sanchit-gandhi HF staff commited on
Commit
253da18
1 Parent(s): 6dc7eab

Saving train state of step 5000

Browse files
Files changed (40) hide show
  1. .gitignore +1 -0
  2. accelerate_config.yaml +17 -0
  3. checkpoint-5000-epoch-0/config.json +27 -0
  4. checkpoint-5000-epoch-0/generation_config.json +7 -0
  5. checkpoint-5000-epoch-0/model-00001-of-00002.safetensors +3 -0
  6. checkpoint-5000-epoch-0/model-00002-of-00002.safetensors +3 -0
  7. checkpoint-5000-epoch-0/model.safetensors +3 -0
  8. checkpoint-5000-epoch-0/model.safetensors.index.json +64 -0
  9. checkpoint-5000-epoch-0/model_1.safetensors +3 -0
  10. checkpoint-5000-epoch-0/optimizer.bin +3 -0
  11. checkpoint-5000-epoch-0/random_states_0.pkl +3 -0
  12. checkpoint-5000-epoch-0/random_states_1.pkl +3 -0
  13. checkpoint-5000-epoch-0/random_states_2.pkl +3 -0
  14. checkpoint-5000-epoch-0/random_states_3.pkl +3 -0
  15. checkpoint-5000-epoch-0/random_states_4.pkl +3 -0
  16. checkpoint-5000-epoch-0/random_states_5.pkl +3 -0
  17. checkpoint-5000-epoch-0/random_states_6.pkl +3 -0
  18. checkpoint-5000-epoch-0/random_states_7.pkl +3 -0
  19. checkpoint-5000-epoch-0/scheduler.bin +3 -0
  20. config.json +27 -0
  21. config_mistral.yaml +51 -0
  22. distil-mistral/1715161591.5675907/events.out.tfevents.1715161591.ip-26-0-168-34.2221624.1 +3 -0
  23. distil-mistral/1715161591.571527/hparams.yml +18 -0
  24. distil-mistral/1715175038.3339753/events.out.tfevents.1715175038.ip-26-0-163-127.494306.1 +3 -0
  25. distil-mistral/1715175038.3382986/hparams.yml +18 -0
  26. distil-mistral/1715175965.4032657/events.out.tfevents.1715175965.ip-26-0-167-9.1040300.1 +3 -0
  27. distil-mistral/1715175965.40856/hparams.yml +18 -0
  28. distil-mistral/1715186628.352039/events.out.tfevents.1715186628.ip-26-0-168-30.3708820.1 +3 -0
  29. distil-mistral/1715186628.3564456/hparams.yml +18 -0
  30. distil-mistral/events.out.tfevents.1715161582.ip-26-0-168-34.2221624.0 +3 -0
  31. distil-mistral/events.out.tfevents.1715175028.ip-26-0-163-127.494306.0 +3 -0
  32. distil-mistral/events.out.tfevents.1715175944.ip-26-0-167-9.1040300.0 +3 -0
  33. distil-mistral/events.out.tfevents.1715186617.ip-26-0-168-30.3708820.0 +3 -0
  34. generation_config.json +7 -0
  35. run_distillation.py +1623 -0
  36. slurm_job.slurm +75 -0
  37. special_tokens_map.json +24 -0
  38. tokenizer.json +0 -0
  39. tokenizer.model +3 -0
  40. tokenizer_config.json +43 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ wandb
accelerate_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
checkpoint-5000-epoch-0/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/Mistral-1.5B-Instruct-v0.2",
3
+ "architectures": [
4
+ "MistralForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 4096,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 14336,
13
+ "max_position_embeddings": 32768,
14
+ "model_type": "mistral",
15
+ "num_attention_heads": 32,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 8,
18
+ "output_router_logits": true,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 1000000.0,
21
+ "sliding_window": null,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.40.0.dev0",
25
+ "use_cache": true,
26
+ "vocab_size": 32000
27
+ }
checkpoint-5000-epoch-0/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "max_length": 4096,
6
+ "transformers_version": "4.40.0.dev0"
7
+ }
checkpoint-5000-epoch-0/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fce5fe22c7ceb73ee7bae7a9e1fe31f9a1ef7d2d00264003ade9e04e31bddd21
3
+ size 4987196936
checkpoint-5000-epoch-0/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07aa70a89248b4efdc4b4dc4f6e9cfa62f7febbae5a0ed1886166e356483d37e
3
+ size 1296089984
checkpoint-5000-epoch-0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6adc663990edb408b721efebed4f83689a9810a17b87fea53fa97e13617eea90
3
+ size 6283286904
checkpoint-5000-epoch-0/model.safetensors.index.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 6283280384
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00002-of-00002.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
29
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
30
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
41
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
42
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.5.input_layernorm.weight": "model-00002-of-00002.safetensors",
54
+ "model.layers.5.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
55
+ "model.layers.5.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
56
+ "model.layers.5.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
57
+ "model.layers.5.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
58
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.5.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
60
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
62
+ "model.norm.weight": "model-00002-of-00002.safetensors"
63
+ }
64
+ }
checkpoint-5000-epoch-0/model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65173ac419081e25f0d5f93ed77393cf05f5158325ee154a5cbb3e14b47ece07
3
+ size 4450837792
checkpoint-5000-epoch-0/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3561e9158d61994bdb15ca2180ccc6ea1d5c4649ea70ac711066d807f4627950
3
+ size 3147845370
checkpoint-5000-epoch-0/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aca503cc09e63ca033e29a437a20cc580a9c1db27fef2174e533f58ba275879
3
+ size 16100
checkpoint-5000-epoch-0/random_states_1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31831c2134536b1e81ba1e763e72b2ff98a14a83774fcfb30d153a66dca7879c
3
+ size 16100
checkpoint-5000-epoch-0/random_states_2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a628258539b4090ce50e9faf5fda4d613f523ca957f3e837c02d316e4b20122
3
+ size 16100
checkpoint-5000-epoch-0/random_states_3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d594aa54f68e8eb41c3deb9753bf43474028f44edb92db1930ebdf967f708a7c
3
+ size 16100
checkpoint-5000-epoch-0/random_states_4.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28ca4240374ff4b93ad0537aca2f28bfc293153a29ee8069cf09d088ca30fee7
3
+ size 16100
checkpoint-5000-epoch-0/random_states_5.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d6f3577977e8c32eac49b1c5136c6718fcd9c66051b703ba6e305cca03a8fb0
3
+ size 16100
checkpoint-5000-epoch-0/random_states_6.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0ef1d86e60e6cedda41454cd08e0b3652ab6a6eb017b4eed0d6b84866ed7d46
3
+ size 16100
checkpoint-5000-epoch-0/random_states_7.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08d860c07ef8d57c8162394106fcd87c34e7924d859b28b4b292e9e792a96af2
3
+ size 16100
checkpoint-5000-epoch-0/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c25f7255aa53945ccffbdb6904da689924024cb2e693a6c6739ade9fae0454a2
3
+ size 1064
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/Mistral-1.5B-Instruct-v0.2",
3
+ "architectures": [
4
+ "MistralForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 4096,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 14336,
13
+ "max_position_embeddings": 32768,
14
+ "model_type": "mistral",
15
+ "num_attention_heads": 32,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 8,
18
+ "output_router_logits": true,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 1000000.0,
21
+ "sliding_window": null,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.40.0.dev0",
25
+ "use_cache": true,
26
+ "vocab_size": 32000
27
+ }
config_mistral.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model arguments
2
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-Instruct-v0.2
3
+ teacher_model_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
4
+ dtype: bfloat16
5
+ load_teacher_in_4bit: true
6
+ optim: adamw_bnb_8bit
7
+
8
+ # Data arguments
9
+ train_dataset_name: sanchit-gandhi/cosmopedia-logprobs
10
+ train_dataset_config_name:
11
+ - auto_math_text
12
+ - khanacademy
13
+ - openstax
14
+ - stanford
15
+ - stories
16
+ - web_samples_v1
17
+ - web_samples_v2
18
+ - wikihow
19
+ train_split_name: train[1000:]
20
+ eval_split_name: train[:1000]
21
+ prompt_column_name: prompt
22
+ eval_prompt_column_name: prompt
23
+ max_steps: 200000
24
+ logprob_threshold: -1.5
25
+
26
+ # Training arguments
27
+ do_train: true
28
+ do_eval: true
29
+ per_device_eval_batch_size: 8
30
+ per_device_train_batch_size: 8
31
+ gradient_accumulation_steps: 1
32
+ gradient_checkpointing: true
33
+ max_label_length: 4096
34
+ learning_rate: 0.0001
35
+ warmup_steps: 500
36
+ dataloader_num_workers: 4
37
+ preprocessing_num_workers: 32
38
+ ddp_timeout: 7200
39
+ save_strategy: steps
40
+ save_steps: 5000
41
+ evaluation_strategy: steps
42
+ eval_steps: 5000
43
+ logging_steps: 25
44
+ output_router_logits: true
45
+ report_to: all
46
+ output_dir: ./
47
+ overwrite_output_dir: false
48
+ save_total_limit: 1
49
+ wandb_project: distil-mistral
50
+ push_to_hub: true
51
+
distil-mistral/1715161591.5675907/events.out.tfevents.1715161591.ip-26-0-168-34.2221624.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29f6e1080b7adf540f82bccd4f0724e99020740f191cd4328f104db18efd383
3
+ size 1268
distil-mistral/1715161591.571527/hparams.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ global_batch_size: 64
4
+ gradient_accumulation_steps: 1
5
+ learning_rate: 0.0001
6
+ logprob_threshold: -1.5
7
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
8
+ - linear
9
+ max_label_length: 4096
10
+ max_steps: 200000
11
+ mixed_precision: bf16
12
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-Instruct-v0.2
13
+ num_train_epochs: 3.0
14
+ per_device_train_batch_size: 8
15
+ teacher_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
16
+ temperature: 2.0
17
+ warmup_steps: 500
18
+ weight_decay: 0.0
distil-mistral/1715175038.3339753/events.out.tfevents.1715175038.ip-26-0-163-127.494306.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff596db682c97ae007ce69f2d81a483b4aaef7990a3352da665a212af781542f
3
+ size 1268
distil-mistral/1715175038.3382986/hparams.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ global_batch_size: 64
4
+ gradient_accumulation_steps: 1
5
+ learning_rate: 0.0001
6
+ logprob_threshold: -1.5
7
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
8
+ - linear
9
+ max_label_length: 4096
10
+ max_steps: 200000
11
+ mixed_precision: bf16
12
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-Instruct-v0.2
13
+ num_train_epochs: 3.0
14
+ per_device_train_batch_size: 8
15
+ teacher_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
16
+ temperature: 2.0
17
+ warmup_steps: 500
18
+ weight_decay: 0.0
distil-mistral/1715175965.4032657/events.out.tfevents.1715175965.ip-26-0-167-9.1040300.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:114843fc2a3e0dda7fb196209d9a5e72530b8801b300455841d8450fa8cfe608
3
+ size 1268
distil-mistral/1715175965.40856/hparams.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ global_batch_size: 64
4
+ gradient_accumulation_steps: 1
5
+ learning_rate: 0.0001
6
+ logprob_threshold: -1.5
7
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
8
+ - linear
9
+ max_label_length: 4096
10
+ max_steps: 200000
11
+ mixed_precision: bf16
12
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-Instruct-v0.2
13
+ num_train_epochs: 3.0
14
+ per_device_train_batch_size: 8
15
+ teacher_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
16
+ temperature: 2.0
17
+ warmup_steps: 500
18
+ weight_decay: 0.0
distil-mistral/1715186628.352039/events.out.tfevents.1715186628.ip-26-0-168-30.3708820.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bd087787008bdb404a0402f1d86f94da145942f0e941f4b49b89e9e193aa8d8
3
+ size 1268
distil-mistral/1715186628.3564456/hparams.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ global_batch_size: 64
4
+ gradient_accumulation_steps: 1
5
+ learning_rate: 0.0001
6
+ logprob_threshold: -1.5
7
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
8
+ - linear
9
+ max_label_length: 4096
10
+ max_steps: 200000
11
+ mixed_precision: bf16
12
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-Instruct-v0.2
13
+ num_train_epochs: 3.0
14
+ per_device_train_batch_size: 8
15
+ teacher_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
16
+ temperature: 2.0
17
+ warmup_steps: 500
18
+ weight_decay: 0.0
distil-mistral/events.out.tfevents.1715161582.ip-26-0-168-34.2221624.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a663aedff86fb9e4775d4658e2116d8ab90e310026f5a48aa6233bbd7b93220
3
+ size 88
distil-mistral/events.out.tfevents.1715175028.ip-26-0-163-127.494306.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5677f34be2c805c3714a6ba9ee771fb011bcaba6d0e9bd6a8d8e089ab05ceb42
3
+ size 1000
distil-mistral/events.out.tfevents.1715175944.ip-26-0-167-9.1040300.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8720e414b2153a47ac3ab8f9ca28f99b332ed9f2e0beac5b1ef3f5e68b7d8509
3
+ size 1304
distil-mistral/events.out.tfevents.1715186617.ip-26-0-168-30.3708820.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e63ebd063802579cffe4f5e27630d7b75eff24dfd33d15544531de93c194c8d2
3
+ size 62058
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "max_length": 4096,
6
+ "transformers_version": "4.40.0.dev0"
7
+ }
run_distillation.py ADDED
@@ -0,0 +1,1623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training langauge models for conditional language modelling tasks via teacher-student distillation.
18
+ """
19
+ # You can also adapt this script for your own distillation tasks. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import shutil
26
+ import sys
27
+ import time
28
+ from dataclasses import dataclass, field
29
+ from functools import partial
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Union
32
+
33
+ import datasets
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from datasets import (
41
+ Dataset,
42
+ DatasetDict,
43
+ IterableDataset,
44
+ IterableDatasetDict,
45
+ concatenate_datasets,
46
+ interleave_datasets,
47
+ load_dataset,
48
+ )
49
+ from huggingface_hub import create_repo, get_full_repo_name, upload_folder
50
+ from peft import LoraConfig, get_peft_model
51
+ from torch.utils.data import DataLoader
52
+ from tqdm import tqdm
53
+ from transformers import (
54
+ AutoConfig,
55
+ AutoModelForCausalLM,
56
+ AutoTokenizer,
57
+ BatchEncoding,
58
+ BitsAndBytesConfig,
59
+ HfArgumentParser,
60
+ PreTrainedTokenizerBase,
61
+ Seq2SeqTrainingArguments,
62
+ get_scheduler,
63
+ set_seed, is_bitsandbytes_available,
64
+ )
65
+ from transformers.training_args import OptimizerNames
66
+ from transformers.utils import check_min_version
67
+ from transformers.utils.versions import require_version
68
+
69
+
70
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
71
+ check_min_version("4.34.0.dev0")
72
+
73
+ require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")
74
+
75
+ logger = get_logger(__name__)
76
+
77
+
78
+ @dataclass
79
+ class ModelArguments:
80
+ """
81
+ Arguments pertaining to which model/config/tokenizer we are going to distill from.
82
+ """
83
+
84
+ model_name_or_path: str = field(
85
+ metadata={"help": "Path to pretrained Whisper model or model identifier from huggingface.co/models"}
86
+ )
87
+ teacher_model_name_or_path: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "Path to pretrained teacher model or model identifier from huggingface.co/models"}
90
+ )
91
+ config_name: Optional[str] = field(
92
+ default=None,
93
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
94
+ )
95
+ tokenizer_name: Optional[str] = field(
96
+ default=None,
97
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
98
+ )
99
+ cache_dir: Optional[str] = field(
100
+ default=None,
101
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
102
+ )
103
+ use_fast_tokenizer: bool = field(
104
+ default=True,
105
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
106
+ )
107
+ model_revision: str = field(
108
+ default="main",
109
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
110
+ )
111
+ subfolder: str = field(
112
+ default="",
113
+ metadata={
114
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
115
+ "specify the folder name here."
116
+ },
117
+ )
118
+ token: str = field(
119
+ default=None,
120
+ metadata={
121
+ "help": (
122
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
123
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
124
+ )
125
+ },
126
+ )
127
+ attn_implementation: Optional[str] = field(
128
+ default=None,
129
+ metadata={
130
+ "help": (
131
+ "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n"
132
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
133
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
134
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
135
+ )
136
+ },
137
+ )
138
+ load_teacher_in_8bit: bool = field(default=False, metadata={"help": "Use 8 bit precision for the teacher model."})
139
+ load_teacher_in_4bit: bool = field(default=False, metadata={"help": "Use 4 bit precision for the teacher model."})
140
+ load_student_in_8bit: bool = field(default=False, metadata={"help": "Use 8 bit precision for the student model."})
141
+ load_student_in_4bit: bool = field(default=False, metadata={"help": "Use 4 bit precision for the student model."})
142
+ bnb_4bit_quant_type: Optional[str] = field(
143
+ default="nf4", metadata={"help": "Quantization type if the teacher is quantized (fp4 or nf4)"}
144
+ )
145
+ use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether or not to use nested quantization."})
146
+ lora_r: Optional[int] = field(
147
+ default=16,
148
+ metadata={"help": "LoRA R value."},
149
+ )
150
+ lora_alpha: Optional[int] = field(
151
+ default=32,
152
+ metadata={"help": "LoRA alpha."},
153
+ )
154
+ lora_dropout: Optional[float] = field(
155
+ default=0.05,
156
+ metadata={"help": "LoRA dropout."},
157
+ )
158
+ lora_target_modules: Optional[List[str]] = field(
159
+ default=None,
160
+ metadata={"help": "LoRA target modules."},
161
+ )
162
+ lora_modules_to_save: Optional[List[str]] = field(
163
+ default=None,
164
+ metadata={"help": "Model layers to unfreeze & train"},
165
+ )
166
+ instruction_model: Optional[bool] = field(
167
+ default=None,
168
+ metadata={"help": "Whether or not the pre-trained model is instruction tuned"},
169
+ )
170
+
171
+
172
+ @dataclass
173
+ class DataTrainingArguments:
174
+ """
175
+ Arguments pertaining to what data we are going to input our model for training and eval.
176
+ """
177
+
178
+ train_dataset_name: List[str] = field(
179
+ default=None,
180
+ metadata={
181
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
182
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load LibriSpeech "
183
+ "and Common Voice, set `train_dataset_name='librispeech_asr+common_voice'`."
184
+ },
185
+ )
186
+ train_dataset_config_name: Optional[List[str]] = field(
187
+ default=None,
188
+ metadata={
189
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
190
+ "multiple datasets by separating dataset configs by a '+' symbol. Note that the order of the configs should "
191
+ "match the order of the datasets."
192
+ },
193
+ )
194
+ train_dataset_samples: Optional[List[str]] = field(
195
+ default=None,
196
+ metadata={
197
+ "help": "Number of samples in each dataset when loading multiple datasets with streaming mode. "
198
+ "Not required when using one dataset or non-streaming mode. The sample values provide the sampling "
199
+ "probability for each dataset. Setting them equal to the number of sample values ensures that every "
200
+ "sample from every dataset is used once per epoch."
201
+ },
202
+ )
203
+ eval_dataset_name: Optional[List[str]] = field(
204
+ default=None,
205
+ metadata={
206
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training "
207
+ "dataset name if unspecified. Load multiple evaluation datasets by separating dataset "
208
+ "ids by a '+' symbol."
209
+ },
210
+ )
211
+ eval_dataset_config_name: Optional[List[str]] = field(
212
+ default=None,
213
+ metadata={
214
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the "
215
+ "training dataset config name if unspecified."
216
+ },
217
+ )
218
+ dataset_cache_dir: Optional[str] = field(
219
+ default=None,
220
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
221
+ )
222
+ overwrite_cache: bool = field(
223
+ default=False,
224
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
225
+ )
226
+ preprocessing_num_workers: Optional[int] = field(
227
+ default=None,
228
+ metadata={"help": "The number of processes to use for the preprocessing if using non-streaming mode."},
229
+ )
230
+ max_train_samples: Optional[int] = field(
231
+ default=None,
232
+ metadata={
233
+ "help": (
234
+ "For debugging purposes or quicker training, truncate the number of training examples to this value if set."
235
+ )
236
+ },
237
+ )
238
+ max_eval_samples: Optional[int] = field(
239
+ default=None,
240
+ metadata={
241
+ "help": (
242
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set."
243
+ )
244
+ },
245
+ )
246
+ text_column_name: Optional[List[str]] = field(
247
+ default=None,
248
+ metadata={"help": "The name of the dataset column containing the generated text data in the training set."},
249
+ )
250
+ prompt_column_name: Optional[List[str]] = field(
251
+ default=None,
252
+ metadata={"help": "The name of the dataset column containing the prompt data. Defaults to 'prompt'"},
253
+ )
254
+ eval_text_column_name: Optional[List[str]] = field(
255
+ default=None,
256
+ metadata={"help": "The name of the dataset column containing the generated text data in the evaluation set."},
257
+ )
258
+ eval_prompt_column_name: Optional[List[str]] = field(
259
+ default=None,
260
+ metadata={"help": "The name of the dataset column containing the prompt data in the evaluation set."},
261
+ )
262
+ max_label_length: Optional[int] = field(
263
+ default=4096,
264
+ metadata={"help": "Truncate target labels that are longer `max_label_length` tokens."},
265
+ )
266
+ logprob_threshold: Optional[float] = field(
267
+ default=None,
268
+ metadata={"help": "Filter training examples with avg log-probability less than `logprob_threshold`."},
269
+ )
270
+ pad_target_to_multiple_of: Optional[int] = field(
271
+ default=None,
272
+ metadata={
273
+ "help": (
274
+ "If set will pad the target sequence to a multiple of the provided value. This is important to "
275
+ "avoid triggering recompilations when using torch compile. If unspecified, will default to padding "
276
+ "the targets to max length."
277
+ )
278
+ },
279
+ )
280
+ preprocessing_only: bool = field(
281
+ default=False,
282
+ metadata={
283
+ "help": (
284
+ "Whether to only do data preprocessing and skip training. This is especially useful when data "
285
+ "preprocessing errors out in distributed training due to timeout. In this case, one should run the "
286
+ "preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets "
287
+ "can consequently be loaded in distributed training"
288
+ )
289
+ },
290
+ )
291
+ train_split_name: Optional[List[str]] = field(
292
+ default=lambda: ["train"],
293
+ metadata={
294
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
295
+ },
296
+ )
297
+ eval_split_name: Optional[List[str]] = field(
298
+ default=lambda: ["validation"],
299
+ metadata={
300
+ "help": (
301
+ "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
302
+ )
303
+ },
304
+ )
305
+ streaming: bool = field(
306
+ default=False,
307
+ metadata={"help": "Whether to use Datasets' streaming mode to load and pre-process the data."},
308
+ )
309
+ wandb_project: str = field(
310
+ default="distil-mistral",
311
+ metadata={"help": "The name of the wandb project."},
312
+ )
313
+
314
+
315
+ @dataclass
316
+ class DistillationTrainingArguments(Seq2SeqTrainingArguments):
317
+ freeze_embeddings: Optional[bool] = field(
318
+ default=False, metadata={"help": "Whether to freeze the input and output embeddings of the student model."}
319
+ )
320
+ freeze_n_layers: Optional[int] = field(
321
+ default=None, metadata={"help": "Freeze the first n layers of the student model."}
322
+ )
323
+ temperature: Optional[float] = field(
324
+ default=2.0, metadata={"help": "Temperature to anneal the logits when computing the softmax."}
325
+ )
326
+ kl_weight: Optional[float] = field(
327
+ default=1.0,
328
+ metadata={
329
+ "help": (
330
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
331
+ "computed between the teacher-student hidden states and attentions."
332
+ )
333
+ },
334
+ )
335
+ output_router_logits: Optional[bool] = field(
336
+ default=False,
337
+ metadata={
338
+ "help": "Whether or not to return the router logits in the forward pass. Enabling this will "
339
+ "also configure the model to compute the auxiliary loss."
340
+ },
341
+ )
342
+ dtype: Optional[str] = field(
343
+ default="float32",
344
+ metadata={
345
+ "help": (
346
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
347
+ "`float16` or `bfloat16` (both half-precision)."
348
+ )
349
+ },
350
+ )
351
+ completions_only: Optional[bool] = field(
352
+ default=False,
353
+ metadata={
354
+ "help": "Whether to train only on the target completions, or the prompt + completions."
355
+ },
356
+ )
357
+
358
+
359
+ @dataclass
360
+ class DataCollatorCausalLMWithPadding:
361
+ """
362
+ Data collator that will dynamically pad the inputs received.
363
+ Args:
364
+ tokenizer ([`PreTrainedTokenizer`])
365
+ The tokenizer used for tokenizing the data.
366
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
367
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
368
+ See above for details.
369
+ max_target_length (:obj:`int`, `optional`):
370
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
371
+ completions_only (:obj:`bool`, `optional`):
372
+ Whether to train on the assistant responses (completions) only, or the combination of prompt + responses.
373
+ """
374
+
375
+ tokenizer: PreTrainedTokenizerBase
376
+ target_padding: Union[bool, str] = "max_length"
377
+ max_target_length: Optional[int] = None
378
+ completions_only: Optional[bool] = False
379
+
380
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> BatchEncoding:
381
+ # dataloader returns a list of features which we convert to a dict
382
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
383
+ label_lengths = [len(feature["labels"]) for feature in features]
384
+ prompt_lengths = [feature["prompt_length"] for feature in features]
385
+
386
+ batch = self.tokenizer.pad(
387
+ label_features,
388
+ max_length=self.max_target_length,
389
+ padding=self.target_padding,
390
+ return_tensors="pt",
391
+ )
392
+
393
+ labels_mask = batch["attention_mask"]
394
+
395
+ if self.completions_only:
396
+ # don't include prompts in loss calculation
397
+ for idx in range(len(prompt_lengths)):
398
+ padding_length = labels_mask.shape[1] - label_lengths[idx]
399
+ labels_mask[idx, padding_length : padding_length + prompt_lengths[idx]] = 0
400
+
401
+ # replace padding with -100 to ignore loss correctly
402
+ labels = batch["input_ids"].masked_fill(labels_mask.ne(1), -100)
403
+
404
+ batch["labels"] = labels
405
+
406
+ return batch
407
+
408
+
409
+ def log_metric(
410
+ accelerator,
411
+ metrics: Dict,
412
+ train_time: float,
413
+ step: int,
414
+ epoch: int,
415
+ learning_rate: float = None,
416
+ prefix: str = "train",
417
+ ):
418
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
419
+ log_metrics = {}
420
+ for k, v in metrics.items():
421
+ log_metrics[f"{prefix}/{k}"] = v
422
+ log_metrics[f"{prefix}/time"] = train_time
423
+ log_metrics[f"{prefix}/epoch"] = epoch
424
+ if learning_rate is not None:
425
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
426
+ accelerator.log(log_metrics, step=step)
427
+
428
+
429
+ def log_pred(
430
+ accelerator,
431
+ pred_str: List[str],
432
+ label_str: List[str],
433
+ step: int,
434
+ epoch: int,
435
+ evaluation_strategy: str,
436
+ prefix: str = "eval",
437
+ num_lines: int = 200000,
438
+ ):
439
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
440
+ if accelerator.is_main_process:
441
+ wandb_tracker = accelerator.get_tracker("wandb")
442
+ # pretty name for current step: step 50000 -> step 50k
443
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
444
+ prefix_pretty = prefix.replace("/", "-")
445
+
446
+ if evaluation_strategy == "epoch":
447
+ table_name = f"predictions/{prefix_pretty}-epoch-{epoch}"
448
+ else:
449
+ table_name = f"predictions/{prefix_pretty}-step-{cur_step_pretty}"
450
+
451
+ # convert str data to a wandb compatible format
452
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
453
+ # log as a table with the appropriate headers
454
+ wandb_tracker.log_table(
455
+ table_name=table_name,
456
+ columns=["Target", "Pred"],
457
+ data=str_data[:num_lines],
458
+ step=step,
459
+ )
460
+
461
+
462
+ def convert_dataset_str_to_list(
463
+ dataset_names,
464
+ dataset_config_names,
465
+ splits=None,
466
+ text_column_names=None,
467
+ prompt_column_names=None,
468
+ dataset_samples=None,
469
+ default_split="train",
470
+ ) -> List[Dict]:
471
+ """
472
+ Given three lists of dataset names, configs and splits, this function groups the corresponding
473
+ names/configs/splits. Each dataset is assigned a unique dictionary with these metadata values, and the
474
+ function returns a list of dictionaries, one for each dataset.
475
+ """
476
+ if isinstance(dataset_names, str):
477
+ dataset_names = [dataset_names]
478
+ splits = [splits] if splits else None
479
+ text_column_names = [text_column_names] if text_column_names else None
480
+ prompt_column_names = [prompt_column_names] if prompt_column_names else None
481
+ if isinstance(dataset_config_names, str):
482
+ dataset_config_names = [dataset_config_names]
483
+
484
+ if len(dataset_names) == 1 and len(dataset_config_names) > 1:
485
+ dataset_names = len(dataset_config_names) * dataset_names
486
+
487
+ if isinstance(splits, list) and len(splits) == 1 and len(dataset_config_names) > 1:
488
+ splits = len(dataset_config_names) * splits
489
+
490
+ if isinstance(text_column_names, list) and len(text_column_names) == 1 and len(dataset_config_names) > 1:
491
+ text_column_names = len(dataset_config_names) * text_column_names
492
+
493
+ if isinstance(prompt_column_names, list) and len(prompt_column_names) == 1 and len(dataset_config_names) > 1:
494
+ prompt_column_names = len(dataset_config_names) * prompt_column_names
495
+
496
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
497
+ if dataset_config_names is not None and len(dataset_names) != len(dataset_config_names):
498
+ raise ValueError(
499
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
500
+ f" {len(dataset_config_names)} configs."
501
+ )
502
+
503
+ if splits is not None and len(splits) != len(dataset_names):
504
+ raise ValueError(
505
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
506
+ )
507
+
508
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
509
+ raise ValueError(
510
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
511
+ f" {len(text_column_names)} text column names."
512
+ )
513
+
514
+ if prompt_column_names is not None and len(prompt_column_names) != len(dataset_names):
515
+ raise ValueError(
516
+ f"Ensure one prompt column name is passed for each dataset, got {len(dataset_names)} datasets and"
517
+ f" {len(prompt_column_names)} prompt column names."
518
+ )
519
+
520
+ if dataset_samples is not None:
521
+ if len(dataset_samples) != len(dataset_names):
522
+ raise ValueError(
523
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
524
+ f"{len(dataset_samples)} samples."
525
+ )
526
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
527
+ else:
528
+ dataset_samples = [None] * len(dataset_names)
529
+
530
+ dataset_config_names = (
531
+ dataset_config_names if dataset_config_names is not None else ["default" for _ in range(len(dataset_names))]
532
+ )
533
+ text_column_names = (
534
+ text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
535
+ )
536
+ prompt_column_names = (
537
+ prompt_column_names if prompt_column_names is not None else [None for _ in range(len(dataset_names))]
538
+ )
539
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
540
+
541
+ dataset_names_dict = []
542
+ for i, ds_name in enumerate(dataset_names):
543
+ dataset_names_dict.append(
544
+ {
545
+ "name": ds_name,
546
+ "config": dataset_config_names[i],
547
+ "split": splits[i],
548
+ "text_column_name": text_column_names[i],
549
+ "prompt_column_name": prompt_column_names[i],
550
+ "samples": dataset_samples[i],
551
+ }
552
+ )
553
+ return dataset_names_dict
554
+
555
+
556
+ def load_multiple_datasets(
557
+ dataset_names: Union[List, str],
558
+ dataset_config_names: Union[List, str],
559
+ splits: Optional[Union[List, str]] = None,
560
+ text_column_names: Optional[List] = None,
561
+ prompt_column_names: Optional[List] = None,
562
+ stopping_strategy: Optional[str] = "first_exhausted",
563
+ dataset_samples: Optional[Union[List, np.array]] = None,
564
+ streaming: Optional[bool] = False,
565
+ seed: Optional[int] = None,
566
+ use_logprobs: Optional[bool] = False,
567
+ accelerator: Optional[Accelerator] = None,
568
+ **kwargs,
569
+ ) -> Union[Dataset, IterableDataset]:
570
+ dataset_names_dict = convert_dataset_str_to_list(
571
+ dataset_names, dataset_config_names, splits, text_column_names, prompt_column_names, dataset_samples
572
+ )
573
+
574
+ if dataset_samples is not None:
575
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
576
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
577
+ else:
578
+ probabilities = None
579
+
580
+ all_datasets = []
581
+ # iterate over the datasets we want to interleave
582
+ for dataset_dict in tqdm(
583
+ dataset_names_dict,
584
+ desc="Combining datasets...",
585
+ disable=not accelerator.is_main_process,
586
+ ):
587
+ dataset = load_dataset(
588
+ dataset_dict["name"],
589
+ dataset_dict["config"],
590
+ split=dataset_dict["split"],
591
+ streaming=streaming,
592
+ **kwargs,
593
+ )
594
+
595
+ columns_to_keep = {"text"}
596
+ dataset_features = dataset.features.keys()
597
+
598
+ if dataset_dict["text_column_name"] not in dataset_features:
599
+ raise ValueError(
600
+ f"Text column name {dataset_dict['text_column_name']} not found in dataset"
601
+ f" '{dataset_dict['name']}'. Make sure to set `--text_column_name` to the"
602
+ f" correct text column - one of {', '.join(dataset_features)}."
603
+ )
604
+
605
+ # blanket renaming of all transcription columns to text
606
+ if dataset_dict["text_column_name"] != "text":
607
+ dataset = dataset.rename_column(dataset_dict["text_column_name"], "text")
608
+
609
+ # blanket renaming of all prompt columns to prompt
610
+ if dataset_dict["prompt_column_name"] is not None:
611
+ if dataset_dict["prompt_column_name"] not in dataset_features:
612
+ raise ValueError(
613
+ f"Prompt column name {dataset_dict['prompt_column_name']} not found in dataset"
614
+ f" '{dataset_dict['name']}'. Make sure to set `--prompt_column_name` to the"
615
+ f" correct prompt column - one of {', '.join(dataset_features)}."
616
+ )
617
+ elif dataset_dict["prompt_column_name"] != "prompt":
618
+ dataset = dataset.rename_column(dataset_dict["prompt_column_name"], "prompt")
619
+ columns_to_keep.add("prompt")
620
+
621
+ if use_logprobs:
622
+ if "logprobs" not in dataset_features:
623
+ raise ValueError(
624
+ "If setting `logprob_threshold`, ensure that the column 'logprobs' is in the dataset "
625
+ f"'{dataset_dict['name']}'. Got the following columns: {', '.join(dataset_features)}."
626
+ )
627
+ columns_to_keep.add("logprobs")
628
+
629
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
630
+ all_datasets.append(dataset)
631
+
632
+ if len(all_datasets) == 1:
633
+ # we have a single dataset so just return it as is
634
+ return all_datasets[0]
635
+
636
+ if streaming:
637
+ interleaved_dataset = interleave_datasets(
638
+ all_datasets,
639
+ stopping_strategy=stopping_strategy,
640
+ probabilities=probabilities,
641
+ seed=seed,
642
+ )
643
+ else:
644
+ interleaved_dataset = concatenate_datasets(all_datasets)
645
+
646
+ # shuffle mixed dataset prior to potentially truncating it
647
+ interleaved_dataset = interleaved_dataset.shuffle(seed)
648
+ return interleaved_dataset
649
+
650
+
651
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
652
+ """Helper function to sort saved checkpoints from oldest to newest."""
653
+ ordering_and_checkpoint_path = []
654
+
655
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
656
+
657
+ for path in glob_checkpoints:
658
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
659
+ if regex_match is not None and regex_match.groups() is not None:
660
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
661
+
662
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
663
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
664
+ return checkpoints_sorted
665
+
666
+
667
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> Union[List, None]:
668
+ """Helper function to delete old checkpoints."""
669
+ if save_total_limit is None or save_total_limit <= 0:
670
+ return
671
+ # Check if we should delete older checkpoint(s)
672
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
673
+ if len(checkpoints_sorted) <= save_total_limit:
674
+ return
675
+
676
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
677
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
678
+ for checkpoint in checkpoints_to_be_deleted:
679
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
680
+ shutil.rmtree(checkpoint, ignore_errors=True)
681
+ checkpoints_to_be_deleted = [f"*{Path(checkpoint).absolute().name}*" for checkpoint in checkpoints_to_be_deleted]
682
+ return checkpoints_to_be_deleted
683
+
684
+
685
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
686
+
687
+
688
+ def get_last_checkpoint(folder):
689
+ content = os.listdir(folder)
690
+ checkpoints = [
691
+ path
692
+ for path in content
693
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
694
+ ]
695
+ if len(checkpoints) == 0:
696
+ return
697
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
698
+
699
+
700
+ def get_parameter_names(model, forbidden_layer_types, forbidden_module=None):
701
+ """
702
+ Returns the names of the model parameters that are not inside a forbidden layer or forbidden module.
703
+ Can be used to get a subset of parameter names for decay masks, or to exclude parameters from an optimiser
704
+ (e.g. if the module is frozen).
705
+ """
706
+ result = []
707
+ for name, child in model.named_children():
708
+ result += [
709
+ f"{name}.{n}"
710
+ for n in get_parameter_names(child, forbidden_layer_types, forbidden_module)
711
+ if not (
712
+ isinstance(child, tuple(forbidden_layer_types))
713
+ or (child in tuple(forbidden_module) if forbidden_module is not None else False)
714
+ )
715
+ ]
716
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
717
+ result += list(model._parameters.keys())
718
+ return result
719
+
720
+
721
+ def get_quantization_config(
722
+ model_args: ModelArguments, torch_dtype: torch.dtype
723
+ ) -> tuple[BitsAndBytesConfig | None, BitsAndBytesConfig | None]:
724
+ if model_args.load_teacher_in_4bit:
725
+ quantization_config_teacher = BitsAndBytesConfig(
726
+ load_in_4bit=True,
727
+ bnb_4bit_compute_dtype=torch_dtype,
728
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
729
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
730
+ )
731
+ elif model_args.load_teacher_in_8bit:
732
+ quantization_config_teacher = BitsAndBytesConfig(load_in_8bit=True)
733
+ else:
734
+ quantization_config_teacher = None
735
+
736
+ if model_args.load_student_in_4bit:
737
+ quantization_config_student = BitsAndBytesConfig(
738
+ load_in_4bit=True,
739
+ bnb_4bit_compute_dtype=torch_dtype,
740
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
741
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
742
+ )
743
+ elif model_args.load_student_in_8bit:
744
+ quantization_config_student = BitsAndBytesConfig(load_in_8bit=True)
745
+ else:
746
+ quantization_config_student = None
747
+
748
+ return quantization_config_teacher, quantization_config_student
749
+
750
+
751
+ def main():
752
+ # 1. Parse input arguments
753
+ # We keep distinct sets of args, for cleaner separation of model/data/training related args
754
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, DistillationTrainingArguments))
755
+
756
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
757
+ # If we pass only one argument to the script and it's the path to a json file,
758
+ # let's parse it to get our arguments.
759
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
760
+ elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
761
+ # If we pass only one argument to the script and it's the path to a yaml file,
762
+ # let's parse it to get our arguments.
763
+ model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1]))
764
+ else:
765
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
766
+
767
+ # 2. Initialize the accelerator
768
+ # We will let the accelerator handle device placement for us in this example
769
+ # We simply have to specify the training precision and any trackers being used
770
+ # We'll use the same dtype arguments as our JAX/Flax training script and convert
771
+ # it to accelerate format
772
+ if training_args.dtype == "float16":
773
+ mixed_precision = "fp16"
774
+ teacher_dtype = torch.float16
775
+ elif training_args.dtype == "bfloat16":
776
+ mixed_precision = "bf16"
777
+ teacher_dtype = torch.bfloat16
778
+ else:
779
+ mixed_precision = "no"
780
+ teacher_dtype = torch.float32
781
+
782
+ accelerator = Accelerator(
783
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
784
+ mixed_precision=mixed_precision,
785
+ log_with=training_args.report_to,
786
+ project_dir=training_args.output_dir,
787
+ )
788
+
789
+ accelerator.init_trackers(
790
+ project_name=data_args.wandb_project,
791
+ config={
792
+ "learning_rate": training_args.learning_rate,
793
+ "model_name_or_path": model_args.model_name_or_path,
794
+ "teacher_name_or_path": model_args.teacher_model_name_or_path,
795
+ "num_train_epochs": training_args.num_train_epochs,
796
+ "max_steps": training_args.max_steps,
797
+ "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
798
+ "per_device_train_batch_size": training_args.per_device_train_batch_size,
799
+ "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
800
+ "mixed_precision": mixed_precision,
801
+ "lr_scheduler_type": training_args.lr_scheduler_type,
802
+ "warmup_steps": training_args.warmup_steps,
803
+ "weight_decay": training_args.weight_decay,
804
+ "adam_beta1": training_args.adam_beta1,
805
+ "adam_beta2": training_args.adam_beta2,
806
+ "temperature": training_args.temperature,
807
+ "logprob_threshold": data_args.logprob_threshold,
808
+ "max_label_length": data_args.max_label_length,
809
+ },
810
+ )
811
+
812
+ # 3. Set-up basic logging
813
+ # Create one log on every process with the configuration for debugging
814
+ logging.basicConfig(
815
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
816
+ datefmt="%m/%d/%Y %H:%M:%S",
817
+ level=logging.INFO,
818
+ )
819
+ # Log a small summary on each proces
820
+ logger.warning(
821
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
822
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
823
+ )
824
+
825
+ # Set the verbosity to info of the Transformers logger (on main process only)
826
+ if accelerator.is_local_main_process:
827
+ datasets.utils.logging.set_verbosity_warning()
828
+ transformers.utils.logging.set_verbosity_info()
829
+ else:
830
+ datasets.utils.logging.set_verbosity_error()
831
+ transformers.utils.logging.set_verbosity_error()
832
+ logger.info("Training/evaluation parameters %s", training_args)
833
+
834
+ # 4. Detecting last checkpoint and eventually continue from last checkpoint
835
+ last_checkpoint = None
836
+ if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
837
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
838
+ if last_checkpoint is None and len(sorted_checkpoints(training_args.output_dir)) > 0:
839
+ raise ValueError(
840
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
841
+ "Use --overwrite_output_dir to overcome."
842
+ )
843
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
844
+ logger.info(
845
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
846
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
847
+ )
848
+
849
+ # 5. Handle the repository creation
850
+ if accelerator.is_main_process:
851
+ if training_args.output_dir is not None:
852
+ os.makedirs(training_args.output_dir, exist_ok=True)
853
+ if training_args.push_to_hub:
854
+ if training_args.hub_model_id is None:
855
+ repo_name = get_full_repo_name(
856
+ Path(training_args.output_dir).absolute().name,
857
+ token=training_args.hub_token,
858
+ )
859
+ else:
860
+ repo_name = training_args.hub_model_id
861
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
862
+
863
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
864
+ if "wandb" not in gitignore:
865
+ gitignore.write("wandb\n")
866
+ accelerator.wait_for_everyone()
867
+
868
+ # 6. Load dataset - either streaming or non-streaming (offline)
869
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
870
+
871
+ # set seed for determinism
872
+ set_seed(training_args.seed)
873
+
874
+ raw_datasets["train"] = load_multiple_datasets(
875
+ data_args.train_dataset_name,
876
+ data_args.train_dataset_config_name,
877
+ splits=data_args.train_split_name,
878
+ text_column_names=data_args.text_column_name,
879
+ prompt_column_names=data_args.prompt_column_name,
880
+ streaming=data_args.streaming,
881
+ dataset_samples=data_args.train_dataset_samples,
882
+ seed=training_args.seed,
883
+ use_logprobs=data_args.logprob_threshold is not None,
884
+ accelerator=accelerator,
885
+ cache_dir=data_args.dataset_cache_dir,
886
+ token=model_args.token,
887
+ num_proc=data_args.preprocessing_num_workers,
888
+ )
889
+ raw_datasets_train_features = set(raw_datasets["train"].features.keys())
890
+
891
+ if training_args.do_eval:
892
+ dataset_names_dict = convert_dataset_str_to_list(
893
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
894
+ (
895
+ data_args.eval_dataset_config_name
896
+ if data_args.eval_dataset_config_name
897
+ else data_args.train_dataset_config_name
898
+ ),
899
+ splits=data_args.eval_split_name,
900
+ text_column_names=data_args.eval_text_column_name,
901
+ prompt_column_names=data_args.eval_prompt_column_name,
902
+ )
903
+ all_eval_splits = []
904
+ if len(dataset_names_dict) == 1:
905
+ # load a single eval set
906
+ dataset_dict = dataset_names_dict[0]
907
+ all_eval_splits.append("eval")
908
+ raw_datasets["eval"] = load_dataset(
909
+ dataset_dict["name"],
910
+ dataset_dict["config"],
911
+ split=dataset_dict["split"],
912
+ cache_dir=data_args.dataset_cache_dir,
913
+ token=model_args.token,
914
+ streaming=data_args.streaming,
915
+ )
916
+ if dataset_dict["text_column_name"] != "text":
917
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_text_column_name, "text")
918
+ if dataset_dict["prompt_column_name"] and dataset_dict["prompt_column_name"] != "prompt":
919
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_prompt_column_name, "prompt")
920
+ else:
921
+ # load multiple eval sets
922
+ for dataset_dict in dataset_names_dict:
923
+ pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['config'].replace('.', '-')}"
924
+ all_eval_splits.append(pretty_name)
925
+ raw_datasets[pretty_name] = load_dataset(
926
+ dataset_dict["name"],
927
+ dataset_dict["config"],
928
+ split=dataset_dict["split"],
929
+ cache_dir=data_args.dataset_cache_dir,
930
+ token=model_args.token,
931
+ streaming=data_args.streaming,
932
+ )
933
+ # make column names consistent (text, prompt)
934
+ columns_to_keep = {"text"}
935
+ if dataset_dict["text_column_name"] != "text":
936
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
937
+ dataset_dict["text_column_name"], "text"
938
+ )
939
+ if dataset_dict["prompt_column_name"]:
940
+ if dataset_dict["prompt_column_name"] != "prompt":
941
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
942
+ dataset_dict["prompt_column_name"], "prompt"
943
+ )
944
+ columns_to_keep.add("prompt")
945
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
946
+ set(raw_datasets[pretty_name].features.keys()) - columns_to_keep
947
+ )
948
+
949
+ # 7. Load pretrained model, tokenizer, and feature extractor
950
+ config = AutoConfig.from_pretrained(
951
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
952
+ cache_dir=model_args.cache_dir,
953
+ revision=model_args.model_revision,
954
+ token=model_args.token,
955
+ )
956
+ if training_args.output_router_logits:
957
+ config.output_router_logits = True
958
+
959
+ tokenizer = AutoTokenizer.from_pretrained(
960
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
961
+ cache_dir=model_args.cache_dir,
962
+ use_fast=model_args.use_fast_tokenizer,
963
+ revision=model_args.model_revision,
964
+ token=model_args.token,
965
+ )
966
+ if tokenizer.pad_token_id is None:
967
+ tokenizer.pad_token = tokenizer.eos_token
968
+
969
+ quantization_config_teacher, quantization_config_student = get_quantization_config(
970
+ model_args, torch_dtype=teacher_dtype
971
+ )
972
+
973
+ if model_args.teacher_model_name_or_path:
974
+ # Teacher-student distillation
975
+ # The teacher model can safely be cast to the dtype of training since we don't
976
+ # update the params
977
+ teacher_model = AutoModelForCausalLM.from_pretrained(
978
+ model_args.teacher_model_name_or_path,
979
+ cache_dir=model_args.cache_dir,
980
+ token=model_args.token,
981
+ low_cpu_mem_usage=True,
982
+ torch_dtype=teacher_dtype,
983
+ attn_implementation=model_args.attn_implementation,
984
+ quantization_config=quantization_config_teacher,
985
+ ).eval()
986
+ else:
987
+ # Shrink and fine-tune
988
+ teacher_model = None
989
+
990
+ student_model = AutoModelForCausalLM.from_pretrained(
991
+ model_args.model_name_or_path,
992
+ config=config,
993
+ cache_dir=model_args.cache_dir,
994
+ revision=model_args.model_revision,
995
+ subfolder=model_args.subfolder,
996
+ token=model_args.token,
997
+ low_cpu_mem_usage=True,
998
+ attn_implementation=model_args.attn_implementation,
999
+ quantization_config=quantization_config_student,
1000
+ )
1001
+
1002
+ if quantization_config_student is not None:
1003
+ lora_config = LoraConfig(
1004
+ r=model_args.lora_r,
1005
+ lora_alpha=model_args.lora_alpha,
1006
+ target_modules=model_args.lora_target_modules,
1007
+ lora_dropout=model_args.lora_dropout,
1008
+ bias="none",
1009
+ task_type="CAUSAL_LM",
1010
+ )
1011
+ student_model = get_peft_model(student_model, lora_config)
1012
+
1013
+ if student_model.generation_config.bos_token_id is None or (teacher_model and teacher_model.generation_config.bos_token_id is None):
1014
+ student_error = f"Make sure that `generation_config.bos_token_id` is correctly defined. Got {student_model.generation_config.bos_token_id} for the student."
1015
+ teacher_error = f"Got {teacher_model.generation_config.bos_token_id} for the teacher." if teacher_model else None
1016
+ raise ValueError(student_error + teacher_error)
1017
+
1018
+ def set_trainable_parameters(module, requires_grad=False):
1019
+ for param in module.parameters():
1020
+ param.requires_grad = requires_grad
1021
+ module._requires_grad = requires_grad
1022
+
1023
+ forbidden_module = []
1024
+ # freeze student embeddings if necessary
1025
+ if training_args.freeze_embeddings:
1026
+ set_trainable_parameters(student_model.get_output_embeddings(), requires_grad=False)
1027
+ set_trainable_parameters(student_model.get_input_embeddings(), requires_grad=False)
1028
+ forbidden_module.extend([student_model.get_output_embeddings(), student_model.get_input_embeddings()])
1029
+
1030
+ if training_args.freeze_n_layers:
1031
+ for i in range(int(training_args.freeze_n_layers)):
1032
+ set_trainable_parameters(student_model.model.layers[i], requires_grad=False)
1033
+ forbidden_module.extend([student_model.model.layers[i]])
1034
+
1035
+ # enable gradient checkpointing if necessary
1036
+ if training_args.gradient_checkpointing:
1037
+ if training_args.freeze_embeddings or training_args.freeze_n_layers:
1038
+ raise ValueError(
1039
+ "Gradient checkpointing is not compatible with `--freeze_embeddings` or `--freeze_n_layers`. "
1040
+ "Either un-freeze these layers, or set `--gradient_checkpointing=False`."
1041
+ )
1042
+ student_model.gradient_checkpointing_enable()
1043
+
1044
+ student_model.generation_config.max_length = data_args.max_label_length
1045
+
1046
+ # 8. Save all pre-processed tokenizers/config/generation configs
1047
+ if accelerator.is_main_process:
1048
+ tokenizer.save_pretrained(training_args.output_dir)
1049
+ # save the config and generation config as well
1050
+ config.save_pretrained(training_args.output_dir)
1051
+ student_model.generation_config.save_pretrained(training_args.output_dir)
1052
+
1053
+ accelerator.wait_for_everyone()
1054
+
1055
+
1056
+ # 10. Preprocessing the datasets: we need to combine the prompt and generations and tokenize the targets.
1057
+ # 10.1: Define the pre-processing constants
1058
+ max_label_length = (
1059
+ data_args.max_label_length if data_args.max_label_length is not None else config.max_length
1060
+ )
1061
+ num_workers = data_args.preprocessing_num_workers
1062
+ dataloader_num_workers = training_args.dataloader_num_workers
1063
+ prefetch_factor = training_args.dataloader_prefetch_factor
1064
+ eos_token_id = tokenizer.eos_token_id
1065
+ if model_args.instruction_model is not None:
1066
+ instruction_model = model_args.instruction_model
1067
+ else:
1068
+ instruction_model = "instruct" in model_args.model_name_or_path.lower()
1069
+ if instruction_model and "prompt" not in raw_datasets_train_features:
1070
+ raise ValueError(
1071
+ "Distilling an instruction model, but `--prompt_column_name` is set to None. "
1072
+ "Ensure `--prompt_column_name` is set according to the dataset features."
1073
+ )
1074
+
1075
+ # 10.2: filter based on maximum number of training/evaluation samples
1076
+ if data_args.max_train_samples is not None:
1077
+ raw_datasets["train"] = (
1078
+ raw_datasets["train"].take(data_args.max_train_samples)
1079
+ if data_args.streaming
1080
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
1081
+ )
1082
+
1083
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1084
+ for eval_split in all_eval_splits:
1085
+ raw_datasets[eval_split] = (
1086
+ raw_datasets[eval_split].take(data_args.max_eval_samples)
1087
+ if data_args.streaming
1088
+ else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1089
+ )
1090
+
1091
+ # 10.3: filter based on log-probs criteria
1092
+ if data_args.logprob_threshold is not None:
1093
+ logprob_threshold = data_args.logprob_threshold
1094
+
1095
+ def is_logprob_in_range(logprob):
1096
+ return logprob > logprob_threshold
1097
+
1098
+ filter_by_lobprobs_fn = partial(
1099
+ raw_datasets["train"].filter, function=is_logprob_in_range, input_columns=["logprobs"]
1100
+ )
1101
+ with accelerator.main_process_first():
1102
+ raw_datasets["train"] = (
1103
+ filter_by_lobprobs_fn(num_proc=num_workers, desc="filtering train dataset by logprobs")
1104
+ if not data_args.streaming
1105
+ else filter_by_lobprobs_fn()
1106
+ )
1107
+
1108
+ # 10.4: pre-process training/evaluation datasets
1109
+ def prepare_dataset(example):
1110
+ prompt = example.get("prompt")
1111
+ target_text = prompt + example["text"] if prompt is not None else example["text"]
1112
+ example["labels"] = tokenizer(target_text).input_ids
1113
+ if example["labels"][-1] != eos_token_id:
1114
+ example["labels"] += [eos_token_id]
1115
+ example["prompt_length"] = len(tokenizer(prompt).input_ids) if prompt else 0
1116
+ return example
1117
+
1118
+ def prepare_instruction_dataset(example):
1119
+ messages = [
1120
+ {"role": "user", "content": example["prompt"]},
1121
+ {"role": "assistant", "content": example["text"]},
1122
+ ]
1123
+ example["labels"] = tokenizer.apply_chat_template(messages)
1124
+ if example["labels"][-1] != eos_token_id:
1125
+ example["labels"] = example["labels"][:-1]
1126
+
1127
+ example["prompt_length"] = len(tokenizer.apply_chat_template([messages[0]]))
1128
+ return example
1129
+
1130
+ prepare_dataset = prepare_instruction_dataset if instruction_model else prepare_dataset
1131
+ vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1132
+
1133
+ # with streaming mode we can only have 1 worker, whereas with non-streaming
1134
+ # we can use `num_workers` (which is much faster)
1135
+ # We gate the pre-processing function accordingly
1136
+ map_fn_train = partial(
1137
+ raw_datasets["train"].map,
1138
+ function=prepare_dataset,
1139
+ remove_columns=raw_datasets_train_features,
1140
+ )
1141
+ with accelerator.main_process_first():
1142
+ vectorized_datasets["train"] = (
1143
+ map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1144
+ if not data_args.streaming
1145
+ else map_fn_train()
1146
+ )
1147
+
1148
+ if training_args.do_eval:
1149
+ for eval_split in all_eval_splits:
1150
+ raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1151
+ map_fn_eval = partial(
1152
+ raw_datasets[eval_split].map, function=prepare_dataset, remove_columns=raw_datasets_eval_features
1153
+ )
1154
+ with accelerator.main_process_first():
1155
+ vectorized_datasets[eval_split] = (
1156
+ map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1157
+ if not data_args.streaming
1158
+ else map_fn_eval()
1159
+ )
1160
+
1161
+ # 10.5: Filter training data with labels longer than `max_label_length`
1162
+ def is_labels_in_length_range(labels):
1163
+ return 0 < len(labels) <= max_label_length
1164
+
1165
+ filter_by_labels_fn = partial(
1166
+ vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1167
+ )
1168
+ with accelerator.main_process_first():
1169
+ vectorized_datasets = (
1170
+ filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1171
+ if not data_args.streaming
1172
+ else filter_by_labels_fn()
1173
+ )
1174
+
1175
+ # Pre-processing complete!
1176
+ # For large datasets it is advised to run the preprocessing on a
1177
+ # single machine first with `--preprocessing_only` since there will mostly likely
1178
+ # be a timeout when running the script in distributed mode.
1179
+ # In a second step, `--preprocessing_only` can then be set to `False` to load the
1180
+ # cached dataset
1181
+ if data_args.preprocessing_only:
1182
+ if data_args.streaming:
1183
+ raise ValueError(
1184
+ "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1185
+ "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1186
+ "on the fly with streaming mode."
1187
+ )
1188
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1189
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1190
+ return
1191
+
1192
+ # 11. Define Evaluation Metrics
1193
+ def compute_metrics(preds, labels):
1194
+ # TODO(SG): better metrics for performance?
1195
+ # replace padded labels by the padding token
1196
+ for idx in range(len(labels)):
1197
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1198
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
1199
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1200
+ return pred_str, label_str
1201
+
1202
+ # 12. Define Training Schedule
1203
+ # 12.1: Store some constants
1204
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1205
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1206
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1207
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1208
+ temperature = training_args.temperature
1209
+
1210
+ # 12.2: Set max training steps
1211
+ if not data_args.streaming and training_args.max_steps < 0:
1212
+ num_epochs = int(training_args.num_train_epochs)
1213
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1214
+ total_train_steps = steps_per_epoch * num_epochs
1215
+ elif training_args.max_steps > 0:
1216
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1217
+ total_train_steps = int(training_args.max_steps)
1218
+ if not data_args.streaming:
1219
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1220
+ num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1221
+ else:
1222
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1223
+ num_epochs = sys.maxsize
1224
+ steps_per_epoch = total_train_steps
1225
+ else:
1226
+ raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1227
+
1228
+ # 12.3: Set evaluation steps
1229
+ if training_args.evaluation_strategy == "epoch":
1230
+ eval_steps = steps_per_epoch
1231
+ elif training_args.eval_steps is None:
1232
+ logger.info(
1233
+ f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1234
+ )
1235
+ eval_steps = steps_per_epoch
1236
+ else:
1237
+ eval_steps = training_args.eval_steps
1238
+
1239
+ # 12.4: Set save steps
1240
+ if training_args.save_strategy == "epoch":
1241
+ save_steps = steps_per_epoch
1242
+ elif training_args.save_strategy == "steps":
1243
+ save_steps = training_args.save_steps
1244
+ else:
1245
+ save_steps = sys.maxsize
1246
+
1247
+ # 13. Define optimizer, LR scheduler, collator
1248
+ decay_parameters = get_parameter_names(
1249
+ student_model,
1250
+ [nn.LayerNorm],
1251
+ forbidden_module,
1252
+ )
1253
+
1254
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
1255
+ optimizer_grouped_parameters = [
1256
+ {
1257
+ "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1258
+ "weight_decay": training_args.weight_decay,
1259
+ },
1260
+ {
1261
+ "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1262
+ "weight_decay": 0.0,
1263
+ },
1264
+ ]
1265
+ if training_args.optim == OptimizerNames.ADAMW_TORCH:
1266
+ optim_cls = torch.optim.AdamW
1267
+ elif training_args.optim == OptimizerNames.ADAMW_BNB:
1268
+ if not is_bitsandbytes_available():
1269
+ raise ValueError(
1270
+ "bitsandbytes package required for Adam8bit. Install via: `pip install --upgrade bitsandbytes`"
1271
+ )
1272
+ import bitsandbytes as bnb
1273
+
1274
+ optim_cls = bnb.optim.Adam8bit
1275
+ else:
1276
+ raise ValueError(
1277
+ f"Got invalid `--optim` {training_args.optim}, should be one of `['adam_torch', 'adamw_bnb_8bit']`."
1278
+ )
1279
+
1280
+ optimizer = optim_cls(
1281
+ params = optimizer_grouped_parameters,
1282
+ lr = training_args.learning_rate,
1283
+ betas = (training_args.adam_beta1, training_args.adam_beta2),
1284
+ eps = training_args.adam_epsilon,
1285
+ )
1286
+
1287
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1288
+ lr_scheduler = get_scheduler(
1289
+ name=training_args.lr_scheduler_type,
1290
+ optimizer=optimizer,
1291
+ num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1292
+ num_training_steps=total_train_steps * accelerator.num_processes,
1293
+ )
1294
+
1295
+ data_collator = DataCollatorCausalLMWithPadding(
1296
+ tokenizer=tokenizer,
1297
+ target_padding="max_length",
1298
+ max_target_length=max_label_length,
1299
+ completions_only=training_args.completions_only,
1300
+ )
1301
+
1302
+ # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1303
+ # so that we can still access the configs
1304
+ num_beams = (
1305
+ training_args.generation_num_beams
1306
+ if training_args.generation_num_beams is not None
1307
+ else getattr(student_model.generation_config, "num_beams", 1)
1308
+ )
1309
+
1310
+ # 15. Prepare everything with accelerate
1311
+ student_model, optimizer, lr_scheduler = accelerator.prepare(student_model, optimizer, lr_scheduler)
1312
+ teacher_model = accelerator.prepare(teacher_model) if teacher_model else None
1313
+
1314
+ def kl_divergence(target_distribution, log_predicted_distribution, labels):
1315
+ kl_loss = nn.KLDivLoss(reduction="none")
1316
+ divergence = kl_loss(log_predicted_distribution, target_distribution)
1317
+ # ignore padded tokens from divergence, i.e. where labels are not set to -100
1318
+ padding_mask = labels >= 0
1319
+ padding_mask = padding_mask.unsqueeze(-1)
1320
+ divergence = divergence * padding_mask
1321
+ # take the average over the mini-batch
1322
+ divergence = divergence.sum() / padding_mask.sum()
1323
+ return divergence
1324
+
1325
+ # Define gradient update step fn
1326
+ def train_step(batch):
1327
+ student_model.train()
1328
+ student_outputs = student_model(**batch)
1329
+
1330
+ # CE (data) loss
1331
+ ce_loss = student_outputs.loss
1332
+ metrics = {"ce_loss": ce_loss}
1333
+
1334
+ if teacher_model:
1335
+ with torch.no_grad():
1336
+ teacher_outputs = teacher_model(**batch)
1337
+ # rescale distribution by temperature to ensure gradients scale correctly
1338
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1339
+ # log softmax of student predictions for numerical stability
1340
+ student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1341
+ # KL-divergence loss (scaled by temperature)
1342
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature ** 2
1343
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1344
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1345
+ metrics["kl_loss"] = kl_loss
1346
+ else:
1347
+ loss = ce_loss
1348
+
1349
+ metrics["loss"] = loss
1350
+ return loss, metrics
1351
+
1352
+ # Define eval fn
1353
+ @torch.no_grad()
1354
+ def eval_step(batch):
1355
+ student_model.eval()
1356
+
1357
+ # CE (data) loss
1358
+ student_outputs = student_model(**batch)
1359
+ ce_loss = student_outputs.loss
1360
+ metrics = {"ce_loss": ce_loss}
1361
+
1362
+ if teacher_model:
1363
+ teacher_outputs = teacher_model(**batch)
1364
+ # log softmax / softmax for numerical stability
1365
+ student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1366
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1367
+ # temperature is always 1 for eval
1368
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1369
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1370
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1371
+ metrics["kl_loss"] = kl_loss
1372
+ else:
1373
+ loss = ce_loss
1374
+
1375
+ metrics["loss"] = loss
1376
+ return metrics
1377
+
1378
+ def generate_step(batch):
1379
+ output_ids = accelerator.unwrap_model(student_model).generate(
1380
+ **batch, max_length=max_label_length, num_beams=num_beams
1381
+ )
1382
+ output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1383
+ return output_ids
1384
+
1385
+ logger.info("***** Running training *****")
1386
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1387
+ if not data_args.streaming:
1388
+ logger.info(f" Num epochs = {num_epochs}")
1389
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1390
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1391
+ logger.info(
1392
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1393
+ )
1394
+ logger.info(f" Total optimization steps = {total_train_steps}")
1395
+
1396
+ # ======================== Training ================================
1397
+ train_time = 0
1398
+ train_start = time.time()
1399
+ steps_trained_progress_bar = tqdm(
1400
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1401
+ )
1402
+ continue_training = True
1403
+ epochs_trained = 0
1404
+ cur_step = 0
1405
+
1406
+ checkpoint = None
1407
+ if training_args.resume_from_checkpoint is not None:
1408
+ checkpoint = training_args.resume_from_checkpoint
1409
+ elif last_checkpoint is not None:
1410
+ checkpoint = last_checkpoint
1411
+
1412
+ if checkpoint is not None:
1413
+ accelerator.load_state(checkpoint)
1414
+ # Find num steps and epoch from saved state string pattern
1415
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1416
+ match = re.search(pattern, checkpoint)
1417
+ cur_step = int(match.group(1))
1418
+ epochs_trained = int(match.group(2))
1419
+
1420
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1421
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1422
+ logger.info(f" Continuing training from global step {cur_step}")
1423
+
1424
+ steps_trained_progress_bar.update(cur_step)
1425
+
1426
+ for epoch in range(0, epochs_trained):
1427
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1428
+
1429
+ if not data_args.streaming and training_args.max_steps < 0:
1430
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1431
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1432
+ else:
1433
+ # Currently we don't know how many steps we've taken in the current epoch
1434
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1435
+ # This is "good enough" for our purposes but not fully correct
1436
+ resume_step = None
1437
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1438
+ else:
1439
+ resume_step = None
1440
+
1441
+ for epoch in range(epochs_trained, num_epochs):
1442
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1443
+ train_dataloader = DataLoader(
1444
+ vectorized_datasets["train"],
1445
+ collate_fn=data_collator,
1446
+ batch_size=per_device_train_batch_size,
1447
+ num_workers=dataloader_num_workers,
1448
+ prefetch_factor=prefetch_factor,
1449
+ pin_memory=training_args.dataloader_pin_memory,
1450
+ )
1451
+ train_dataloader = accelerator.prepare(train_dataloader)
1452
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1453
+ train_dataloader.dataset.set_epoch(epoch)
1454
+
1455
+ if resume_step is not None:
1456
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1457
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1458
+ resume_step = None
1459
+
1460
+ for batch in train_dataloader:
1461
+ with accelerator.accumulate(student_model):
1462
+ loss, train_metric = train_step(batch)
1463
+ accelerator.backward(loss)
1464
+ if accelerator.sync_gradients:
1465
+ accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1466
+ optimizer.step()
1467
+ lr_scheduler.step()
1468
+ optimizer.zero_grad()
1469
+
1470
+ # Check if the accelerator has performed an optimization step behind the scenes
1471
+ if accelerator.sync_gradients:
1472
+ steps_trained_progress_bar.update(1)
1473
+ cur_step += 1
1474
+
1475
+ if cur_step % training_args.logging_steps == 0:
1476
+ steps_trained_progress_bar.write(
1477
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1478
+ f" {train_metric['loss']}, Learning Rate:"
1479
+ f" {lr_scheduler.get_last_lr()[0]})"
1480
+ )
1481
+ log_metric(
1482
+ accelerator,
1483
+ metrics=train_metric,
1484
+ learning_rate=lr_scheduler.get_last_lr()[0],
1485
+ train_time=train_time + time.time() - train_start,
1486
+ step=cur_step,
1487
+ epoch=epoch if data_args.streaming else epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1488
+ prefix="train",
1489
+ )
1490
+
1491
+ # save checkpoint and weights after each save_steps and at the end of training
1492
+ if (cur_step % save_steps == 0) or cur_step == total_train_steps:
1493
+ accelerator.wait_for_everyone()
1494
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1495
+ accelerator.save_state(output_dir=intermediate_dir)
1496
+ unwrapped_model = accelerator.unwrap_model(student_model)
1497
+ unwrapped_model.save_pretrained(
1498
+ intermediate_dir,
1499
+ is_main_process=accelerator.is_main_process,
1500
+ save_function=accelerator.save,
1501
+ )
1502
+ if accelerator.is_main_process:
1503
+ checkpoint_to_be_deleted = rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1504
+ if training_args.push_to_hub:
1505
+ upload_folder(
1506
+ folder_path=training_args.output_dir,
1507
+ repo_id=repo_name,
1508
+ repo_type="model",
1509
+ commit_message=f"Saving train state of step {cur_step}",
1510
+ delete_patterns=checkpoint_to_be_deleted,
1511
+ )
1512
+
1513
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1514
+ train_time += time.time() - train_start
1515
+ student_model.eval()
1516
+ # ======================== Evaluating ==============================
1517
+ for eval_split in all_eval_splits:
1518
+ eval_metrics = []
1519
+ eval_preds = []
1520
+ eval_labels = []
1521
+ eval_start = time.time()
1522
+
1523
+ validation_dataloader = DataLoader(
1524
+ vectorized_datasets[eval_split],
1525
+ collate_fn=data_collator,
1526
+ batch_size=per_device_eval_batch_size,
1527
+ drop_last=False,
1528
+ num_workers=dataloader_num_workers,
1529
+ prefetch_factor=prefetch_factor,
1530
+ pin_memory=training_args.dataloader_pin_memory,
1531
+ )
1532
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1533
+
1534
+ for batch in tqdm(
1535
+ validation_dataloader,
1536
+ desc=f"Evaluating {eval_split}...",
1537
+ position=2,
1538
+ disable=not accelerator.is_local_main_process,
1539
+ ):
1540
+ # Model forward
1541
+ eval_metric = eval_step(batch)
1542
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1543
+ eval_metrics.append(eval_metric)
1544
+
1545
+ # generation
1546
+ if training_args.predict_with_generate:
1547
+ generated_ids = generate_step(batch)
1548
+ # Gather all predictions and targets
1549
+ generated_ids, labels = accelerator.gather_for_metrics(
1550
+ (generated_ids, batch["labels"])
1551
+ )
1552
+ eval_preds.extend(generated_ids)
1553
+ eval_labels.extend(labels)
1554
+
1555
+ eval_time = time.time() - eval_start
1556
+ stack = torch.stack if accelerator.num_processes == 1 else torch.concatenate
1557
+ # normalize eval metrics
1558
+ eval_metrics = {
1559
+ key: torch.mean(stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1560
+ }
1561
+ try:
1562
+ eval_metrics["perplexity"] = math.exp(eval_metrics["ce_loss"])
1563
+ except OverflowError:
1564
+ eval_metrics["perplexity"] = float("inf")
1565
+
1566
+ if training_args.predict_with_generate:
1567
+ pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1568
+ log_pred(
1569
+ accelerator,
1570
+ pred_str,
1571
+ label_str,
1572
+ step=cur_step,
1573
+ epoch=epoch,
1574
+ evaluation_strategy=training_args.evaluation_strategy,
1575
+ prefix=eval_split,
1576
+ )
1577
+
1578
+ # Print metrics and update progress bar
1579
+ logger_desc = " ".join([f"Eval {key}: {value} |" for key, value in eval_metrics.items()])
1580
+ steps_trained_progress_bar.write(
1581
+ f"Eval results for step ({cur_step} / {total_train_steps} | {logger_desc}"
1582
+ )
1583
+
1584
+ log_metric(
1585
+ accelerator,
1586
+ metrics=eval_metrics,
1587
+ train_time=eval_time,
1588
+ step=cur_step,
1589
+ epoch=epoch if data_args.streaming else epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1590
+ prefix=eval_split,
1591
+ )
1592
+
1593
+ # flush the train metrics
1594
+ train_start = time.time()
1595
+
1596
+ # break condition
1597
+ if cur_step == total_train_steps:
1598
+ accelerator.wait_for_everyone()
1599
+ # un-wrap student model for save
1600
+ student_model = accelerator.unwrap_model(student_model)
1601
+ student_model.save_pretrained(
1602
+ training_args.output_dir,
1603
+ is_main_process=accelerator.is_main_process,
1604
+ save_function=accelerator.save,
1605
+ )
1606
+ if training_args.push_to_hub and accelerator.is_main_process:
1607
+ upload_folder(
1608
+ folder_path=training_args.output_dir,
1609
+ repo_id=repo_name,
1610
+ repo_type="model",
1611
+ commit_message=f"Saving final weights of step {cur_step}",
1612
+ )
1613
+ continue_training = False
1614
+ break
1615
+
1616
+ if not continue_training:
1617
+ break
1618
+
1619
+ accelerator.end_training()
1620
+
1621
+
1622
+ if __name__ == "__main__":
1623
+ main()
slurm_job.slurm ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=distil-mistral
3
+ #SBATCH --nodes=1
4
+ # set 24h for job wall time limit
5
+ #SBATCH --time=48:00:00
6
+ #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
7
+ #SBATCH --cpus-per-task=32
8
+ #SBATCH --gres=gpu:8
9
+ #SBATCH --exclusive
10
+ #SBATCH --partition=hopper-prod
11
+ #SBATCH --output=/fsx/sanchit/logs/%x-%j.out
12
+
13
+ set -x -e
14
+
15
+ # START EDIT
16
+ source ~/.bashrc
17
+ source /fsx/sanchit/miniconda3/bin/activate venv
18
+
19
+ LOG_PATH="/fsx/sanchit/logs/main_log.txt"
20
+ SAVE_DIR="/fsx/sanchit"
21
+ # END EDIT
22
+
23
+ echo "START TIME: $(date)"
24
+
25
+ GPUS_PER_NODE=8
26
+ NNODES=$SLURM_NNODES
27
+
28
+ # so processes know who to talk to
29
+ MASTER_ADDR=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1`
30
+
31
+ # From https://i.hsfzxjy.site/2021-03-10-obtain-a-random-unused-tcp-port-with-bash/
32
+ function unused_port() {
33
+ N=${1:-1}
34
+ comm -23 \
35
+ <(seq "1025" "65535" | sort) \
36
+ <(ss -Htan |
37
+ awk '{print $4}' |
38
+ cut -d':' -f2 |
39
+ sort -u) |
40
+ shuf |
41
+ head -n "$N"
42
+ }
43
+ MASTER_PORT=$(unused_port)
44
+
45
+ # export TORCH_CPP_LOG_LEVEL=INFO
46
+ # export TORCH_DISTRIBUTED_DEBUG=DETAIL
47
+
48
+ export LAUNCHER="python -u -m accelerate.commands.launch --config_file ./accelerate_config.yaml"
49
+
50
+ export PROGRAM="./run_distillation.py ./config_mistral.yaml"
51
+ export CMD="$LAUNCHER $PROGRAM"
52
+ echo $CMD
53
+
54
+ SRUN_ARGS=" \
55
+ --wait=60 \
56
+ --kill-on-bad-exit=1 \
57
+ "
58
+
59
+ # py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
60
+ clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt
61
+
62
+
63
+ # srun error handling:
64
+ # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
65
+ # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
66
+
67
+ # SRUN_ARGS=" \
68
+ # --wait=60 \
69
+ # --kill-on-bad-exit=1 \
70
+ # "
71
+ #
72
+ # # py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
73
+ # clear; srun $SRUN_ARGS --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt
74
+
75
+ echo "END TIME: $(date)"
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "additional_special_tokens": [],
31
+ "bos_token": "<s>",
32
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "</s>",
35
+ "legacy": true,
36
+ "model_max_length": 1000000000000000019884624838656,
37
+ "pad_token": "</s>",
38
+ "sp_model_kwargs": {},
39
+ "spaces_between_special_tokens": false,
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }