sanchit-gandhi HF staff commited on
Commit
9cd4306
1 Parent(s): 298635d

Saving train state of step 5000

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ wandb
accelerate_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
checkpoint-5000-epoch-0/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/Mistral-1.5B-v0.1",
3
+ "architectures": [
4
+ "MistralForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 4096,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 14336,
13
+ "max_position_embeddings": 32768,
14
+ "model_type": "mistral",
15
+ "num_attention_heads": 32,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 8,
18
+ "output_router_logits": true,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 10000.0,
21
+ "sliding_window": 4096,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "bfloat16",
24
+ "transformers_version": "4.40.2",
25
+ "use_cache": true,
26
+ "vocab_size": 32000
27
+ }
checkpoint-5000-epoch-0/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "max_length": 4096,
6
+ "transformers_version": "4.40.2"
7
+ }
checkpoint-5000-epoch-0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:882041588b1bfea594d47a35eb70c8f704275451a600f901c670bb8e3f48393b
3
+ size 3141646744
checkpoint-5000-epoch-0/model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa681ff78f80f5898262cfc263e375064613fc40aa8b147cc1a5423ee5661da1
3
+ size 4450837792
checkpoint-5000-epoch-0/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f88645f4694690415f71a8269aee4cecf3a2917f3de73a644eb5a8d5b931831f
3
+ size 6283329590
checkpoint-5000-epoch-0/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aca503cc09e63ca033e29a437a20cc580a9c1db27fef2174e533f58ba275879
3
+ size 16100
checkpoint-5000-epoch-0/random_states_1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31831c2134536b1e81ba1e763e72b2ff98a14a83774fcfb30d153a66dca7879c
3
+ size 16100
checkpoint-5000-epoch-0/random_states_2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a628258539b4090ce50e9faf5fda4d613f523ca957f3e837c02d316e4b20122
3
+ size 16100
checkpoint-5000-epoch-0/random_states_3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d594aa54f68e8eb41c3deb9753bf43474028f44edb92db1930ebdf967f708a7c
3
+ size 16100
checkpoint-5000-epoch-0/random_states_4.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28ca4240374ff4b93ad0537aca2f28bfc293153a29ee8069cf09d088ca30fee7
3
+ size 16100
checkpoint-5000-epoch-0/random_states_5.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d6f3577977e8c32eac49b1c5136c6718fcd9c66051b703ba6e305cca03a8fb0
3
+ size 16100
checkpoint-5000-epoch-0/random_states_6.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0ef1d86e60e6cedda41454cd08e0b3652ab6a6eb017b4eed0d6b84866ed7d46
3
+ size 16100
checkpoint-5000-epoch-0/random_states_7.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08d860c07ef8d57c8162394106fcd87c34e7924d859b28b4b292e9e792a96af2
3
+ size 16100
checkpoint-5000-epoch-0/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c25f7255aa53945ccffbdb6904da689924024cb2e693a6c6739ade9fae0454a2
3
+ size 1064
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/Mistral-1.5B-v0.1",
3
+ "architectures": [
4
+ "MistralForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 4096,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 14336,
13
+ "max_position_embeddings": 32768,
14
+ "model_type": "mistral",
15
+ "num_attention_heads": 32,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 8,
18
+ "output_router_logits": true,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 10000.0,
21
+ "sliding_window": 4096,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.40.2",
25
+ "use_cache": true,
26
+ "vocab_size": 32000
27
+ }
config_mistral_fineweb.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model arguments
2
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-v0.1
3
+ teacher_model_name_or_path: mistralai/Mistral-7B-v0.1
4
+ dtype: bfloat16
5
+ load_teacher_in_4bit: true
6
+
7
+ # Data arguments
8
+ train_dataset_name: HuggingFaceFW/fineweb-edu
9
+ train_dataset_config_name: sample-100BT
10
+ train_split_name: train[2000:]
11
+ eval_split_name: train[:2000]
12
+ max_steps: 200000
13
+ max_train_samples: 15000000
14
+
15
+ # Training arguments
16
+ do_train: true
17
+ do_eval: true
18
+ per_device_eval_batch_size: 8
19
+ per_device_train_batch_size: 8
20
+ max_label_length: 4096
21
+ learning_rate: 0.0001
22
+ warmup_steps: 500
23
+ gradient_checkpointing: true
24
+ dataloader_num_workers: 4
25
+ preprocessing_num_workers: 32
26
+ ddp_timeout: 7200
27
+ save_strategy: steps
28
+ save_steps: 5000
29
+ evaluation_strategy: steps
30
+ eval_steps: 5000
31
+ logging_steps: 25
32
+ output_router_logits: true
33
+ report_to: all
34
+ output_dir: ./
35
+ overwrite_output_dir: false
36
+ save_total_limit: 1
37
+ wandb_project: distil-mistral
38
+ push_to_hub: true
distil-mistral/1717437256.0075724/events.out.tfevents.1717437256.ip-26-0-160-216.2067815.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d138f754d8ba4e0b3d958d54ed57eb416143d6a6f4e9530b9b47a29b25d393c
3
+ size 1142
distil-mistral/1717437256.0108476/hparams.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ global_batch_size: 64
4
+ gradient_accumulation_steps: 1
5
+ learning_rate: 0.0001
6
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
7
+ - linear
8
+ max_steps: 200000
9
+ mixed_precision: bf16
10
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-v0.1
11
+ num_train_epochs: 3.0
12
+ per_device_train_batch_size: 8
13
+ teacher_name_or_path: mistralai/Mistral-7B-v0.1
14
+ temperature: 2.0
15
+ warmup_steps: 500
16
+ weight_decay: 0.0
distil-mistral/1717437539.7803326/events.out.tfevents.1717437539.ip-26-0-160-216.2070693.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d298a6613fe6720c7c70048d24997d478a79aa706a9015e2d457f5416503529b
3
+ size 1142
distil-mistral/1717437539.7839332/hparams.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ global_batch_size: 64
4
+ gradient_accumulation_steps: 1
5
+ learning_rate: 0.0001
6
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
7
+ - linear
8
+ max_steps: 200000
9
+ mixed_precision: bf16
10
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-v0.1
11
+ num_train_epochs: 3.0
12
+ per_device_train_batch_size: 8
13
+ teacher_name_or_path: mistralai/Mistral-7B-v0.1
14
+ temperature: 2.0
15
+ warmup_steps: 500
16
+ weight_decay: 0.0
distil-mistral/events.out.tfevents.1717437245.ip-26-0-160-216.2067815.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a8fd86cab31aca953565dd7fc58ba0a95ac7a7f5d6974533531a59efe9be3bf
3
+ size 88
distil-mistral/events.out.tfevents.1717437532.ip-26-0-160-216.2070693.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf4c81caf7767182b70095b3ce42bd03955f71e626d0a86e4d5296925983ff6b
3
+ size 62058
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "max_length": 4096,
6
+ "transformers_version": "4.40.2"
7
+ }
run_distillation.py ADDED
@@ -0,0 +1,1549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training langauge models for conditional language modelling tasks via teacher-student distillation.
18
+ """
19
+ # You can also adapt this script for your own distillation tasks. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import shutil
26
+ import sys
27
+ import time
28
+ from dataclasses import dataclass, field
29
+ from functools import partial
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Union
32
+
33
+ import datasets
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from datasets import (
41
+ Dataset,
42
+ DatasetDict,
43
+ IterableDataset,
44
+ IterableDatasetDict,
45
+ concatenate_datasets,
46
+ interleave_datasets,
47
+ load_dataset,
48
+ )
49
+ from huggingface_hub import create_repo, get_full_repo_name, upload_folder
50
+ from peft import LoraConfig, get_peft_model
51
+ from torch.utils.data import DataLoader
52
+ from tqdm import tqdm
53
+ from transformers import (
54
+ AutoConfig,
55
+ AutoModelForCausalLM,
56
+ AutoTokenizer,
57
+ BatchEncoding,
58
+ BitsAndBytesConfig,
59
+ HfArgumentParser,
60
+ PreTrainedTokenizerBase,
61
+ Seq2SeqTrainingArguments,
62
+ get_scheduler,
63
+ set_seed,
64
+ )
65
+ from transformers.utils import check_min_version
66
+ from transformers.utils.versions import require_version
67
+
68
+
69
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
70
+ check_min_version("4.34.0.dev0")
71
+
72
+ require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")
73
+
74
+ logger = get_logger(__name__)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ """
80
+ Arguments pertaining to which model/config/tokenizer we are going to distill from.
81
+ """
82
+
83
+ model_name_or_path: str = field(
84
+ metadata={"help": "Path to pretrained Whisper model or model identifier from huggingface.co/models"}
85
+ )
86
+ teacher_model_name_or_path: str = field(
87
+ metadata={"help": "Path to pretrained teacher model or model identifier from huggingface.co/models"}
88
+ )
89
+ config_name: Optional[str] = field(
90
+ default=None,
91
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
92
+ )
93
+ tokenizer_name: Optional[str] = field(
94
+ default=None,
95
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
96
+ )
97
+ cache_dir: Optional[str] = field(
98
+ default=None,
99
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
100
+ )
101
+ use_fast_tokenizer: bool = field(
102
+ default=True,
103
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
104
+ )
105
+ model_revision: str = field(
106
+ default="main",
107
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
108
+ )
109
+ subfolder: str = field(
110
+ default="",
111
+ metadata={
112
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
113
+ "specify the folder name here."
114
+ },
115
+ )
116
+ token: str = field(
117
+ default=None,
118
+ metadata={
119
+ "help": (
120
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
121
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
122
+ )
123
+ },
124
+ )
125
+ attn_implementation: Optional[str] = field(
126
+ default=None,
127
+ metadata={
128
+ "help": (
129
+ "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n"
130
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
131
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
132
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
133
+ )
134
+ },
135
+ )
136
+ load_teacher_in_8bit: bool = field(default=False, metadata={"help": "Use 8 bit precision for the teacher model."})
137
+ load_teacher_in_4bit: bool = field(default=False, metadata={"help": "Use 4 bit precision for the teacher model."})
138
+ load_student_in_8bit: bool = field(default=False, metadata={"help": "Use 8 bit precision for the student model."})
139
+ load_student_in_4bit: bool = field(default=False, metadata={"help": "Use 4 bit precision for the student model."})
140
+ bnb_4bit_quant_type: Optional[str] = field(
141
+ default="nf4", metadata={"help": "Quantization type if the teacher is quantized (fp4 or nf4)"}
142
+ )
143
+ use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether or not to use nested quantization."})
144
+ lora_r: Optional[int] = field(
145
+ default=16,
146
+ metadata={"help": "LoRA R value."},
147
+ )
148
+ lora_alpha: Optional[int] = field(
149
+ default=32,
150
+ metadata={"help": "LoRA alpha."},
151
+ )
152
+ lora_dropout: Optional[float] = field(
153
+ default=0.05,
154
+ metadata={"help": "LoRA dropout."},
155
+ )
156
+ lora_target_modules: Optional[List[str]] = field(
157
+ default=None,
158
+ metadata={"help": "LoRA target modules."},
159
+ )
160
+ lora_modules_to_save: Optional[List[str]] = field(
161
+ default=None,
162
+ metadata={"help": "Model layers to unfreeze & train"},
163
+ )
164
+ instruction_model: Optional[bool] = field(
165
+ default=None,
166
+ metadata={"help": "Whether or not the pre-trained model is instruction tuned"},
167
+ )
168
+
169
+
170
+ @dataclass
171
+ class DataTrainingArguments:
172
+ """
173
+ Arguments pertaining to what data we are going to input our model for training and eval.
174
+ """
175
+
176
+ train_dataset_name: List[str] = field(
177
+ default=None,
178
+ metadata={
179
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
180
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load LibriSpeech "
181
+ "and Common Voice, set `train_dataset_name='librispeech_asr+common_voice'`."
182
+ },
183
+ )
184
+ train_dataset_config_name: Optional[List[str]] = field(
185
+ default=None,
186
+ metadata={
187
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
188
+ "multiple datasets by separating dataset configs by a '+' symbol. Note that the order of the configs should "
189
+ "match the order of the datasets."
190
+ },
191
+ )
192
+ train_dataset_samples: Optional[List[str]] = field(
193
+ default=None,
194
+ metadata={
195
+ "help": "Number of samples in each dataset when loading multiple datasets with streaming mode. "
196
+ "Not required when using one dataset or non-streaming mode. The sample values provide the sampling "
197
+ "probability for each dataset. Setting them equal to the number of sample values ensures that every "
198
+ "sample from every dataset is used once per epoch."
199
+ },
200
+ )
201
+ eval_dataset_name: Optional[List[str]] = field(
202
+ default=None,
203
+ metadata={
204
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training "
205
+ "dataset name if unspecified. Load multiple evaluation datasets by separating dataset "
206
+ "ids by a '+' symbol."
207
+ },
208
+ )
209
+ eval_dataset_config_name: Optional[List[str]] = field(
210
+ default=None,
211
+ metadata={
212
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the "
213
+ "training dataset config name if unspecified."
214
+ },
215
+ )
216
+ dataset_cache_dir: Optional[str] = field(
217
+ default=None,
218
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
219
+ )
220
+ overwrite_cache: bool = field(
221
+ default=False,
222
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
223
+ )
224
+ preprocessing_num_workers: Optional[int] = field(
225
+ default=None,
226
+ metadata={"help": "The number of processes to use for the preprocessing if using non-streaming mode."},
227
+ )
228
+ max_train_samples: Optional[int] = field(
229
+ default=None,
230
+ metadata={
231
+ "help": (
232
+ "For debugging purposes or quicker training, truncate the number of training examples to this value if set."
233
+ )
234
+ },
235
+ )
236
+ max_eval_samples: Optional[int] = field(
237
+ default=None,
238
+ metadata={
239
+ "help": (
240
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set."
241
+ )
242
+ },
243
+ )
244
+ text_column_name: str = field(
245
+ default=None,
246
+ metadata={"help": "The name of the dataset column containing the generated text data in the training set."},
247
+ )
248
+ prompt_column_name: str = field(
249
+ default=None,
250
+ metadata={"help": "The name of the dataset column containing the prompt data. Defaults to 'prompt'"},
251
+ )
252
+ eval_text_column_name: str = field(
253
+ default=None,
254
+ metadata={"help": "The name of the dataset column containing the generated text data in the evaluation set."},
255
+ )
256
+ eval_prompt_column_name: str = field(
257
+ default=None,
258
+ metadata={"help": "The name of the dataset column containing the prompt data in the evaluation set."},
259
+ )
260
+ max_label_length: int = field(
261
+ default=2048,
262
+ metadata={"help": "Truncate target labels that are longer `max_label_length` tokens."},
263
+ )
264
+ pad_target_to_multiple_of: Optional[int] = field(
265
+ default=None,
266
+ metadata={
267
+ "help": (
268
+ "If set will pad the target sequence to a multiple of the provided value. This is important to "
269
+ "avoid triggering recompilations when using torch compile. If unspecified, will default to padding "
270
+ "the targets to max length."
271
+ )
272
+ },
273
+ )
274
+ preprocessing_only: bool = field(
275
+ default=False,
276
+ metadata={
277
+ "help": (
278
+ "Whether to only do data preprocessing and skip training. This is especially useful when data "
279
+ "preprocessing errors out in distributed training due to timeout. In this case, one should run the "
280
+ "preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets "
281
+ "can consequently be loaded in distributed training"
282
+ )
283
+ },
284
+ )
285
+ train_split_name: Optional[List[str]] = field(
286
+ default=lambda: ["train"],
287
+ metadata={
288
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
289
+ },
290
+ )
291
+ eval_split_name: Optional[List[str]] = field(
292
+ default=lambda: ["validation"],
293
+ metadata={
294
+ "help": (
295
+ "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
296
+ )
297
+ },
298
+ )
299
+ streaming: bool = field(
300
+ default=False,
301
+ metadata={"help": "Whether to use Datasets' streaming mode to load and pre-process the data."},
302
+ )
303
+ wandb_project: str = field(
304
+ default="distil-mistral",
305
+ metadata={"help": "The name of the wandb project."},
306
+ )
307
+
308
+
309
+ @dataclass
310
+ class DistillationTrainingArguments(Seq2SeqTrainingArguments):
311
+ freeze_lm_head: Optional[bool] = field(
312
+ default=False, metadata={"help": "Whether to freeze the LM head of the student model."}
313
+ )
314
+ temperature: Optional[float] = field(
315
+ default=2.0, metadata={"help": "Temperature to anneal the logits when computing the softmax."}
316
+ )
317
+ kl_weight: Optional[float] = field(
318
+ default=1.0,
319
+ metadata={
320
+ "help": (
321
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
322
+ "computed between the teacher-student hidden states and attentions."
323
+ )
324
+ },
325
+ )
326
+ output_router_logits: Optional[bool] = field(
327
+ default=False,
328
+ metadata={
329
+ "help": "Whether or not to return the router logits in the forward pass. Enabling this will "
330
+ "also configure the model to compute the auxiliary loss."
331
+ },
332
+ )
333
+ dtype: Optional[str] = field(
334
+ default="float32",
335
+ metadata={
336
+ "help": (
337
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
338
+ "`float16` or `bfloat16` (both half-precision)."
339
+ )
340
+ },
341
+ )
342
+ completions_only: Optional[bool] = field(
343
+ default=False,
344
+ metadata={
345
+ "help": "Whether to train only on the target completions, or the prompt + completions."
346
+ },
347
+ )
348
+
349
+
350
+ @dataclass
351
+ class DataCollatorCausalLMWithPadding:
352
+ """
353
+ Data collator that will dynamically pad the inputs received.
354
+ Args:
355
+ tokenizer ([`PreTrainedTokenizer`])
356
+ The tokenizer used for tokenizing the data.
357
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
358
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
359
+ See above for details.
360
+ max_target_length (:obj:`int`, `optional`):
361
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
362
+ completions_only (:obj:`bool`, `optional`):
363
+ Whether to train on the assistant responses (completions) only, or the combination of prompt + responses.
364
+ """
365
+
366
+ tokenizer: PreTrainedTokenizerBase
367
+ target_padding: Union[bool, str] = "max_length"
368
+ max_target_length: Optional[int] = None
369
+ completions_only: Optional[bool] = False
370
+
371
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> BatchEncoding:
372
+ # dataloader returns a list of features which we convert to a dict
373
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
374
+ label_lengths = [len(feature["labels"]) for feature in features]
375
+ prompt_lengths = [feature["prompt_length"] for feature in features]
376
+
377
+ batch = self.tokenizer.pad(
378
+ label_features,
379
+ max_length=self.max_target_length,
380
+ padding=self.target_padding,
381
+ return_tensors="pt",
382
+ )
383
+
384
+ labels_mask = batch["attention_mask"]
385
+
386
+ if self.completions_only:
387
+ # don't include prompts in loss calculation
388
+ for idx in range(len(prompt_lengths)):
389
+ padding_length = labels_mask.shape[1] - label_lengths[idx]
390
+ labels_mask[idx, padding_length : padding_length + prompt_lengths[idx]] = 0
391
+
392
+ # replace padding with -100 to ignore loss correctly
393
+ labels = batch["input_ids"].masked_fill(labels_mask.ne(1), -100)
394
+
395
+ batch["labels"] = labels
396
+
397
+ return batch
398
+
399
+
400
+ def log_metric(
401
+ accelerator,
402
+ metrics: Dict,
403
+ train_time: float,
404
+ step: int,
405
+ epoch: int,
406
+ learning_rate: float = None,
407
+ prefix: str = "train",
408
+ ):
409
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
410
+ log_metrics = {}
411
+ for k, v in metrics.items():
412
+ log_metrics[f"{prefix}/{k}"] = v
413
+ log_metrics[f"{prefix}/time"] = train_time
414
+ log_metrics[f"{prefix}/epoch"] = epoch
415
+ if learning_rate is not None:
416
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
417
+ accelerator.log(log_metrics, step=step)
418
+
419
+
420
+ def log_pred(
421
+ accelerator,
422
+ pred_str: List[str],
423
+ label_str: List[str],
424
+ step: int,
425
+ epoch: int,
426
+ evaluation_strategy: str,
427
+ prefix: str = "eval",
428
+ num_lines: int = 200000,
429
+ ):
430
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
431
+ if accelerator.is_main_process:
432
+ wandb_tracker = accelerator.get_tracker("wandb")
433
+ # pretty name for current step: step 50000 -> step 50k
434
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
435
+ prefix_pretty = prefix.replace("/", "-")
436
+
437
+ if evaluation_strategy == "epoch":
438
+ table_name = f"predictions/{prefix_pretty}-epoch-{epoch}"
439
+ else:
440
+ table_name = f"predictions/{prefix_pretty}-step-{cur_step_pretty}"
441
+
442
+ # convert str data to a wandb compatible format
443
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
444
+ # log as a table with the appropriate headers
445
+ wandb_tracker.log_table(
446
+ table_name=table_name,
447
+ columns=["Target", "Pred"],
448
+ data=str_data[:num_lines],
449
+ step=step,
450
+ )
451
+
452
+
453
+ def convert_dataset_str_to_list(
454
+ dataset_names,
455
+ dataset_config_names,
456
+ splits=None,
457
+ text_column_names=None,
458
+ prompt_column_names=None,
459
+ dataset_samples=None,
460
+ default_split="train",
461
+ ) -> List[Dict]:
462
+ """
463
+ Given three lists of dataset names, configs and splits, this function groups the corresponding
464
+ names/configs/splits. Each dataset is assigned a unique dictionary with these metadata values, and the
465
+ function returns a list of dictionaries, one for each dataset.
466
+ """
467
+ if isinstance(dataset_names, str):
468
+ dataset_names = [dataset_names]
469
+ splits = [splits] if splits else None
470
+ text_column_names = [text_column_names] if text_column_names else None
471
+ prompt_column_names = [prompt_column_names] if prompt_column_names else None
472
+ if isinstance(dataset_config_names, str):
473
+ dataset_config_names = [dataset_config_names]
474
+
475
+ if len(dataset_names) == 1 and len(dataset_config_names) > 1:
476
+ dataset_names = len(dataset_config_names) * dataset_names
477
+
478
+ if isinstance(splits, list) and len(splits) == 1 and len(dataset_config_names) > 1:
479
+ splits = len(dataset_config_names) * splits
480
+
481
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
482
+ if dataset_config_names is not None and len(dataset_names) != len(dataset_config_names):
483
+ raise ValueError(
484
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
485
+ f" {len(dataset_config_names)} configs."
486
+ )
487
+
488
+ if splits is not None and len(splits) != len(dataset_names):
489
+ raise ValueError(
490
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
491
+ )
492
+
493
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
494
+ raise ValueError(
495
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
496
+ f" {len(text_column_names)} text column names."
497
+ )
498
+
499
+ if prompt_column_names is not None and len(prompt_column_names) != len(dataset_names):
500
+ raise ValueError(
501
+ f"Ensure one prompt column name is passed for each dataset, got {len(dataset_names)} datasets and"
502
+ f" {len(prompt_column_names)} prompt column names."
503
+ )
504
+
505
+ if dataset_samples is not None:
506
+ if len(dataset_samples) != len(dataset_names):
507
+ raise ValueError(
508
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
509
+ f"{len(dataset_samples)} samples."
510
+ )
511
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
512
+ else:
513
+ dataset_samples = [None] * len(dataset_names)
514
+
515
+ dataset_config_names = (
516
+ dataset_config_names if dataset_config_names is not None else ["default" for _ in range(len(dataset_names))]
517
+ )
518
+ text_column_names = (
519
+ text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
520
+ )
521
+ prompt_column_names = (
522
+ prompt_column_names if prompt_column_names is not None else [None for _ in range(len(dataset_names))]
523
+ )
524
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
525
+
526
+ dataset_names_dict = []
527
+ for i, ds_name in enumerate(dataset_names):
528
+ dataset_names_dict.append(
529
+ {
530
+ "name": ds_name,
531
+ "config": dataset_config_names[i],
532
+ "split": splits[i],
533
+ "text_column_name": text_column_names[i],
534
+ "prompt_column_name": prompt_column_names[i],
535
+ "samples": dataset_samples[i],
536
+ }
537
+ )
538
+ return dataset_names_dict
539
+
540
+
541
+ def load_multiple_datasets(
542
+ dataset_names: Union[List, str],
543
+ dataset_config_names: Union[List, str],
544
+ splits: Optional[Union[List, str]] = None,
545
+ text_column_names: Optional[List] = None,
546
+ prompt_column_names: Optional[List] = None,
547
+ stopping_strategy: Optional[str] = "first_exhausted",
548
+ dataset_samples: Optional[Union[List, np.array]] = None,
549
+ streaming: Optional[bool] = False,
550
+ seed: Optional[int] = None,
551
+ accelerator: Optional[Accelerator] = None,
552
+ **kwargs,
553
+ ) -> Union[Dataset, IterableDataset]:
554
+ dataset_names_dict = convert_dataset_str_to_list(
555
+ dataset_names, dataset_config_names, splits, text_column_names, prompt_column_names, dataset_samples
556
+ )
557
+
558
+ if dataset_samples is not None:
559
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
560
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
561
+ else:
562
+ probabilities = None
563
+
564
+ all_datasets = []
565
+ # iterate over the datasets we want to interleave
566
+ for dataset_dict in tqdm(
567
+ dataset_names_dict,
568
+ desc="Combining datasets...",
569
+ disable=not accelerator.is_main_process,
570
+ ):
571
+ dataset = load_dataset(
572
+ dataset_dict["name"],
573
+ dataset_dict["config"],
574
+ split=dataset_dict["split"],
575
+ streaming=streaming,
576
+ **kwargs,
577
+ )
578
+
579
+ columns_to_keep = {"text"}
580
+ dataset_features = dataset.features.keys()
581
+
582
+ if dataset_dict["text_column_name"] not in dataset_features:
583
+ raise ValueError(
584
+ f"Text column name {dataset_dict['text_column_name']} not found in dataset"
585
+ f" '{dataset_dict['name']}'. Make sure to set `--text_column_name` to the"
586
+ f" correct text column - one of {', '.join(dataset_features)}."
587
+ )
588
+
589
+ # blanket renaming of all transcription columns to text
590
+ if dataset_dict["text_column_name"] != "text":
591
+ dataset = dataset.rename_column(dataset_dict["text_column_name"], "text")
592
+
593
+ # blanket renaming of all prompt columns to prompt
594
+ if dataset_dict["prompt_column_name"] is not None:
595
+ if dataset_dict["prompt_column_name"] not in dataset_features:
596
+ raise ValueError(
597
+ f"Prompt column name {dataset_dict['prompt_column_name']} not found in dataset"
598
+ f" '{dataset_dict['name']}'. Make sure to set `--prompt_column_name` to the"
599
+ f" correct prompt column - one of {', '.join(dataset_features)}."
600
+ )
601
+ elif dataset_dict["prompt_column_name"] != "prompt":
602
+ dataset = dataset.rename_column(dataset_dict["prompt_column_name"], "prompt")
603
+ columns_to_keep.add("prompt")
604
+
605
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
606
+ all_datasets.append(dataset)
607
+
608
+ if len(all_datasets) == 1:
609
+ # we have a single dataset so just return it as is
610
+ return all_datasets[0]
611
+
612
+ if streaming:
613
+ interleaved_dataset = interleave_datasets(
614
+ all_datasets,
615
+ stopping_strategy=stopping_strategy,
616
+ probabilities=probabilities,
617
+ seed=seed,
618
+ )
619
+ else:
620
+ interleaved_dataset = concatenate_datasets(all_datasets)
621
+
622
+ # shuffle mixed dataset prior to potentially truncating it
623
+ interleaved_dataset = interleaved_dataset.shuffle(seed)
624
+ return interleaved_dataset
625
+
626
+
627
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
628
+ """Helper function to sort saved checkpoints from oldest to newest."""
629
+ ordering_and_checkpoint_path = []
630
+
631
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
632
+
633
+ for path in glob_checkpoints:
634
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
635
+ if regex_match is not None and regex_match.groups() is not None:
636
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
637
+
638
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
639
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
640
+ return checkpoints_sorted
641
+
642
+
643
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> Union[List, None]:
644
+ """Helper function to delete old checkpoints."""
645
+ if save_total_limit is None or save_total_limit <= 0:
646
+ return
647
+ # Check if we should delete older checkpoint(s)
648
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
649
+ if len(checkpoints_sorted) <= save_total_limit:
650
+ return
651
+
652
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
653
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
654
+ for checkpoint in checkpoints_to_be_deleted:
655
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
656
+ shutil.rmtree(checkpoint, ignore_errors=True)
657
+ checkpoints_to_be_deleted = [f"*{Path(checkpoint).absolute().name}*" for checkpoint in checkpoints_to_be_deleted]
658
+ return checkpoints_to_be_deleted
659
+
660
+
661
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
662
+
663
+
664
+ def get_last_checkpoint(folder):
665
+ content = os.listdir(folder)
666
+ checkpoints = [
667
+ path
668
+ for path in content
669
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
670
+ ]
671
+ if len(checkpoints) == 0:
672
+ return
673
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
674
+
675
+
676
+ def get_parameter_names(model, forbidden_layer_types, forbidden_module=None):
677
+ """
678
+ Returns the names of the model parameters that are not inside a forbidden layer or forbidden module.
679
+ Can be used to get a subset of parameter names for decay masks, or to exclude parameters from an optimiser
680
+ (e.g. if the module is frozen).
681
+ """
682
+ result = []
683
+ for name, child in model.named_children():
684
+ result += [
685
+ f"{name}.{n}"
686
+ for n in get_parameter_names(child, forbidden_layer_types, forbidden_module)
687
+ if not (
688
+ isinstance(child, tuple(forbidden_layer_types))
689
+ or (child in tuple(forbidden_module) if forbidden_module is not None else False)
690
+ )
691
+ ]
692
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
693
+ result += list(model._parameters.keys())
694
+ return result
695
+
696
+
697
+ def get_quantization_config(
698
+ model_args: ModelArguments, torch_dtype: torch.dtype
699
+ ) -> tuple[BitsAndBytesConfig | None, BitsAndBytesConfig | None]:
700
+ if model_args.load_teacher_in_4bit:
701
+ quantization_config_teacher = BitsAndBytesConfig(
702
+ load_in_4bit=True,
703
+ bnb_4bit_compute_dtype=torch_dtype,
704
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
705
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
706
+ )
707
+ elif model_args.load_teacher_in_8bit:
708
+ quantization_config_teacher = BitsAndBytesConfig(load_in_8bit=True)
709
+ else:
710
+ quantization_config_teacher = None
711
+
712
+ if model_args.load_student_in_4bit:
713
+ quantization_config_student = BitsAndBytesConfig(
714
+ load_in_4bit=True,
715
+ bnb_4bit_compute_dtype=torch_dtype,
716
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
717
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
718
+ )
719
+ elif model_args.load_student_in_8bit:
720
+ quantization_config_student = BitsAndBytesConfig(load_in_8bit=True)
721
+ else:
722
+ quantization_config_student = None
723
+
724
+ return quantization_config_teacher, quantization_config_student
725
+
726
+
727
+ def main():
728
+ # 1. Parse input arguments
729
+ # We keep distinct sets of args, for cleaner separation of model/data/training related args
730
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, DistillationTrainingArguments))
731
+
732
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
733
+ # If we pass only one argument to the script and it's the path to a json file,
734
+ # let's parse it to get our arguments.
735
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
736
+ elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
737
+ # If we pass only one argument to the script and it's the path to a yaml file,
738
+ # let's parse it to get our arguments.
739
+ model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1]))
740
+ else:
741
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
742
+
743
+ # 2. Initialize the accelerator
744
+ # We will let the accelerator handle device placement for us in this example
745
+ # We simply have to specify the training precision and any trackers being used
746
+ # We'll use the same dtype arguments as our JAX/Flax training script and convert
747
+ # it to accelerate format
748
+ if training_args.dtype == "float16":
749
+ mixed_precision = "fp16"
750
+ teacher_dtype = torch.float16
751
+ elif training_args.dtype == "bfloat16":
752
+ mixed_precision = "bf16"
753
+ teacher_dtype = torch.bfloat16
754
+ else:
755
+ mixed_precision = "no"
756
+ teacher_dtype = torch.float32
757
+
758
+ accelerator = Accelerator(
759
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
760
+ mixed_precision=mixed_precision,
761
+ log_with=training_args.report_to,
762
+ project_dir=training_args.output_dir,
763
+ )
764
+
765
+ accelerator.init_trackers(
766
+ project_name=data_args.wandb_project,
767
+ config={
768
+ "learning_rate": training_args.learning_rate,
769
+ "model_name_or_path": model_args.model_name_or_path,
770
+ "teacher_name_or_path": model_args.teacher_model_name_or_path,
771
+ "num_train_epochs": training_args.num_train_epochs,
772
+ "max_steps": training_args.max_steps,
773
+ "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
774
+ "per_device_train_batch_size": training_args.per_device_train_batch_size,
775
+ "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
776
+ "mixed_precision": mixed_precision,
777
+ "lr_scheduler_type": training_args.lr_scheduler_type,
778
+ "warmup_steps": training_args.warmup_steps,
779
+ "weight_decay": training_args.weight_decay,
780
+ "adam_beta1": training_args.adam_beta1,
781
+ "adam_beta2": training_args.adam_beta2,
782
+ "temperature": training_args.temperature,
783
+ },
784
+ )
785
+
786
+ # 3. Set-up basic logging
787
+ # Create one log on every process with the configuration for debugging
788
+ logging.basicConfig(
789
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
790
+ datefmt="%m/%d/%Y %H:%M:%S",
791
+ level=logging.INFO,
792
+ )
793
+ # Log a small summary on each proces
794
+ logger.warning(
795
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
796
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
797
+ )
798
+
799
+ # Set the verbosity to info of the Transformers logger (on main process only)
800
+ if accelerator.is_local_main_process:
801
+ datasets.utils.logging.set_verbosity_warning()
802
+ transformers.utils.logging.set_verbosity_info()
803
+ else:
804
+ datasets.utils.logging.set_verbosity_error()
805
+ transformers.utils.logging.set_verbosity_error()
806
+ logger.info("Training/evaluation parameters %s", training_args)
807
+
808
+ # 4. Detecting last checkpoint and eventually continue from last checkpoint
809
+ last_checkpoint = None
810
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
811
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
812
+ if last_checkpoint is None and len(sorted_checkpoints(training_args.output_dir)) > 0:
813
+ raise ValueError(
814
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
815
+ "Use --overwrite_output_dir to overcome."
816
+ )
817
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
818
+ logger.info(
819
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
820
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
821
+ )
822
+
823
+ # 5. Handle the repository creation
824
+ if accelerator.is_main_process:
825
+ if training_args.output_dir is not None:
826
+ os.makedirs(training_args.output_dir, exist_ok=True)
827
+ if training_args.push_to_hub:
828
+ if training_args.hub_model_id is None:
829
+ repo_name = get_full_repo_name(
830
+ Path(training_args.output_dir).absolute().name,
831
+ token=training_args.hub_token,
832
+ )
833
+ else:
834
+ repo_name = training_args.hub_model_id
835
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
836
+
837
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
838
+ if "wandb" not in gitignore:
839
+ gitignore.write("wandb\n")
840
+ accelerator.wait_for_everyone()
841
+
842
+ # 6. Load dataset - either streaming or non-streaming (offline)
843
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
844
+
845
+ # set seed for determinism
846
+ set_seed(training_args.seed)
847
+
848
+ if training_args.do_train:
849
+ raw_datasets["train"] = load_multiple_datasets(
850
+ data_args.train_dataset_name,
851
+ data_args.train_dataset_config_name,
852
+ splits=data_args.train_split_name,
853
+ text_column_names=data_args.text_column_name,
854
+ prompt_column_names=data_args.prompt_column_name,
855
+ streaming=data_args.streaming,
856
+ dataset_samples=data_args.train_dataset_samples,
857
+ seed=training_args.seed,
858
+ accelerator=accelerator,
859
+ cache_dir=data_args.dataset_cache_dir,
860
+ token=model_args.token,
861
+ num_proc=data_args.preprocessing_num_workers,
862
+ )
863
+ raw_datasets_train_features = set(raw_datasets["train"].features.keys())
864
+
865
+ if training_args.do_eval:
866
+ dataset_names_dict = convert_dataset_str_to_list(
867
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
868
+ (
869
+ data_args.eval_dataset_config_name
870
+ if data_args.eval_dataset_config_name
871
+ else data_args.train_dataset_config_name
872
+ ),
873
+ splits=data_args.eval_split_name,
874
+ text_column_names=data_args.eval_text_column_name,
875
+ prompt_column_names=data_args.eval_prompt_column_name,
876
+ )
877
+ all_eval_splits = []
878
+ if len(dataset_names_dict) == 1:
879
+ # load a single eval set
880
+ dataset_dict = dataset_names_dict[0]
881
+ all_eval_splits.append("eval")
882
+ raw_datasets["eval"] = load_dataset(
883
+ dataset_dict["name"],
884
+ dataset_dict["config"],
885
+ split=dataset_dict["split"],
886
+ cache_dir=data_args.dataset_cache_dir,
887
+ token=model_args.token,
888
+ streaming=data_args.streaming,
889
+ )
890
+ if dataset_dict["text_column_name"] != "text":
891
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_text_column_name, "text")
892
+ if dataset_dict["prompt_column_name"] and dataset_dict["prompt_column_name"] != "prompt":
893
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_prompt_column_name, "prompt")
894
+ else:
895
+ # load multiple eval sets
896
+ for dataset_dict in dataset_names_dict:
897
+ pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['config'].replace('.', '-')}"
898
+ all_eval_splits.append(pretty_name)
899
+ raw_datasets[pretty_name] = load_dataset(
900
+ dataset_dict["name"],
901
+ dataset_dict["config"],
902
+ split=dataset_dict["split"],
903
+ cache_dir=data_args.dataset_cache_dir,
904
+ token=model_args.token,
905
+ streaming=data_args.streaming,
906
+ )
907
+ # make column names consistent (text, prompt)
908
+ columns_to_keep = {"text"}
909
+ if dataset_dict["text_column_name"] != "text":
910
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
911
+ dataset_dict["text_column_name"], "text"
912
+ )
913
+ if dataset_dict["prompt_column_name"] and dataset_dict["prompt_column_name"] != "prompt":
914
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
915
+ dataset_dict["prompt_column_name"], "prompt"
916
+ )
917
+ columns_to_keep.add("prompt")
918
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
919
+ set(raw_datasets[pretty_name].features.keys()) - columns_to_keep
920
+ )
921
+
922
+ if not training_args.do_train and not training_args.do_eval:
923
+ raise ValueError(
924
+ "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
925
+ )
926
+
927
+ # 7. Load pretrained model, tokenizer, and feature extractor
928
+ config = AutoConfig.from_pretrained(
929
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
930
+ cache_dir=model_args.cache_dir,
931
+ revision=model_args.model_revision,
932
+ token=model_args.token,
933
+ )
934
+ if training_args.output_router_logits:
935
+ config.output_router_logits = True
936
+
937
+ tokenizer = AutoTokenizer.from_pretrained(
938
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
939
+ cache_dir=model_args.cache_dir,
940
+ use_fast=model_args.use_fast_tokenizer,
941
+ revision=model_args.model_revision,
942
+ token=model_args.token,
943
+ )
944
+ if tokenizer.pad_token_id is None:
945
+ tokenizer.pad_token = tokenizer.eos_token
946
+
947
+ quantization_config_teacher, quantization_config_student = get_quantization_config(
948
+ model_args, torch_dtype=teacher_dtype
949
+ )
950
+
951
+ # The teacher model can safely be cast to the dtype of training since we don't
952
+ # update the params
953
+ teacher_model = AutoModelForCausalLM.from_pretrained(
954
+ model_args.teacher_model_name_or_path,
955
+ cache_dir=model_args.cache_dir,
956
+ token=model_args.token,
957
+ low_cpu_mem_usage=True,
958
+ torch_dtype=teacher_dtype,
959
+ attn_implementation=model_args.attn_implementation,
960
+ quantization_config=quantization_config_teacher,
961
+ )
962
+
963
+ student_model = AutoModelForCausalLM.from_pretrained(
964
+ model_args.model_name_or_path,
965
+ config=config,
966
+ cache_dir=model_args.cache_dir,
967
+ revision=model_args.model_revision,
968
+ subfolder=model_args.subfolder,
969
+ token=model_args.token,
970
+ torch_dtype=teacher_dtype,
971
+ low_cpu_mem_usage=True,
972
+ attn_implementation=model_args.attn_implementation,
973
+ quantization_config=quantization_config_student,
974
+ )
975
+
976
+ if quantization_config_student is not None:
977
+ lora_config = LoraConfig(
978
+ r=model_args.lora_r,
979
+ lora_alpha=model_args.lora_alpha,
980
+ target_modules=model_args.lora_target_modules,
981
+ lora_dropout=model_args.lora_dropout,
982
+ bias="none",
983
+ task_type="CAUSAL_LM",
984
+ )
985
+ student_model = get_peft_model(student_model, lora_config)
986
+
987
+ if student_model.generation_config.bos_token_id is None or teacher_model.generation_config.bos_token_id is None:
988
+ raise ValueError(
989
+ f"Make sure that `generation_config.bos_token_id` is correctly defined for both the "
990
+ f"student and teacher model. Got {student_model.generation_config.bos_token_id} for the "
991
+ f"student and {teacher_model.generation_config.bos_token_id} for the teacher."
992
+ )
993
+
994
+ # enable gradient checkpointing if necessary
995
+ if training_args.gradient_checkpointing:
996
+ student_model.gradient_checkpointing_enable()
997
+
998
+ def set_trainable_parameters(module, requires_grad=False):
999
+ for param in module.parameters():
1000
+ param.requires_grad = requires_grad
1001
+ module._requires_grad = requires_grad
1002
+
1003
+ # freeze student lm head if necessary
1004
+ if training_args.freeze_lm_head:
1005
+ set_trainable_parameters(student_model.lm_head, requires_grad=False)
1006
+
1007
+ student_model.generation_config.max_length = data_args.max_label_length
1008
+
1009
+ # 8. Save all pre-processed tokenizers/config/generation configs
1010
+ if accelerator.is_main_process:
1011
+ tokenizer.save_pretrained(training_args.output_dir)
1012
+ # save the config and generation config as well
1013
+ config.save_pretrained(training_args.output_dir)
1014
+ student_model.generation_config.save_pretrained(training_args.output_dir)
1015
+
1016
+ accelerator.wait_for_everyone()
1017
+
1018
+
1019
+ # 10. Preprocessing the datasets: we need to combine the prompt and generations and tokenize the targets.
1020
+ # 10.1: Define the pre-processing constants
1021
+ max_label_length = (
1022
+ data_args.max_label_length if data_args.max_label_length is not None else config.max_length
1023
+ )
1024
+ num_workers = data_args.preprocessing_num_workers
1025
+ dataloader_num_workers = training_args.dataloader_num_workers
1026
+ prefetch_factor = training_args.dataloader_prefetch_factor
1027
+ eos_token_id = tokenizer.eos_token_id
1028
+ if model_args.instruction_model is not None:
1029
+ instruction_model = model_args.instruction_model
1030
+ else:
1031
+ instruction_model = "instruct" in model_args.model_name_or_path.lower()
1032
+ if instruction_model and "prompt" not in raw_datasets_train_features:
1033
+ raise ValueError(
1034
+ "Distilling an instruction model, but training dataset does not contain prompt-response pairs. Ensure"
1035
+ "the dataset includes both prompts and responses, which should be specified with the `--prompt_column_name`"
1036
+ f"and `--text_column_name` arguments respectively. Got the following columns: {' '.join(list(raw_datasets_train_features))}."
1037
+ )
1038
+
1039
+ # 10.2: filter based on maximum number of training/evaluation samples
1040
+ if training_args.do_train and data_args.max_train_samples is not None:
1041
+ raw_datasets["train"] = (
1042
+ raw_datasets["train"].take(data_args.max_train_samples)
1043
+ if data_args.streaming
1044
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
1045
+ )
1046
+
1047
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1048
+ for eval_split in all_eval_splits:
1049
+ raw_datasets[eval_split] = (
1050
+ raw_datasets[eval_split].take(data_args.max_eval_samples)
1051
+ if data_args.streaming
1052
+ else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1053
+ )
1054
+
1055
+ # 10.3: pre-process training/evaluation datasets
1056
+ def prepare_dataset(example):
1057
+ prompt = example.get("prompt")
1058
+ target_text = prompt + example["text"] if prompt is not None else example["text"]
1059
+ example["labels"] = tokenizer(target_text).input_ids
1060
+ if example["labels"][-1] != eos_token_id:
1061
+ example["labels"] += [eos_token_id]
1062
+ example["prompt_length"] = len(tokenizer(prompt).input_ids) if prompt else 0
1063
+ return example
1064
+
1065
+ def prepare_instruction_dataset(example):
1066
+ messages = [
1067
+ {"role": "user", "content": example["prompt"]},
1068
+ {"role": "assistant", "content": example["text"]},
1069
+ ]
1070
+ example["labels"] = tokenizer.apply_chat_template(messages)
1071
+ if example["labels"][-1] != eos_token_id:
1072
+ example["labels"] = example["labels"][:-1]
1073
+
1074
+ example["prompt_length"] = len(tokenizer.apply_chat_template([messages[0]]))
1075
+ return example
1076
+
1077
+ prepare_dataset = prepare_instruction_dataset if instruction_model else prepare_dataset
1078
+ vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1079
+ if training_args.do_train:
1080
+ # with streaming mode we can only have 1 worker, whereas with non-streaming
1081
+ # we can use `num_workers` (which is much faster)
1082
+ # We gate the pre-processing function accordingly
1083
+ map_fn_train = partial(
1084
+ raw_datasets["train"].map,
1085
+ function=prepare_dataset,
1086
+ remove_columns=raw_datasets_train_features,
1087
+ )
1088
+ with accelerator.main_process_first():
1089
+ vectorized_datasets["train"] = (
1090
+ map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1091
+ if not data_args.streaming
1092
+ else map_fn_train()
1093
+ )
1094
+ if training_args.do_eval:
1095
+ for eval_split in all_eval_splits:
1096
+ raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1097
+ map_fn_eval = partial(
1098
+ raw_datasets[eval_split].map, function=prepare_dataset, remove_columns=raw_datasets_eval_features
1099
+ )
1100
+ with accelerator.main_process_first():
1101
+ vectorized_datasets[eval_split] = (
1102
+ map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1103
+ if not data_args.streaming
1104
+ else map_fn_eval()
1105
+ )
1106
+
1107
+ # 10.4: Filter training data with labels longer than `max_label_length`
1108
+ def is_labels_in_length_range(labels):
1109
+ return 0 < len(labels) <= max_label_length
1110
+
1111
+ filter_by_labels_fn = partial(
1112
+ vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1113
+ )
1114
+ with accelerator.main_process_first():
1115
+ vectorized_datasets = (
1116
+ filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1117
+ if not data_args.streaming
1118
+ else filter_by_labels_fn()
1119
+ )
1120
+
1121
+ # Pre-processing complete!
1122
+ # For large datasets it is advised to run the preprocessing on a
1123
+ # single machine first with `--preprocessing_only` since there will mostly likely
1124
+ # be a timeout when running the script in distributed mode.
1125
+ # In a second step, `--preprocessing_only` can then be set to `False` to load the
1126
+ # cached dataset
1127
+ if data_args.preprocessing_only:
1128
+ if data_args.streaming:
1129
+ raise ValueError(
1130
+ "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1131
+ "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1132
+ "on the fly with streaming mode."
1133
+ )
1134
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1135
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1136
+ return
1137
+
1138
+ # 11. Define Evaluation Metrics
1139
+ def compute_metrics(preds, labels):
1140
+ # TODO(SG): better metrics for performance?
1141
+ # replace padded labels by the padding token
1142
+ for idx in range(len(labels)):
1143
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1144
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
1145
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1146
+ return pred_str, label_str
1147
+
1148
+ # 12. Define Training Schedule
1149
+ # 12.1: Store some constants
1150
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1151
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1152
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1153
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1154
+
1155
+ # 12.2: Set max training steps
1156
+ if not data_args.streaming and training_args.max_steps < 0:
1157
+ num_epochs = int(training_args.num_train_epochs)
1158
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1159
+ total_train_steps = steps_per_epoch * num_epochs
1160
+ elif training_args.max_steps > 0:
1161
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1162
+ total_train_steps = int(training_args.max_steps)
1163
+ if not data_args.streaming:
1164
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1165
+ num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1166
+ else:
1167
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1168
+ num_epochs = sys.maxsize
1169
+ steps_per_epoch = total_train_steps
1170
+ else:
1171
+ raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1172
+
1173
+ # 12.3: Set evaluation steps
1174
+ if training_args.evaluation_strategy == "epoch":
1175
+ eval_steps = steps_per_epoch
1176
+ elif training_args.eval_steps is None:
1177
+ logger.info(
1178
+ f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1179
+ )
1180
+ eval_steps = steps_per_epoch
1181
+ else:
1182
+ eval_steps = training_args.eval_steps
1183
+
1184
+ # 12.4: Set save steps
1185
+ if training_args.save_strategy == "epoch":
1186
+ save_steps = steps_per_epoch
1187
+ elif training_args.save_strategy == "steps":
1188
+ save_steps = training_args.save_steps
1189
+ else:
1190
+ save_steps = sys.maxsize
1191
+
1192
+ # 13. Define optimizer, LR scheduler, collator
1193
+ decay_parameters = get_parameter_names(
1194
+ student_model,
1195
+ [nn.LayerNorm],
1196
+ )
1197
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
1198
+ optimizer_grouped_parameters = [
1199
+ {
1200
+ "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1201
+ "weight_decay": training_args.weight_decay,
1202
+ },
1203
+ {
1204
+ "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1205
+ "weight_decay": 0.0,
1206
+ },
1207
+ ]
1208
+ optimizer = torch.optim.AdamW(
1209
+ params=optimizer_grouped_parameters,
1210
+ lr=training_args.learning_rate,
1211
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
1212
+ eps=training_args.adam_epsilon,
1213
+ )
1214
+
1215
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1216
+ lr_scheduler = get_scheduler(
1217
+ name=training_args.lr_scheduler_type,
1218
+ optimizer=optimizer,
1219
+ num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1220
+ num_training_steps=total_train_steps * accelerator.num_processes,
1221
+ )
1222
+
1223
+ data_collator = DataCollatorCausalLMWithPadding(
1224
+ tokenizer=tokenizer,
1225
+ target_padding="max_length",
1226
+ max_target_length=max_label_length,
1227
+ completions_only=training_args.completions_only,
1228
+ )
1229
+
1230
+ # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1231
+ # so that we can still access the configs
1232
+ num_beams = (
1233
+ training_args.generation_num_beams
1234
+ if training_args.generation_num_beams is not None
1235
+ else getattr(student_model.generation_config, "num_beams", 1)
1236
+ )
1237
+
1238
+ # 15. Prepare everything with accelerate
1239
+ student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1240
+ student_model, teacher_model, optimizer, lr_scheduler
1241
+ )
1242
+
1243
+ def kl_divergence(target_distribution, log_predicted_distribution, labels):
1244
+ kl_loss = nn.KLDivLoss(reduction="none")
1245
+ divergence = kl_loss(log_predicted_distribution, target_distribution)
1246
+ # ignore padded tokens from divergence, i.e. where labels are not set to -100
1247
+ padding_mask = labels >= 0
1248
+ padding_mask = padding_mask.unsqueeze(-1)
1249
+ divergence = divergence * padding_mask
1250
+ # take the average over the mini-batch
1251
+ divergence = divergence.sum() / padding_mask.sum()
1252
+ return divergence
1253
+
1254
+ # Define gradient update step fn
1255
+ def train_step(
1256
+ batch,
1257
+ temperature=2.0,
1258
+ ):
1259
+ student_model.train()
1260
+ teacher_model.eval()
1261
+
1262
+ student_outputs = student_model(**batch)
1263
+ with torch.no_grad():
1264
+ teacher_outputs = teacher_model(**batch)
1265
+
1266
+ # CE (data) loss
1267
+ ce_loss = student_outputs.loss
1268
+ # rescale distribution by temperature to ensure gradients scale correctly
1269
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1270
+ # log softmax of student predictions for numerical stability
1271
+ student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1272
+ # KL-divergence loss (scaled by temperature)
1273
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1274
+
1275
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1276
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1277
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1278
+ return loss, metrics
1279
+
1280
+ # Define eval fn
1281
+ def eval_step(batch):
1282
+ student_model.eval()
1283
+ teacher_model.eval()
1284
+
1285
+ with torch.no_grad():
1286
+ student_outputs = student_model(**batch)
1287
+ teacher_outputs = teacher_model(**batch)
1288
+
1289
+ # CE (data) loss
1290
+ ce_loss = student_outputs.loss
1291
+
1292
+ # log softmax / softmax for numerical stability
1293
+ student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1294
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1295
+ # temperature is always 1 for eval
1296
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1297
+
1298
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1299
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1300
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1301
+ return metrics
1302
+
1303
+ def generate_step(batch):
1304
+ student_model.eval()
1305
+ output_ids = accelerator.unwrap_model(student_model).generate(
1306
+ **batch, max_length=max_label_length, num_beams=num_beams
1307
+ )
1308
+ output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1309
+ return output_ids
1310
+
1311
+ logger.info("***** Running training *****")
1312
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1313
+ if not data_args.streaming:
1314
+ logger.info(f" Num epochs = {num_epochs}")
1315
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1316
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1317
+ logger.info(
1318
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1319
+ )
1320
+ logger.info(f" Total optimization steps = {total_train_steps}")
1321
+
1322
+ # ======================== Training ================================
1323
+ train_time = 0
1324
+ train_start = time.time()
1325
+ steps_trained_progress_bar = tqdm(
1326
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1327
+ )
1328
+ continue_training = True
1329
+ epochs_trained = 0
1330
+ cur_step = 0
1331
+
1332
+ checkpoint = None
1333
+ if training_args.resume_from_checkpoint is not None:
1334
+ checkpoint = training_args.resume_from_checkpoint
1335
+ elif last_checkpoint is not None:
1336
+ checkpoint = last_checkpoint
1337
+
1338
+ if checkpoint is not None:
1339
+ accelerator.load_state(checkpoint)
1340
+ # Find num steps and epoch from saved state string pattern
1341
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1342
+ match = re.search(pattern, checkpoint)
1343
+ cur_step = int(match.group(1))
1344
+ epochs_trained = int(match.group(2))
1345
+
1346
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1347
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1348
+ logger.info(f" Continuing training from global step {cur_step}")
1349
+
1350
+ steps_trained_progress_bar.update(cur_step)
1351
+
1352
+ for epoch in range(0, epochs_trained):
1353
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1354
+
1355
+ if not data_args.streaming and training_args.max_steps < 0:
1356
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1357
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1358
+ else:
1359
+ # Currently we don't know how many steps we've taken in the current epoch
1360
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1361
+ # This is "good enough" for our purposes but not fully correct
1362
+ resume_step = None
1363
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1364
+ else:
1365
+ resume_step = None
1366
+
1367
+ for epoch in range(epochs_trained, num_epochs):
1368
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1369
+ train_dataloader = DataLoader(
1370
+ vectorized_datasets["train"],
1371
+ collate_fn=data_collator,
1372
+ batch_size=per_device_train_batch_size,
1373
+ num_workers=dataloader_num_workers,
1374
+ prefetch_factor=prefetch_factor,
1375
+ pin_memory=training_args.dataloader_pin_memory,
1376
+ )
1377
+ train_dataloader = accelerator.prepare(train_dataloader)
1378
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1379
+ train_dataloader.dataset.set_epoch(epoch)
1380
+
1381
+ if resume_step is not None:
1382
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1383
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1384
+ resume_step = None
1385
+
1386
+ for batch in train_dataloader:
1387
+ with accelerator.accumulate(student_model):
1388
+ loss, train_metric = train_step(batch, temperature=training_args.temperature)
1389
+ accelerator.backward(loss)
1390
+ if accelerator.sync_gradients:
1391
+ accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1392
+ optimizer.step()
1393
+ lr_scheduler.step()
1394
+ optimizer.zero_grad()
1395
+
1396
+ # Check if the accelerator has performed an optimization step behind the scenes
1397
+ if accelerator.sync_gradients:
1398
+ steps_trained_progress_bar.update(1)
1399
+ cur_step += 1
1400
+
1401
+ if cur_step % training_args.logging_steps == 0:
1402
+ steps_trained_progress_bar.write(
1403
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1404
+ f" {train_metric['loss']}, Learning Rate:"
1405
+ f" {lr_scheduler.get_last_lr()[0]})"
1406
+ )
1407
+ log_metric(
1408
+ accelerator,
1409
+ metrics=train_metric,
1410
+ learning_rate=lr_scheduler.get_last_lr()[0],
1411
+ train_time=train_time + time.time() - train_start,
1412
+ step=cur_step,
1413
+ epoch=epoch if data_args.streaming else epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1414
+ prefix="train",
1415
+ )
1416
+
1417
+ # save checkpoint and weights after each save_steps and at the end of training
1418
+ if (cur_step % save_steps == 0) or cur_step == total_train_steps:
1419
+ accelerator.wait_for_everyone()
1420
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1421
+ accelerator.save_state(output_dir=intermediate_dir)
1422
+ unwrapped_model = accelerator.unwrap_model(student_model)
1423
+ unwrapped_model.save_pretrained(
1424
+ intermediate_dir,
1425
+ is_main_process=accelerator.is_main_process,
1426
+ save_function=accelerator.save,
1427
+ )
1428
+ if accelerator.is_main_process:
1429
+ checkpoint_to_be_deleted = rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1430
+ if training_args.push_to_hub:
1431
+ upload_folder(
1432
+ folder_path=training_args.output_dir,
1433
+ repo_id=repo_name,
1434
+ repo_type="model",
1435
+ commit_message=f"Saving train state of step {cur_step}",
1436
+ delete_patterns=checkpoint_to_be_deleted,
1437
+ )
1438
+
1439
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1440
+ train_time += time.time() - train_start
1441
+ student_model.eval()
1442
+ # ======================== Evaluating ==============================
1443
+ for eval_split in all_eval_splits:
1444
+ eval_metrics = []
1445
+ eval_preds = []
1446
+ eval_labels = []
1447
+ eval_start = time.time()
1448
+
1449
+ validation_dataloader = DataLoader(
1450
+ vectorized_datasets[eval_split],
1451
+ collate_fn=data_collator,
1452
+ batch_size=per_device_eval_batch_size,
1453
+ drop_last=False,
1454
+ num_workers=dataloader_num_workers,
1455
+ prefetch_factor=prefetch_factor,
1456
+ pin_memory=training_args.dataloader_pin_memory,
1457
+ )
1458
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1459
+
1460
+ for batch in tqdm(
1461
+ validation_dataloader,
1462
+ desc=f"Evaluating {eval_split}...",
1463
+ position=2,
1464
+ disable=not accelerator.is_local_main_process,
1465
+ ):
1466
+ # Model forward
1467
+ eval_metric = eval_step(batch)
1468
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1469
+ eval_metrics.append(eval_metric)
1470
+
1471
+ # generation
1472
+ if training_args.predict_with_generate:
1473
+ generated_ids = generate_step(batch)
1474
+ # Gather all predictions and targets
1475
+ generated_ids, labels = accelerator.gather_for_metrics(
1476
+ (generated_ids, batch["labels"])
1477
+ )
1478
+ eval_preds.extend(generated_ids)
1479
+ eval_labels.extend(labels)
1480
+
1481
+ eval_time = time.time() - eval_start
1482
+ stack = torch.stack if accelerator.num_processes == 1 else torch.concatenate
1483
+ # normalize eval metrics
1484
+ eval_metrics = {
1485
+ key: torch.mean(stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1486
+ }
1487
+ try:
1488
+ eval_metrics["perplexity"] = math.exp(eval_metrics["ce_loss"])
1489
+ except OverflowError:
1490
+ eval_metrics["perplexity"] = float("inf")
1491
+
1492
+ if training_args.predict_with_generate:
1493
+ pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1494
+ log_pred(
1495
+ accelerator,
1496
+ pred_str,
1497
+ label_str,
1498
+ step=cur_step,
1499
+ epoch=epoch,
1500
+ evaluation_strategy=training_args.evaluation_strategy,
1501
+ prefix=eval_split,
1502
+ )
1503
+
1504
+ # Print metrics and update progress bar
1505
+ logger_desc = " ".join([f"Eval {key}: {value} |" for key, value in eval_metrics.items()])
1506
+ steps_trained_progress_bar.write(
1507
+ f"Eval results for step ({cur_step} / {total_train_steps} | {logger_desc}"
1508
+ )
1509
+
1510
+ log_metric(
1511
+ accelerator,
1512
+ metrics=eval_metrics,
1513
+ train_time=eval_time,
1514
+ step=cur_step,
1515
+ epoch=epoch if data_args.streaming else epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1516
+ prefix=eval_split,
1517
+ )
1518
+
1519
+ # flush the train metrics
1520
+ train_start = time.time()
1521
+
1522
+ # break condition
1523
+ if cur_step == total_train_steps:
1524
+ accelerator.wait_for_everyone()
1525
+ # un-wrap student model for save
1526
+ student_model = accelerator.unwrap_model(student_model)
1527
+ student_model.save_pretrained(
1528
+ training_args.output_dir,
1529
+ is_main_process=accelerator.is_main_process,
1530
+ save_function=accelerator.save,
1531
+ )
1532
+ if training_args.push_to_hub and accelerator.is_main_process:
1533
+ upload_folder(
1534
+ folder_path=training_args.output_dir,
1535
+ repo_id=repo_name,
1536
+ repo_type="model",
1537
+ commit_message=f"Saving final weights of step {cur_step}",
1538
+ )
1539
+ continue_training = False
1540
+ break
1541
+
1542
+ if not continue_training:
1543
+ break
1544
+
1545
+ accelerator.end_training()
1546
+
1547
+
1548
+ if __name__ == "__main__":
1549
+ main()
slurm_job.slurm ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=distil-mistral
3
+ #SBATCH --nodes=1
4
+ # set 24h for job wall time limit
5
+ #SBATCH --time=48:00:00
6
+ #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
7
+ #SBATCH --cpus-per-task=32
8
+ #SBATCH --gres=gpu:8
9
+ #SBATCH --exclusive
10
+ #SBATCH --partition=hopper-prod
11
+ #SBATCH --output=/fsx/sanchit/logs/%x-%j.out
12
+
13
+ set -x -e
14
+
15
+ # START EDIT
16
+ source ~/.bashrc
17
+ source /fsx/sanchit/miniconda3/bin/activate venv
18
+
19
+ LOG_PATH="/fsx/sanchit/logs/main_log.txt"
20
+ SAVE_DIR="/fsx/sanchit"
21
+ # END EDIT
22
+
23
+ echo "START TIME: $(date)"
24
+
25
+ GPUS_PER_NODE=8
26
+ NNODES=$SLURM_NNODES
27
+
28
+ # so processes know who to talk to
29
+ MASTER_ADDR=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1`
30
+
31
+ # From https://i.hsfzxjy.site/2021-03-10-obtain-a-random-unused-tcp-port-with-bash/
32
+ function unused_port() {
33
+ N=${1:-1}
34
+ comm -23 \
35
+ <(seq "1025" "65535" | sort) \
36
+ <(ss -Htan |
37
+ awk '{print $4}' |
38
+ cut -d':' -f2 |
39
+ sort -u) |
40
+ shuf |
41
+ head -n "$N"
42
+ }
43
+ MASTER_PORT=$(unused_port)
44
+
45
+ # export TORCH_CPP_LOG_LEVEL=INFO
46
+ # export TORCH_DISTRIBUTED_DEBUG=DETAIL
47
+
48
+ export LAUNCHER="python -u -m accelerate.commands.launch --config_file ./accelerate_config.yaml"
49
+
50
+ export PROGRAM="./run_distillation.py ./config_mistral_fineweb.yaml"
51
+ export CMD="$LAUNCHER $PROGRAM"
52
+ echo $CMD
53
+
54
+ SRUN_ARGS=" \
55
+ --wait=60 \
56
+ --kill-on-bad-exit=1 \
57
+ "
58
+
59
+ # py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
60
+ clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt
61
+
62
+
63
+ # srun error handling:
64
+ # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
65
+ # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
66
+
67
+ # SRUN_ARGS=" \
68
+ # --wait=60 \
69
+ # --kill-on-bad-exit=1 \
70
+ # "
71
+ #
72
+ # # py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
73
+ # clear; srun $SRUN_ARGS --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt
74
+
75
+ echo "END TIME: $(date)"
76
+
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "additional_special_tokens": [],
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "legacy": true,
35
+ "model_max_length": 1000000000000000019884624838656,
36
+ "pad_token": "</s>",
37
+ "sp_model_kwargs": {},
38
+ "spaces_between_special_tokens": false,
39
+ "tokenizer_class": "LlamaTokenizer",
40
+ "unk_token": "<unk>",
41
+ "use_default_system_prompt": false
42
+ }