sanchit-gandhi HF staff commited on
Commit
f78d5d9
1 Parent(s): 37b9647

Saving train state of step 5000

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ wandb
accelerate_config.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
18
+
checkpoint-5000-epoch-0/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/Mistral-1.5B-Instruct-v0.2",
3
+ "architectures": [
4
+ "MistralForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 4096,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 14336,
13
+ "max_position_embeddings": 32768,
14
+ "model_type": "mistral",
15
+ "num_attention_heads": 32,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 8,
18
+ "output_router_logits": true,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 1000000.0,
21
+ "sliding_window": null,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "bfloat16",
24
+ "transformers_version": "4.40.0.dev0",
25
+ "use_cache": true,
26
+ "vocab_size": 32000
27
+ }
checkpoint-5000-epoch-0/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "max_length": 2048,
6
+ "transformers_version": "4.40.0.dev0"
7
+ }
checkpoint-5000-epoch-0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a60136dacab3e895e45dacfc3dff9da7cf34bebacf0b830538268759b9c9b146
3
+ size 3141646744
checkpoint-5000-epoch-0/model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65173ac419081e25f0d5f93ed77393cf05f5158325ee154a5cbb3e14b47ece07
3
+ size 4450837792
checkpoint-5000-epoch-0/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:123c80cdcdad2ee9018734d53265219075b77629c741e8f6e7f6581c7d6ed148
3
+ size 6283329590
checkpoint-5000-epoch-0/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aca503cc09e63ca033e29a437a20cc580a9c1db27fef2174e533f58ba275879
3
+ size 16100
checkpoint-5000-epoch-0/random_states_1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31831c2134536b1e81ba1e763e72b2ff98a14a83774fcfb30d153a66dca7879c
3
+ size 16100
checkpoint-5000-epoch-0/random_states_2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a628258539b4090ce50e9faf5fda4d613f523ca957f3e837c02d316e4b20122
3
+ size 16100
checkpoint-5000-epoch-0/random_states_3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d594aa54f68e8eb41c3deb9753bf43474028f44edb92db1930ebdf967f708a7c
3
+ size 16100
checkpoint-5000-epoch-0/random_states_4.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28ca4240374ff4b93ad0537aca2f28bfc293153a29ee8069cf09d088ca30fee7
3
+ size 16100
checkpoint-5000-epoch-0/random_states_5.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d6f3577977e8c32eac49b1c5136c6718fcd9c66051b703ba6e305cca03a8fb0
3
+ size 16100
checkpoint-5000-epoch-0/random_states_6.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0ef1d86e60e6cedda41454cd08e0b3652ab6a6eb017b4eed0d6b84866ed7d46
3
+ size 16100
checkpoint-5000-epoch-0/random_states_7.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08d860c07ef8d57c8162394106fcd87c34e7924d859b28b4b292e9e792a96af2
3
+ size 16100
checkpoint-5000-epoch-0/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d0100621f3bae95fccd28713bef2f7c347b2970b15af6bcf2c067948dea8722
3
+ size 1064
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/Mistral-1.5B-Instruct-v0.2",
3
+ "architectures": [
4
+ "MistralForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 4096,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 14336,
13
+ "max_position_embeddings": 32768,
14
+ "model_type": "mistral",
15
+ "num_attention_heads": 32,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 8,
18
+ "output_router_logits": true,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 1000000.0,
21
+ "sliding_window": null,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.40.0.dev0",
25
+ "use_cache": true,
26
+ "vocab_size": 32000
27
+ }
config_mistral.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model arguments
2
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-Instruct-v0.2
3
+ teacher_model_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
4
+ tokenizer_name: mistralai/Mistral-7B-Instruct-v0.2
5
+ dtype: bfloat16
6
+ load_teacher_in_4bit: true
7
+
8
+ # Data arguments
9
+ train_dataset_name: HuggingFaceTB/cosmopedia
10
+ train_dataset_config_name:
11
+ - auto_math_text
12
+ - khanacademy
13
+ - openstax
14
+ - stanford
15
+ - stories
16
+ - web_samples_v1
17
+ - web_samples_v2
18
+ - wikihow
19
+ train_split_name: train[1000:]
20
+ eval_split_name: train[:1000]
21
+ max_steps: 100000
22
+ max_train_samples: 10000000
23
+
24
+ # Training arguments
25
+ do_train: true
26
+ do_eval: true
27
+ per_device_eval_batch_size: 8
28
+ per_device_train_batch_size: 8
29
+ learning_rate: 0.0001
30
+ warmup_steps: 500
31
+ gradient_checkpointing: true
32
+ dataloader_num_workers: 4
33
+ preprocessing_num_workers: 32
34
+ ddp_timeout: 7200
35
+ save_strategy: steps
36
+ save_steps: 5000
37
+ evaluation_strategy: steps
38
+ eval_steps: 5000
39
+ logging_steps: 25
40
+ output_router_logits: true
41
+ report_to: all
42
+ output_dir: ./
43
+ overwrite_output_dir: false
44
+ save_total_limit: 1
45
+ wandb_project: distil-mistral
46
+ push_to_hub: true
47
+
distil-mistral/1713805949.3655071/events.out.tfevents.1713805949.ip-26-0-163-236.1916906.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9450d9de35b9a73e76d8b6c91c6e007b447e0e09e14cbce3ad5b663c323783e
3
+ size 1160
distil-mistral/1713805949.3694406/hparams.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ global_batch_size: 64
4
+ gradient_accumulation_steps: 1
5
+ learning_rate: 0.0001
6
+ lr_scheduler_type: !!python/object/apply:transformers.trainer_utils.SchedulerType
7
+ - linear
8
+ max_steps: 100000
9
+ mixed_precision: bf16
10
+ model_name_or_path: sanchit-gandhi/Mistral-1.5B-Instruct-v0.2
11
+ num_train_epochs: 3.0
12
+ per_device_train_batch_size: 8
13
+ teacher_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
14
+ temperature: 2.0
15
+ warmup_steps: 500
16
+ weight_decay: 0.0
distil-mistral/events.out.tfevents.1713805939.ip-26-0-163-236.1916906.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74335b885f306b8613f0b40fc477165aca5c52f3da1a97a2c482312243274aa5
3
+ size 62058
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "max_length": 2048,
6
+ "transformers_version": "4.40.0.dev0"
7
+ }
run_distillation.py ADDED
@@ -0,0 +1,1528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training langauge models Whisper model for conditional language modelling tasks via teacher-student distillation.
18
+ """
19
+ # You can also adapt this script for your own distillation tasks. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import shutil
26
+ import sys
27
+ import time
28
+ from dataclasses import dataclass, field
29
+ from functools import partial
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Union
32
+
33
+ import datasets
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from datasets import (
41
+ Dataset,
42
+ DatasetDict,
43
+ IterableDataset,
44
+ IterableDatasetDict,
45
+ concatenate_datasets,
46
+ interleave_datasets,
47
+ load_dataset,
48
+ )
49
+ from huggingface_hub import create_repo, get_full_repo_name, upload_folder
50
+ from peft import LoraConfig, get_peft_model
51
+ from torch.utils.data import DataLoader
52
+ from tqdm import tqdm
53
+ from transformers import (
54
+ AutoConfig,
55
+ AutoModelForCausalLM,
56
+ AutoTokenizer,
57
+ BatchEncoding,
58
+ BitsAndBytesConfig,
59
+ HfArgumentParser,
60
+ PreTrainedTokenizerBase,
61
+ Seq2SeqTrainingArguments,
62
+ get_scheduler,
63
+ set_seed,
64
+ )
65
+ from transformers.utils import check_min_version
66
+ from transformers.utils.versions import require_version
67
+
68
+
69
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
70
+ check_min_version("4.34.0.dev0")
71
+
72
+ require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")
73
+
74
+ logger = get_logger(__name__)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ """
80
+ Arguments pertaining to which model/config/tokenizer we are going to distill from.
81
+ """
82
+
83
+ model_name_or_path: str = field(
84
+ metadata={"help": "Path to pretrained Whisper model or model identifier from huggingface.co/models"}
85
+ )
86
+ teacher_model_name_or_path: str = field(
87
+ metadata={"help": "Path to pretrained teacher model or model identifier from huggingface.co/models"}
88
+ )
89
+ config_name: Optional[str] = field(
90
+ default=None,
91
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
92
+ )
93
+ tokenizer_name: Optional[str] = field(
94
+ default=None,
95
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
96
+ )
97
+ cache_dir: Optional[str] = field(
98
+ default=None,
99
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
100
+ )
101
+ use_fast_tokenizer: bool = field(
102
+ default=True,
103
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
104
+ )
105
+ model_revision: str = field(
106
+ default="main",
107
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
108
+ )
109
+ subfolder: str = field(
110
+ default="",
111
+ metadata={
112
+ "help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
113
+ "specify the folder name here."
114
+ },
115
+ )
116
+ token: str = field(
117
+ default=None,
118
+ metadata={
119
+ "help": (
120
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
121
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
122
+ )
123
+ },
124
+ )
125
+ attn_implementation: Optional[str] = field(
126
+ default=None,
127
+ metadata={
128
+ "help": (
129
+ "Which attention implementation to use in the encoder and decoder attention layers. Can be one of:\n"
130
+ "1. `eager` or `None`: default Transformers attention implementation.\n"
131
+ "2. `sdpa`: Flash Attention through PyTorch SDPA. Requires `torch>=2.1`. Recommended for hardware where Flash Attention 2 is not supported, e.g. Turing GPUs, (T4, RTX 2080).\n"
132
+ "3. `flash_attn_2`: Flash Attention 2 through the Flash Attention package https://github.com/Dao-AILab/flash-attention. **Always** recommended on supported hardware (Ampere, Ada, or Hopper GPUs, e.g., A100, RTX 3090, RTX 4090, H100)."
133
+ )
134
+ },
135
+ )
136
+ load_teacher_in_8bit: bool = field(default=False, metadata={"help": "Use 8 bit precision for the teacher model."})
137
+ load_teacher_in_4bit: bool = field(default=False, metadata={"help": "Use 4 bit precision for the teacher model."})
138
+ load_student_in_8bit: bool = field(default=False, metadata={"help": "Use 8 bit precision for the student model."})
139
+ load_student_in_4bit: bool = field(default=False, metadata={"help": "Use 4 bit precision for the student model."})
140
+ bnb_4bit_quant_type: Optional[str] = field(
141
+ default="nf4", metadata={"help": "Quantization type if the teacher is quantized (fp4 or nf4)"}
142
+ )
143
+ use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether or not to use nested quantization."})
144
+ lora_r: Optional[int] = field(
145
+ default=16,
146
+ metadata={"help": "LoRA R value."},
147
+ )
148
+ lora_alpha: Optional[int] = field(
149
+ default=32,
150
+ metadata={"help": "LoRA alpha."},
151
+ )
152
+ lora_dropout: Optional[float] = field(
153
+ default=0.05,
154
+ metadata={"help": "LoRA dropout."},
155
+ )
156
+ lora_target_modules: Optional[List[str]] = field(
157
+ default=None,
158
+ metadata={"help": "LoRA target modules."},
159
+ )
160
+ lora_modules_to_save: Optional[List[str]] = field(
161
+ default=None,
162
+ metadata={"help": "Model layers to unfreeze & train"},
163
+ )
164
+ instruction_model: Optional[bool] = field(
165
+ default=None,
166
+ metadata={"help": "Whether or not the pre-trained model is instruction tuned"},
167
+ )
168
+
169
+
170
+ @dataclass
171
+ class DataTrainingArguments:
172
+ """
173
+ Arguments pertaining to what data we are going to input our model for training and eval.
174
+ """
175
+
176
+ train_dataset_name: List[str] = field(
177
+ default=None,
178
+ metadata={
179
+ "help": "The name of the training dataset to use (via the datasets library). Load and combine "
180
+ "multiple datasets by separating dataset ids by a '+' symbol. For example, to load LibriSpeech "
181
+ "and Common Voice, set `train_dataset_name='librispeech_asr+common_voice'`."
182
+ },
183
+ )
184
+ train_dataset_config_name: Optional[List[str]] = field(
185
+ default=None,
186
+ metadata={
187
+ "help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
188
+ "multiple datasets by separating dataset configs by a '+' symbol. Note that the order of the configs should "
189
+ "match the order of the datasets."
190
+ },
191
+ )
192
+ train_dataset_samples: Optional[List[str]] = field(
193
+ default=None,
194
+ metadata={
195
+ "help": "Number of samples in each dataset when loading multiple datasets with streaming mode. "
196
+ "Not required when using one dataset or non-streaming mode. The sample values provide the sampling "
197
+ "probability for each dataset. Setting them equal to the number of sample values ensures that every "
198
+ "sample from every dataset is used once per epoch."
199
+ },
200
+ )
201
+ eval_dataset_name: Optional[List[str]] = field(
202
+ default=None,
203
+ metadata={
204
+ "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training "
205
+ "dataset name if unspecified. Load multiple evaluation datasets by separating dataset "
206
+ "ids by a '+' symbol."
207
+ },
208
+ )
209
+ eval_dataset_config_name: Optional[List[str]] = field(
210
+ default=None,
211
+ metadata={
212
+ "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the "
213
+ "training dataset config name if unspecified."
214
+ },
215
+ )
216
+ dataset_cache_dir: Optional[str] = field(
217
+ default=None,
218
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
219
+ )
220
+ overwrite_cache: bool = field(
221
+ default=False,
222
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
223
+ )
224
+ preprocessing_num_workers: Optional[int] = field(
225
+ default=None,
226
+ metadata={"help": "The number of processes to use for the preprocessing if using non-streaming mode."},
227
+ )
228
+ max_train_samples: Optional[int] = field(
229
+ default=None,
230
+ metadata={
231
+ "help": (
232
+ "For debugging purposes or quicker training, truncate the number of training examples to this value if set."
233
+ )
234
+ },
235
+ )
236
+ max_eval_samples: Optional[int] = field(
237
+ default=None,
238
+ metadata={
239
+ "help": (
240
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set."
241
+ )
242
+ },
243
+ )
244
+ text_column_name: str = field(
245
+ default=None,
246
+ metadata={"help": "The name of the dataset column containing the generated text data in the training set."},
247
+ )
248
+ prompt_column_name: str = field(
249
+ default=None,
250
+ metadata={"help": "The name of the dataset column containing the prompt data. Defaults to 'prompt'"},
251
+ )
252
+ eval_text_column_name: str = field(
253
+ default=None,
254
+ metadata={"help": "The name of the dataset column containing the generated text data in the evaluation set."},
255
+ )
256
+ eval_prompt_column_name: str = field(
257
+ default=None,
258
+ metadata={"help": "The name of the dataset column containing the prompt data in the evaluation set."},
259
+ )
260
+ max_label_length: int = field(
261
+ default=2048,
262
+ metadata={"help": "Truncate target labels that are longer `max_label_length` tokens."},
263
+ )
264
+ pad_target_to_multiple_of: Optional[int] = field(
265
+ default=None,
266
+ metadata={
267
+ "help": (
268
+ "If set will pad the target sequence to a multiple of the provided value. This is important to "
269
+ "avoid triggering recompilations when using torch compile. If unspecified, will default to padding "
270
+ "the targets to max length."
271
+ )
272
+ },
273
+ )
274
+ preprocessing_only: bool = field(
275
+ default=False,
276
+ metadata={
277
+ "help": (
278
+ "Whether to only do data preprocessing and skip training. This is especially useful when data "
279
+ "preprocessing errors out in distributed training due to timeout. In this case, one should run the "
280
+ "preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets "
281
+ "can consequently be loaded in distributed training"
282
+ )
283
+ },
284
+ )
285
+ train_split_name: Optional[List[str]] = field(
286
+ default=lambda: ["train"],
287
+ metadata={
288
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
289
+ },
290
+ )
291
+ eval_split_name: Optional[List[str]] = field(
292
+ default=lambda: ["validation"],
293
+ metadata={
294
+ "help": (
295
+ "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
296
+ )
297
+ },
298
+ )
299
+ streaming: bool = field(
300
+ default=False,
301
+ metadata={"help": "Whether to use Datasets' streaming mode to load and pre-process the data."},
302
+ )
303
+ wandb_project: str = field(
304
+ default="distil-mistral",
305
+ metadata={"help": "The name of the wandb project."},
306
+ )
307
+
308
+
309
+ @dataclass
310
+ class DistillationTrainingArguments(Seq2SeqTrainingArguments):
311
+ freeze_lm_head: Optional[bool] = field(
312
+ default=False, metadata={"help": "Whether to freeze the LM head of the student model."}
313
+ )
314
+ temperature: Optional[float] = field(
315
+ default=2.0, metadata={"help": "Temperature to anneal the logits when computing the softmax."}
316
+ )
317
+ kl_weight: Optional[float] = field(
318
+ default=1.0,
319
+ metadata={
320
+ "help": (
321
+ "Weighting assigned to the MSE loss in the KD formulation. MSE loss is "
322
+ "computed between the teacher-student hidden states and attentions."
323
+ )
324
+ },
325
+ )
326
+ output_router_logits: bool = field(
327
+ default=False,
328
+ metadata={
329
+ "help": "Whether or not to return the router logits in the forward pass. Enabling this will "
330
+ "also configure the model to compute the auxiliary loss."
331
+ },
332
+ )
333
+ dtype: Optional[str] = field(
334
+ default="float32",
335
+ metadata={
336
+ "help": (
337
+ "The data type (dtype) in which to run training. One of `float32` (full-precision), "
338
+ "`float16` or `bfloat16` (both half-precision)."
339
+ )
340
+ },
341
+ )
342
+
343
+
344
+ @dataclass
345
+ class DataCollatorCausalLMWithPadding:
346
+ """
347
+ Data collator that will dynamically pad the inputs received.
348
+ Args:
349
+ tokenizer ([`PreTrainedTokenizer`])
350
+ The tokenizer used for tokenizing the data.
351
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
352
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
353
+ See above for details.
354
+ max_target_length (:obj:`int`, `optional`):
355
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
356
+ """
357
+
358
+ tokenizer: PreTrainedTokenizerBase
359
+ target_padding: Union[bool, str] = "max_length"
360
+ max_target_length: Optional[int] = None
361
+
362
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> BatchEncoding:
363
+ # dataloader returns a list of features which we convert to a dict
364
+ label_features = {"input_ids": [feature["labels"] for feature in features]}
365
+ label_lengths = [len(feature["labels"]) for feature in features]
366
+ prompt_lengths = [feature["prompt_length"] for feature in features]
367
+
368
+ batch = self.tokenizer.pad(
369
+ label_features,
370
+ max_length=self.max_target_length,
371
+ padding=self.target_padding,
372
+ return_tensors="pt",
373
+ )
374
+
375
+ labels_mask = batch["attention_mask"]
376
+
377
+ # don't include prompts in loss calculation
378
+ for idx in range(len(prompt_lengths)):
379
+ padding_length = labels_mask.shape[1] - label_lengths[idx]
380
+ labels_mask[idx, padding_length : padding_length + prompt_lengths[idx]] = 0
381
+
382
+ # replace padding with -100 to ignore loss correctly
383
+ labels = batch["input_ids"].masked_fill(labels_mask.ne(1), -100)
384
+
385
+ batch["labels"] = labels
386
+
387
+ return batch
388
+
389
+
390
+ def log_metric(
391
+ accelerator,
392
+ metrics: Dict,
393
+ train_time: float,
394
+ step: int,
395
+ epoch: int,
396
+ learning_rate: float = None,
397
+ prefix: str = "train",
398
+ ):
399
+ """Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
400
+ log_metrics = {}
401
+ for k, v in metrics.items():
402
+ log_metrics[f"{prefix}/{k}"] = v
403
+ log_metrics[f"{prefix}/time"] = train_time
404
+ log_metrics[f"{prefix}/epoch"] = epoch
405
+ if learning_rate is not None:
406
+ log_metrics[f"{prefix}/learning_rate"] = learning_rate
407
+ accelerator.log(log_metrics, step=step)
408
+
409
+
410
+ def log_pred(
411
+ accelerator,
412
+ pred_str: List[str],
413
+ label_str: List[str],
414
+ step: int,
415
+ epoch: int,
416
+ evaluation_strategy: str,
417
+ prefix: str = "eval",
418
+ num_lines: int = 200000,
419
+ ):
420
+ """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
421
+ if accelerator.is_main_process:
422
+ wandb_tracker = accelerator.get_tracker("wandb")
423
+ # pretty name for current step: step 50000 -> step 50k
424
+ cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
425
+ prefix_pretty = prefix.replace("/", "-")
426
+
427
+ if evaluation_strategy == "epoch":
428
+ table_name = f"predictions/{prefix_pretty}-epoch-{epoch}"
429
+ else:
430
+ table_name = f"predictions/{prefix_pretty}-step-{cur_step_pretty}"
431
+
432
+ # convert str data to a wandb compatible format
433
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
434
+ # log as a table with the appropriate headers
435
+ wandb_tracker.log_table(
436
+ table_name=table_name,
437
+ columns=["Target", "Pred"],
438
+ data=str_data[:num_lines],
439
+ step=step,
440
+ )
441
+
442
+
443
+ def convert_dataset_str_to_list(
444
+ dataset_names,
445
+ dataset_config_names,
446
+ splits=None,
447
+ text_column_names=None,
448
+ prompt_column_names=None,
449
+ dataset_samples=None,
450
+ default_split="train",
451
+ ) -> List[Dict]:
452
+ """
453
+ Given three lists of dataset names, configs and splits, this function groups the corresponding
454
+ names/configs/splits. Each dataset is assigned a unique dictionary with these metadata values, and the
455
+ function returns a list of dictionaries, one for each dataset.
456
+ """
457
+ if isinstance(dataset_names, str):
458
+ dataset_names = [dataset_names]
459
+ splits = [splits] if splits else None
460
+ text_column_names = [text_column_names] if text_column_names else None
461
+ prompt_column_names = [prompt_column_names] if prompt_column_names else None
462
+ if isinstance(dataset_config_names, str):
463
+ dataset_config_names = [dataset_config_names]
464
+
465
+ if len(dataset_names) == 1 and len(dataset_config_names) > 1:
466
+ dataset_names = len(dataset_config_names) * dataset_names
467
+
468
+ if isinstance(splits, list) and len(splits) == 1 and len(dataset_config_names) > 1:
469
+ splits = len(dataset_config_names) * splits
470
+
471
+ # basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
472
+ if dataset_config_names is not None and len(dataset_names) != len(dataset_config_names):
473
+ raise ValueError(
474
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
475
+ f" {len(dataset_config_names)} configs."
476
+ )
477
+
478
+ if splits is not None and len(splits) != len(dataset_names):
479
+ raise ValueError(
480
+ f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
481
+ )
482
+
483
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
484
+ raise ValueError(
485
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
486
+ f" {len(text_column_names)} text column names."
487
+ )
488
+
489
+ if prompt_column_names is not None and len(prompt_column_names) != len(dataset_names):
490
+ raise ValueError(
491
+ f"Ensure one prompt column name is passed for each dataset, got {len(dataset_names)} datasets and"
492
+ f" {len(prompt_column_names)} prompt column names."
493
+ )
494
+
495
+ if dataset_samples is not None:
496
+ if len(dataset_samples) != len(dataset_names):
497
+ raise ValueError(
498
+ f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
499
+ f"{len(dataset_samples)} samples."
500
+ )
501
+ dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
502
+ else:
503
+ dataset_samples = [None] * len(dataset_names)
504
+
505
+ dataset_config_names = (
506
+ dataset_config_names if dataset_config_names is not None else ["default" for _ in range(len(dataset_names))]
507
+ )
508
+ text_column_names = (
509
+ text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
510
+ )
511
+ prompt_column_names = (
512
+ prompt_column_names if prompt_column_names is not None else ["prompt" for _ in range(len(dataset_names))]
513
+ )
514
+ splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
515
+
516
+ dataset_names_dict = []
517
+ for i, ds_name in enumerate(dataset_names):
518
+ dataset_names_dict.append(
519
+ {
520
+ "name": ds_name,
521
+ "config": dataset_config_names[i],
522
+ "split": splits[i],
523
+ "text_column_name": text_column_names[i],
524
+ "prompt_column_name": prompt_column_names[i],
525
+ "samples": dataset_samples[i],
526
+ }
527
+ )
528
+ return dataset_names_dict
529
+
530
+
531
+ def load_multiple_datasets(
532
+ dataset_names: Union[List, str],
533
+ dataset_config_names: Union[List, str],
534
+ splits: Optional[Union[List, str]] = None,
535
+ text_column_names: Optional[List] = None,
536
+ prompt_column_names: Optional[List] = None,
537
+ stopping_strategy: Optional[str] = "first_exhausted",
538
+ dataset_samples: Optional[Union[List, np.array]] = None,
539
+ streaming: Optional[bool] = False,
540
+ seed: Optional[int] = None,
541
+ accelerator: Optional[Accelerator] = None,
542
+ **kwargs,
543
+ ) -> Union[Dataset, IterableDataset]:
544
+ dataset_names_dict = convert_dataset_str_to_list(
545
+ dataset_names, dataset_config_names, splits, text_column_names, prompt_column_names, dataset_samples
546
+ )
547
+
548
+ if dataset_samples is not None:
549
+ dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
550
+ probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
551
+ else:
552
+ probabilities = None
553
+
554
+ all_datasets = []
555
+ # iterate over the datasets we want to interleave
556
+ for dataset_dict in tqdm(
557
+ dataset_names_dict,
558
+ desc="Combining datasets...",
559
+ disable=not accelerator.is_main_process,
560
+ ):
561
+ dataset = load_dataset(
562
+ dataset_dict["name"],
563
+ dataset_dict["config"],
564
+ split=dataset_dict["split"],
565
+ streaming=streaming,
566
+ **kwargs,
567
+ )
568
+
569
+ columns_to_keep = {"text"}
570
+ dataset_features = dataset.features.keys()
571
+
572
+ if dataset_dict["text_column_name"] not in dataset_features:
573
+ raise ValueError(
574
+ f"Text column name {dataset_dict['text_column_name']} not found in dataset"
575
+ f" '{dataset_dict['name']}'. Make sure to set `--text_column_name` to the"
576
+ f" correct text column - one of {', '.join(dataset_features)}."
577
+ )
578
+
579
+ # blanket renaming of all transcription columns to text
580
+ if dataset_dict["text_column_name"] != "text":
581
+ dataset = dataset.rename_column(dataset_dict["text_column_name"], "text")
582
+
583
+ # blanket renaming of all prompt columns to prompt
584
+ if dataset_dict["prompt_column_name"] is not None:
585
+ if dataset_dict["prompt_column_name"] not in dataset_features:
586
+ raise ValueError(
587
+ f"Prompt column name {dataset_dict['prompt_column_name']} not found in dataset"
588
+ f" '{dataset_dict['name']}'. Make sure to set `--prompt_column_name` to the"
589
+ f" correct prompt column - one of {', '.join(dataset_features)}."
590
+ )
591
+ elif dataset_dict["prompt_column_name"] != "prompt":
592
+ dataset = dataset.rename_column(dataset_dict["prompt_column_name"], "prompt")
593
+ columns_to_keep.add("prompt")
594
+
595
+ dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
596
+ all_datasets.append(dataset)
597
+
598
+ if len(all_datasets) == 1:
599
+ # we have a single dataset so just return it as is
600
+ return all_datasets[0]
601
+
602
+ if streaming:
603
+ interleaved_dataset = interleave_datasets(
604
+ all_datasets,
605
+ stopping_strategy=stopping_strategy,
606
+ probabilities=probabilities,
607
+ seed=seed,
608
+ )
609
+ else:
610
+ interleaved_dataset = concatenate_datasets(all_datasets)
611
+
612
+ # shuffle mixed dataset prior to potentially truncating it
613
+ interleaved_dataset = interleaved_dataset.shuffle(seed)
614
+ return interleaved_dataset
615
+
616
+
617
+ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
618
+ """Helper function to sort saved checkpoints from oldest to newest."""
619
+ ordering_and_checkpoint_path = []
620
+
621
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
622
+
623
+ for path in glob_checkpoints:
624
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
625
+ if regex_match is not None and regex_match.groups() is not None:
626
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
627
+
628
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
629
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
630
+ return checkpoints_sorted
631
+
632
+
633
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> Union[List, None]:
634
+ """Helper function to delete old checkpoints."""
635
+ if save_total_limit is None or save_total_limit <= 0:
636
+ return
637
+ # Check if we should delete older checkpoint(s)
638
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
639
+ if len(checkpoints_sorted) <= save_total_limit:
640
+ return
641
+
642
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
643
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
644
+ for checkpoint in checkpoints_to_be_deleted:
645
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
646
+ shutil.rmtree(checkpoint, ignore_errors=True)
647
+ checkpoints_to_be_deleted = [f"*{Path(checkpoint).absolute().name}*" for checkpoint in checkpoints_to_be_deleted]
648
+ return checkpoints_to_be_deleted
649
+
650
+
651
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
652
+
653
+
654
+ def get_last_checkpoint(folder):
655
+ content = os.listdir(folder)
656
+ checkpoints = [
657
+ path
658
+ for path in content
659
+ if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
660
+ ]
661
+ if len(checkpoints) == 0:
662
+ return
663
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
664
+
665
+
666
+ def get_parameter_names(model, forbidden_layer_types, forbidden_module=None):
667
+ """
668
+ Returns the names of the model parameters that are not inside a forbidden layer or forbidden module.
669
+ Can be used to get a subset of parameter names for decay masks, or to exclude parameters from an optimiser
670
+ (e.g. if the module is frozen).
671
+ """
672
+ result = []
673
+ for name, child in model.named_children():
674
+ result += [
675
+ f"{name}.{n}"
676
+ for n in get_parameter_names(child, forbidden_layer_types, forbidden_module)
677
+ if not (
678
+ isinstance(child, tuple(forbidden_layer_types))
679
+ or (child in tuple(forbidden_module) if forbidden_module is not None else False)
680
+ )
681
+ ]
682
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
683
+ result += list(model._parameters.keys())
684
+ return result
685
+
686
+
687
+ def get_quantization_config(
688
+ model_args: ModelArguments, torch_dtype: torch.dtype
689
+ ) -> tuple[BitsAndBytesConfig | None, BitsAndBytesConfig | None]:
690
+ if model_args.load_teacher_in_4bit:
691
+ quantization_config_teacher = BitsAndBytesConfig(
692
+ load_in_4bit=True,
693
+ bnb_4bit_compute_dtype=torch_dtype,
694
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
695
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
696
+ )
697
+ elif model_args.load_teacher_in_8bit:
698
+ quantization_config_teacher = BitsAndBytesConfig(load_in_8bit=True)
699
+ else:
700
+ quantization_config_teacher = None
701
+
702
+ if model_args.load_student_in_4bit:
703
+ quantization_config_student = BitsAndBytesConfig(
704
+ load_in_4bit=True,
705
+ bnb_4bit_compute_dtype=torch_dtype,
706
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
707
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
708
+ )
709
+ elif model_args.load_student_in_8bit:
710
+ quantization_config_student = BitsAndBytesConfig(load_in_8bit=True)
711
+ else:
712
+ quantization_config_student = None
713
+
714
+ return quantization_config_teacher, quantization_config_student
715
+
716
+
717
+ def main():
718
+ # 1. Parse input arguments
719
+ # We keep distinct sets of args, for cleaner separation of model/data/training related args
720
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, DistillationTrainingArguments))
721
+
722
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
723
+ # If we pass only one argument to the script and it's the path to a json file,
724
+ # let's parse it to get our arguments.
725
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
726
+ elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
727
+ # If we pass only one argument to the script and it's the path to a yaml file,
728
+ # let's parse it to get our arguments.
729
+ model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1]))
730
+ else:
731
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
732
+
733
+ # 2. Initialize the accelerator
734
+ # We will let the accelerator handle device placement for us in this example
735
+ # We simply have to specify the training precision and any trackers being used
736
+ # We'll use the same dtype arguments as our JAX/Flax training script and convert
737
+ # it to accelerate format
738
+ if training_args.dtype == "float16":
739
+ mixed_precision = "fp16"
740
+ teacher_dtype = torch.float16
741
+ elif training_args.dtype == "bfloat16":
742
+ mixed_precision = "bf16"
743
+ teacher_dtype = torch.bfloat16
744
+ else:
745
+ mixed_precision = "no"
746
+ teacher_dtype = torch.float32
747
+
748
+ accelerator = Accelerator(
749
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
750
+ mixed_precision=mixed_precision,
751
+ log_with=training_args.report_to,
752
+ project_dir=training_args.output_dir,
753
+ )
754
+
755
+ accelerator.init_trackers(
756
+ project_name=data_args.wandb_project,
757
+ config={
758
+ "learning_rate": training_args.learning_rate,
759
+ "model_name_or_path": model_args.model_name_or_path,
760
+ "teacher_name_or_path": model_args.teacher_model_name_or_path,
761
+ "num_train_epochs": training_args.num_train_epochs,
762
+ "max_steps": training_args.max_steps,
763
+ "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
764
+ "per_device_train_batch_size": training_args.per_device_train_batch_size,
765
+ "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
766
+ "mixed_precision": mixed_precision,
767
+ "lr_scheduler_type": training_args.lr_scheduler_type,
768
+ "warmup_steps": training_args.warmup_steps,
769
+ "weight_decay": training_args.weight_decay,
770
+ "adam_beta1": training_args.adam_beta1,
771
+ "adam_beta2": training_args.adam_beta2,
772
+ "temperature": training_args.temperature,
773
+ },
774
+ )
775
+
776
+ # 3. Set-up basic logging
777
+ # Create one log on every process with the configuration for debugging
778
+ logging.basicConfig(
779
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
780
+ datefmt="%m/%d/%Y %H:%M:%S",
781
+ level=logging.INFO,
782
+ )
783
+ # Log a small summary on each proces
784
+ logger.warning(
785
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
786
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
787
+ )
788
+
789
+ # Set the verbosity to info of the Transformers logger (on main process only)
790
+ if accelerator.is_local_main_process:
791
+ datasets.utils.logging.set_verbosity_warning()
792
+ transformers.utils.logging.set_verbosity_info()
793
+ else:
794
+ datasets.utils.logging.set_verbosity_error()
795
+ transformers.utils.logging.set_verbosity_error()
796
+ logger.info("Training/evaluation parameters %s", training_args)
797
+
798
+ # 4. Detecting last checkpoint and eventually continue from last checkpoint
799
+ last_checkpoint = None
800
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
801
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
802
+ if last_checkpoint is None and len(sorted_checkpoints(training_args.output_dir)) > 0:
803
+ raise ValueError(
804
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
805
+ "Use --overwrite_output_dir to overcome."
806
+ )
807
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
808
+ logger.info(
809
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
810
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
811
+ )
812
+
813
+ # 5. Handle the repository creation
814
+ if accelerator.is_main_process:
815
+ if training_args.output_dir is not None:
816
+ os.makedirs(training_args.output_dir, exist_ok=True)
817
+ if training_args.push_to_hub:
818
+ if training_args.hub_model_id is None:
819
+ repo_name = get_full_repo_name(
820
+ Path(training_args.output_dir).absolute().name,
821
+ token=training_args.hub_token,
822
+ )
823
+ else:
824
+ repo_name = training_args.hub_model_id
825
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
826
+
827
+ with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
828
+ if "wandb" not in gitignore:
829
+ gitignore.write("wandb\n")
830
+ accelerator.wait_for_everyone()
831
+
832
+ # 6. Load dataset - either streaming or non-streaming (offline)
833
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
834
+
835
+ # set seed for determinism
836
+ set_seed(training_args.seed)
837
+
838
+ if training_args.do_train:
839
+ raw_datasets["train"] = load_multiple_datasets(
840
+ data_args.train_dataset_name,
841
+ data_args.train_dataset_config_name,
842
+ splits=data_args.train_split_name,
843
+ text_column_names=data_args.text_column_name,
844
+ prompt_column_names=data_args.prompt_column_name,
845
+ streaming=data_args.streaming,
846
+ dataset_samples=data_args.train_dataset_samples,
847
+ seed=training_args.seed,
848
+ accelerator=accelerator,
849
+ cache_dir=data_args.dataset_cache_dir,
850
+ token=model_args.token,
851
+ num_proc=data_args.preprocessing_num_workers,
852
+ )
853
+ raw_datasets_train_features = set(raw_datasets["train"].features.keys())
854
+
855
+ if training_args.do_eval:
856
+ dataset_names_dict = convert_dataset_str_to_list(
857
+ data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
858
+ (
859
+ data_args.eval_dataset_config_name
860
+ if data_args.eval_dataset_config_name
861
+ else data_args.train_dataset_config_name
862
+ ),
863
+ splits=data_args.eval_split_name,
864
+ text_column_names=data_args.eval_text_column_name,
865
+ prompt_column_names=data_args.eval_prompt_column_name,
866
+ )
867
+ all_eval_splits = []
868
+ if len(dataset_names_dict) == 1:
869
+ # load a single eval set
870
+ dataset_dict = dataset_names_dict[0]
871
+ all_eval_splits.append("eval")
872
+ raw_datasets["eval"] = load_dataset(
873
+ dataset_dict["name"],
874
+ dataset_dict["config"],
875
+ split=dataset_dict["split"],
876
+ cache_dir=data_args.dataset_cache_dir,
877
+ token=model_args.token,
878
+ streaming=data_args.streaming,
879
+ )
880
+ if dataset_dict["text_column_name"] != "text":
881
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_text_column_name, "text")
882
+ if dataset_dict["prompt_column_name"] != "prompt":
883
+ raw_datasets["eval"] = raw_datasets["eval"].rename_column(data_args.eval_prompt_column_name, "prompt")
884
+ else:
885
+ # load multiple eval sets
886
+ for dataset_dict in dataset_names_dict:
887
+ pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['config'].replace('.', '-')}"
888
+ all_eval_splits.append(pretty_name)
889
+ raw_datasets[pretty_name] = load_dataset(
890
+ dataset_dict["name"],
891
+ dataset_dict["config"],
892
+ split=dataset_dict["split"],
893
+ cache_dir=data_args.dataset_cache_dir,
894
+ token=model_args.token,
895
+ streaming=data_args.streaming,
896
+ )
897
+ # make column names consistent (text, audio)
898
+ if dataset_dict["text_column_name"] != "text":
899
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
900
+ dataset_dict["text_column_name"], "text"
901
+ )
902
+ if dataset_dict["prompt_column_name"] != "prompt":
903
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
904
+ dataset_dict["prompt_column_name"], "prompt"
905
+ )
906
+ raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
907
+ set(raw_datasets[pretty_name].features.keys()) - {"text", "prompt"}
908
+ )
909
+
910
+ if not training_args.do_train and not training_args.do_eval:
911
+ raise ValueError(
912
+ "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
913
+ )
914
+
915
+ # 7. Load pretrained model, tokenizer, and feature extractor
916
+ config = AutoConfig.from_pretrained(
917
+ (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
918
+ cache_dir=model_args.cache_dir,
919
+ revision=model_args.model_revision,
920
+ token=model_args.token,
921
+ )
922
+ if training_args.output_router_logits:
923
+ config.output_router_logits = True
924
+
925
+ tokenizer = AutoTokenizer.from_pretrained(
926
+ (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
927
+ cache_dir=model_args.cache_dir,
928
+ use_fast=model_args.use_fast_tokenizer,
929
+ revision=model_args.model_revision,
930
+ token=model_args.token,
931
+ )
932
+ if tokenizer.pad_token_id is None:
933
+ tokenizer.pad_token = tokenizer.eos_token
934
+
935
+ quantization_config_teacher, quantization_config_student = get_quantization_config(
936
+ model_args, torch_dtype=teacher_dtype
937
+ )
938
+
939
+ # The teacher model can safely be cast to the dtype of training since we don't
940
+ # update the params
941
+ teacher_model = AutoModelForCausalLM.from_pretrained(
942
+ model_args.teacher_model_name_or_path,
943
+ cache_dir=model_args.cache_dir,
944
+ token=model_args.token,
945
+ low_cpu_mem_usage=True,
946
+ torch_dtype=teacher_dtype,
947
+ attn_implementation=model_args.attn_implementation,
948
+ quantization_config=quantization_config_teacher,
949
+ )
950
+
951
+ student_model = AutoModelForCausalLM.from_pretrained(
952
+ model_args.model_name_or_path,
953
+ config=config,
954
+ cache_dir=model_args.cache_dir,
955
+ revision=model_args.model_revision,
956
+ subfolder=model_args.subfolder,
957
+ token=model_args.token,
958
+ torch_dtype=teacher_dtype,
959
+ low_cpu_mem_usage=True,
960
+ attn_implementation=model_args.attn_implementation,
961
+ quantization_config=quantization_config_student,
962
+ )
963
+
964
+ if quantization_config_student is not None:
965
+ lora_config = LoraConfig(
966
+ r=model_args.lora_r,
967
+ lora_alpha=model_args.lora_alpha,
968
+ target_modules=model_args.lora_target_modules,
969
+ lora_dropout=model_args.lora_dropout,
970
+ bias="none",
971
+ task_type="CAUSAL_LM",
972
+ )
973
+ student_model = get_peft_model(student_model, lora_config)
974
+
975
+ if student_model.generation_config.bos_token_id is None or teacher_model.generation_config.bos_token_id is None:
976
+ raise ValueError(
977
+ f"Make sure that `generation_config.bos_token_id` is correctly defined for both the "
978
+ f"student and teacher model. Got {student_model.generation_config.bos_token_id} for the "
979
+ f"student and {teacher_model.generation_config.bos_token_id} for the teacher."
980
+ )
981
+
982
+ # enable gradient checkpointing if necessary
983
+ if training_args.gradient_checkpointing:
984
+ student_model.gradient_checkpointing_enable()
985
+
986
+ def set_trainable_parameters(module, requires_grad=False):
987
+ for param in module.parameters():
988
+ param.requires_grad = requires_grad
989
+ module._requires_grad = requires_grad
990
+
991
+ # freeze student lm head if necessary
992
+ if training_args.freeze_lm_head:
993
+ set_trainable_parameters(student_model.lm_head, requires_grad=False)
994
+
995
+ student_model.generation_config.max_length = data_args.max_label_length
996
+
997
+ # 8. Save all pre-processed tokenizers/config/generation configs
998
+ if accelerator.is_main_process:
999
+ tokenizer.save_pretrained(training_args.output_dir)
1000
+ # save the config and generation config as well
1001
+ config.save_pretrained(training_args.output_dir)
1002
+ student_model.generation_config.save_pretrained(training_args.output_dir)
1003
+
1004
+ accelerator.wait_for_everyone()
1005
+
1006
+
1007
+ # 10. Preprocessing the datasets: we need to combine the prompt and generations and tokenize the targets.
1008
+ # 10.1: Define the pre-processing constants
1009
+ max_label_length = (
1010
+ data_args.max_label_length if data_args.max_label_length is not None else config.max_length
1011
+ )
1012
+ num_workers = data_args.preprocessing_num_workers
1013
+ dataloader_num_workers = training_args.dataloader_num_workers
1014
+ prefetch_factor = training_args.dataloader_prefetch_factor
1015
+ eos_token_id = tokenizer.eos_token_id
1016
+ if model_args.instruction_model is not None:
1017
+ instruction_model = model_args.instruction_model
1018
+ else:
1019
+ instruction_model = "instruct" in model_args.model_name_or_path.lower()
1020
+
1021
+ # 10.2: filter based on maximum number of training/evaluation samples
1022
+ if training_args.do_train and data_args.max_train_samples is not None:
1023
+ raw_datasets["train"] = (
1024
+ raw_datasets["train"].take(data_args.max_train_samples)
1025
+ if data_args.streaming
1026
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
1027
+ )
1028
+
1029
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1030
+ for eval_split in all_eval_splits:
1031
+ raw_datasets[eval_split] = (
1032
+ raw_datasets[eval_split].take(data_args.max_eval_samples)
1033
+ if data_args.streaming
1034
+ else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1035
+ )
1036
+
1037
+ # 10.3: pre-process training/evaluation datasets
1038
+ def prepare_dataset(example):
1039
+ example["labels"] = tokenizer(example["prompt"] + example["text"]).input_ids
1040
+ if example["labels"][-1] != eos_token_id:
1041
+ example["labels"] += [eos_token_id]
1042
+ example["prompt_length"] = len(tokenizer(example["prompt"]).input_ids)
1043
+ return example
1044
+
1045
+ def prepare_instruction_dataset(example):
1046
+ messages = [
1047
+ {"role": "user", "content": example["prompt"]},
1048
+ {"role": "assistant", "content": example["text"]},
1049
+ ]
1050
+ example["labels"] = tokenizer.apply_chat_template(messages)
1051
+ if example["labels"][-1] != eos_token_id:
1052
+ example["labels"] = example["labels"][:-1]
1053
+
1054
+ example["prompt_length"] = len(tokenizer.apply_chat_template([messages[0]]))
1055
+ return example
1056
+
1057
+ prepare_dataset = prepare_instruction_dataset if instruction_model else prepare_dataset
1058
+ vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1059
+ if training_args.do_train:
1060
+ # with streaming mode we can only have 1 worker, whereas with non-streaming
1061
+ # we can use `num_workers` (which is much faster)
1062
+ # We gate the pre-processing function accordingly
1063
+ map_fn_train = partial(
1064
+ raw_datasets["train"].map,
1065
+ function=prepare_dataset,
1066
+ remove_columns=raw_datasets_train_features,
1067
+ )
1068
+ with accelerator.main_process_first():
1069
+ vectorized_datasets["train"] = (
1070
+ map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1071
+ if not data_args.streaming
1072
+ else map_fn_train()
1073
+ )
1074
+ if training_args.do_eval:
1075
+ for eval_split in all_eval_splits:
1076
+ raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1077
+ map_fn_eval = partial(
1078
+ raw_datasets[eval_split].map, function=prepare_dataset, remove_columns=raw_datasets_eval_features
1079
+ )
1080
+ with accelerator.main_process_first():
1081
+ vectorized_datasets[eval_split] = (
1082
+ map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1083
+ if not data_args.streaming
1084
+ else map_fn_eval()
1085
+ )
1086
+
1087
+ # 10.4: Filter training data with labels longer than `max_label_length`
1088
+ def is_labels_in_length_range(labels):
1089
+ return 0 < len(labels) <= max_label_length
1090
+
1091
+ filter_by_labels_fn = partial(
1092
+ vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1093
+ )
1094
+ with accelerator.main_process_first():
1095
+ vectorized_datasets = (
1096
+ filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1097
+ if not data_args.streaming
1098
+ else filter_by_labels_fn()
1099
+ )
1100
+
1101
+ # Pre-processing complete!
1102
+ # For large datasets it is advised to run the preprocessing on a
1103
+ # single machine first with `--preprocessing_only` since there will mostly likely
1104
+ # be a timeout when running the script in distributed mode.
1105
+ # In a second step, `--preprocessing_only` can then be set to `False` to load the
1106
+ # cached dataset
1107
+ if data_args.preprocessing_only:
1108
+ if data_args.streaming:
1109
+ raise ValueError(
1110
+ "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1111
+ "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1112
+ "on the fly with streaming mode."
1113
+ )
1114
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1115
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1116
+ return
1117
+
1118
+ # 11. Define Evaluation Metrics
1119
+ def compute_metrics(preds, labels):
1120
+ # TODO(SG): better metrics for performance?
1121
+ # replace padded labels by the padding token
1122
+ for idx in range(len(labels)):
1123
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1124
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
1125
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1126
+ return pred_str, label_str
1127
+
1128
+ # 12. Define Training Schedule
1129
+ # 12.1: Store some constants
1130
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1131
+ train_batch_size = per_device_train_batch_size * accelerator.num_processes
1132
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1133
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1134
+
1135
+ # 12.2: Set max training steps
1136
+ if not data_args.streaming and training_args.max_steps < 0:
1137
+ num_epochs = int(training_args.num_train_epochs)
1138
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1139
+ total_train_steps = steps_per_epoch * num_epochs
1140
+ elif training_args.max_steps > 0:
1141
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
1142
+ total_train_steps = int(training_args.max_steps)
1143
+ if not data_args.streaming:
1144
+ steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1145
+ num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1146
+ else:
1147
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1148
+ num_epochs = sys.maxsize
1149
+ steps_per_epoch = total_train_steps
1150
+ else:
1151
+ raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1152
+
1153
+ # 12.3: Set evaluation steps
1154
+ if training_args.evaluation_strategy == "epoch":
1155
+ eval_steps = steps_per_epoch
1156
+ elif training_args.eval_steps is None:
1157
+ logger.info(
1158
+ f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1159
+ )
1160
+ eval_steps = steps_per_epoch
1161
+ else:
1162
+ eval_steps = training_args.eval_steps
1163
+
1164
+ # 12.4: Set save steps
1165
+ if training_args.save_strategy == "epoch":
1166
+ save_steps = steps_per_epoch
1167
+ elif training_args.save_strategy == "steps":
1168
+ save_steps = training_args.save_steps
1169
+ else:
1170
+ save_steps = sys.maxsize
1171
+
1172
+ # 13. Define optimizer, LR scheduler, collator
1173
+ decay_parameters = get_parameter_names(
1174
+ student_model,
1175
+ [nn.LayerNorm],
1176
+ )
1177
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
1178
+ optimizer_grouped_parameters = [
1179
+ {
1180
+ "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1181
+ "weight_decay": training_args.weight_decay,
1182
+ },
1183
+ {
1184
+ "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1185
+ "weight_decay": 0.0,
1186
+ },
1187
+ ]
1188
+ optimizer = torch.optim.AdamW(
1189
+ params=optimizer_grouped_parameters,
1190
+ lr=training_args.learning_rate,
1191
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
1192
+ eps=training_args.adam_epsilon,
1193
+ )
1194
+
1195
+ # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1196
+ lr_scheduler = get_scheduler(
1197
+ name=training_args.lr_scheduler_type,
1198
+ optimizer=optimizer,
1199
+ num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1200
+ num_training_steps=total_train_steps * accelerator.num_processes,
1201
+ )
1202
+
1203
+ data_collator = DataCollatorCausalLMWithPadding(
1204
+ tokenizer=tokenizer,
1205
+ target_padding="max_length",
1206
+ max_target_length=max_label_length,
1207
+ )
1208
+
1209
+ # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1210
+ # so that we can still access the configs
1211
+ num_beams = (
1212
+ training_args.generation_num_beams
1213
+ if training_args.generation_num_beams is not None
1214
+ else getattr(student_model.generation_config, "num_beams", 1)
1215
+ )
1216
+
1217
+ # 15. Prepare everything with accelerate
1218
+ student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1219
+ student_model, teacher_model, optimizer, lr_scheduler
1220
+ )
1221
+
1222
+ def kl_divergence(target_distribution, log_predicted_distribution, labels):
1223
+ kl_loss = nn.KLDivLoss(reduction="none")
1224
+ divergence = kl_loss(log_predicted_distribution, target_distribution)
1225
+ # ignore padded tokens from divergence, i.e. where labels are not set to -100
1226
+ padding_mask = labels >= 0
1227
+ padding_mask = padding_mask.unsqueeze(-1)
1228
+ divergence = divergence * padding_mask
1229
+ # take the average over the mini-batch
1230
+ divergence = divergence.sum() / padding_mask.sum()
1231
+ return divergence
1232
+
1233
+ # Define gradient update step fn
1234
+ def train_step(
1235
+ batch,
1236
+ temperature=2.0,
1237
+ ):
1238
+ student_model.train()
1239
+ teacher_model.eval()
1240
+
1241
+ student_outputs = student_model(**batch)
1242
+ with torch.no_grad():
1243
+ teacher_outputs = teacher_model(**batch)
1244
+
1245
+ # CE (data) loss
1246
+ ce_loss = student_outputs.loss
1247
+ # rescale distribution by temperature to ensure gradients scale correctly
1248
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1249
+ # log softmax of student predictions for numerical stability
1250
+ student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1251
+ # KL-divergence loss (scaled by temperature)
1252
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1253
+
1254
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1255
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1256
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1257
+ return loss, metrics
1258
+
1259
+ # Define eval fn
1260
+ def eval_step(batch):
1261
+ student_model.eval()
1262
+ teacher_model.eval()
1263
+
1264
+ with torch.no_grad():
1265
+ student_outputs = student_model(**batch)
1266
+ teacher_outputs = teacher_model(**batch)
1267
+
1268
+ # CE (data) loss
1269
+ ce_loss = student_outputs.loss
1270
+
1271
+ # log softmax / softmax for numerical stability
1272
+ student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1273
+ teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1274
+ # temperature is always 1 for eval
1275
+ kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1276
+
1277
+ # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1278
+ loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1279
+ metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1280
+ return metrics
1281
+
1282
+ def generate_step(batch):
1283
+ student_model.eval()
1284
+ output_ids = accelerator.unwrap_model(student_model).generate(
1285
+ **batch, max_length=max_label_length, num_beams=num_beams
1286
+ )
1287
+ output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1288
+ return output_ids
1289
+
1290
+ logger.info("***** Running training *****")
1291
+ logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1292
+ if not data_args.streaming:
1293
+ logger.info(f" Num epochs = {num_epochs}")
1294
+ logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1295
+ logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1296
+ logger.info(
1297
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1298
+ )
1299
+ logger.info(f" Total optimization steps = {total_train_steps}")
1300
+
1301
+ # ======================== Training ================================
1302
+ train_time = 0
1303
+ train_start = time.time()
1304
+ steps_trained_progress_bar = tqdm(
1305
+ range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1306
+ )
1307
+ continue_training = True
1308
+ epochs_trained = 0
1309
+ cur_step = 0
1310
+
1311
+ checkpoint = None
1312
+ if training_args.resume_from_checkpoint is not None:
1313
+ checkpoint = training_args.resume_from_checkpoint
1314
+ elif last_checkpoint is not None:
1315
+ checkpoint = last_checkpoint
1316
+
1317
+ if checkpoint is not None:
1318
+ accelerator.load_state(checkpoint)
1319
+ # Find num steps and epoch from saved state string pattern
1320
+ pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1321
+ match = re.search(pattern, checkpoint)
1322
+ cur_step = int(match.group(1))
1323
+ epochs_trained = int(match.group(2))
1324
+
1325
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1326
+ logger.info(f" Continuing training from epoch {epochs_trained}")
1327
+ logger.info(f" Continuing training from global step {cur_step}")
1328
+
1329
+ steps_trained_progress_bar.update(cur_step)
1330
+
1331
+ for epoch in range(0, epochs_trained):
1332
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1333
+
1334
+ if not data_args.streaming and training_args.max_steps < 0:
1335
+ # we know exactly the number of steps per epoch, so can skip through the required number of batches
1336
+ resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1337
+ else:
1338
+ # Currently we don't know how many steps we've taken in the current epoch
1339
+ # So we just shuffle the dataset one extra time and start from a fresh epoch
1340
+ # This is "good enough" for our purposes but not fully correct
1341
+ resume_step = None
1342
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1343
+ else:
1344
+ resume_step = None
1345
+
1346
+ for epoch in range(epochs_trained, num_epochs):
1347
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1348
+ train_dataloader = DataLoader(
1349
+ vectorized_datasets["train"],
1350
+ collate_fn=data_collator,
1351
+ batch_size=per_device_train_batch_size,
1352
+ num_workers=dataloader_num_workers,
1353
+ prefetch_factor=prefetch_factor,
1354
+ pin_memory=training_args.dataloader_pin_memory,
1355
+ )
1356
+ train_dataloader = accelerator.prepare(train_dataloader)
1357
+ if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1358
+ train_dataloader.dataset.set_epoch(epoch)
1359
+
1360
+ if resume_step is not None:
1361
+ # Skip the first N batches in the dataloader when resuming from a checkpoint
1362
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1363
+ resume_step = None
1364
+
1365
+ for batch in train_dataloader:
1366
+ with accelerator.accumulate(student_model):
1367
+ loss, train_metric = train_step(batch, temperature=training_args.temperature)
1368
+ accelerator.backward(loss)
1369
+ if accelerator.sync_gradients:
1370
+ accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1371
+ optimizer.step()
1372
+ lr_scheduler.step()
1373
+ optimizer.zero_grad()
1374
+
1375
+ # Check if the accelerator has performed an optimization step behind the scenes
1376
+ if accelerator.sync_gradients:
1377
+ steps_trained_progress_bar.update(1)
1378
+ cur_step += 1
1379
+
1380
+ if cur_step % training_args.logging_steps == 0:
1381
+ steps_trained_progress_bar.write(
1382
+ f"Step... ({cur_step} / {total_train_steps} | Loss:"
1383
+ f" {train_metric['loss']}, Learning Rate:"
1384
+ f" {lr_scheduler.get_last_lr()[0]})"
1385
+ )
1386
+ log_metric(
1387
+ accelerator,
1388
+ metrics=train_metric,
1389
+ learning_rate=lr_scheduler.get_last_lr()[0],
1390
+ train_time=train_time + time.time() - train_start,
1391
+ step=cur_step,
1392
+ epoch=epoch if data_args.streaming else epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1393
+ prefix="train",
1394
+ )
1395
+
1396
+ # save checkpoint and weights after each save_steps and at the end of training
1397
+ if (cur_step % save_steps == 0) or cur_step == total_train_steps:
1398
+ accelerator.wait_for_everyone()
1399
+ intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1400
+ accelerator.save_state(output_dir=intermediate_dir)
1401
+ unwrapped_model = accelerator.unwrap_model(student_model)
1402
+ unwrapped_model.save_pretrained(
1403
+ intermediate_dir,
1404
+ is_main_process=accelerator.is_main_process,
1405
+ save_function=accelerator.save,
1406
+ )
1407
+ if accelerator.is_main_process:
1408
+ checkpoint_to_be_deleted = rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1409
+ if training_args.push_to_hub:
1410
+ upload_folder(
1411
+ folder_path=training_args.output_dir,
1412
+ repo_id=repo_name,
1413
+ repo_type="model",
1414
+ commit_message=f"Saving train state of step {cur_step}",
1415
+ delete_patterns=checkpoint_to_be_deleted,
1416
+ )
1417
+
1418
+ if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1419
+ train_time += time.time() - train_start
1420
+ student_model.eval()
1421
+ # ======================== Evaluating ==============================
1422
+ for eval_split in all_eval_splits:
1423
+ eval_metrics = []
1424
+ eval_preds = []
1425
+ eval_labels = []
1426
+ eval_start = time.time()
1427
+
1428
+ validation_dataloader = DataLoader(
1429
+ vectorized_datasets[eval_split],
1430
+ collate_fn=data_collator,
1431
+ batch_size=per_device_eval_batch_size,
1432
+ drop_last=False,
1433
+ num_workers=dataloader_num_workers,
1434
+ prefetch_factor=prefetch_factor,
1435
+ pin_memory=training_args.dataloader_pin_memory,
1436
+ )
1437
+ validation_dataloader = accelerator.prepare(validation_dataloader)
1438
+
1439
+ for batch in tqdm(
1440
+ validation_dataloader,
1441
+ desc=f"Evaluating {eval_split}...",
1442
+ position=2,
1443
+ disable=not accelerator.is_local_main_process,
1444
+ ):
1445
+ # Model forward
1446
+ eval_metric = eval_step(batch)
1447
+ eval_metric = accelerator.gather_for_metrics(eval_metric)
1448
+ eval_metrics.append(eval_metric)
1449
+
1450
+ # generation
1451
+ if training_args.predict_with_generate:
1452
+ generated_ids = generate_step(batch)
1453
+ # Gather all predictions and targets
1454
+ generated_ids, labels = accelerator.gather_for_metrics(
1455
+ (generated_ids, batch["labels"])
1456
+ )
1457
+ eval_preds.extend(generated_ids)
1458
+ eval_labels.extend(labels)
1459
+
1460
+ eval_time = time.time() - eval_start
1461
+ stack = torch.stack if accelerator.num_processes == 1 else torch.concatenate
1462
+ # normalize eval metrics
1463
+ eval_metrics = {
1464
+ key: torch.mean(stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1465
+ }
1466
+ try:
1467
+ eval_metrics["perplexity"] = math.exp(eval_metrics["ce_loss"])
1468
+ except OverflowError:
1469
+ eval_metrics["perplexity"] = float("inf")
1470
+
1471
+ if training_args.predict_with_generate:
1472
+ pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1473
+ log_pred(
1474
+ accelerator,
1475
+ pred_str,
1476
+ label_str,
1477
+ step=cur_step,
1478
+ epoch=epoch,
1479
+ evaluation_strategy=training_args.evaluation_strategy,
1480
+ prefix=eval_split,
1481
+ )
1482
+
1483
+ # Print metrics and update progress bar
1484
+ logger_desc = " ".join([f"Eval {key}: {value} |" for key, value in eval_metrics.items()])
1485
+ steps_trained_progress_bar.write(
1486
+ f"Eval results for step ({cur_step} / {total_train_steps} | {logger_desc}"
1487
+ )
1488
+
1489
+ log_metric(
1490
+ accelerator,
1491
+ metrics=eval_metrics,
1492
+ train_time=eval_time,
1493
+ step=cur_step,
1494
+ epoch=epoch if data_args.streaming else epoch + (cur_step - epoch * steps_per_epoch) / steps_per_epoch,
1495
+ prefix=eval_split,
1496
+ )
1497
+
1498
+ # flush the train metrics
1499
+ train_start = time.time()
1500
+
1501
+ # break condition
1502
+ if cur_step == total_train_steps:
1503
+ accelerator.wait_for_everyone()
1504
+ # un-wrap student model for save
1505
+ student_model = accelerator.unwrap_model(student_model)
1506
+ student_model.save_pretrained(
1507
+ training_args.output_dir,
1508
+ is_main_process=accelerator.is_main_process,
1509
+ save_function=accelerator.save,
1510
+ )
1511
+ if training_args.push_to_hub and accelerator.is_main_process:
1512
+ upload_folder(
1513
+ folder_path=training_args.output_dir,
1514
+ repo_id=repo_name,
1515
+ repo_type="model",
1516
+ commit_message=f"Saving final weights of step {cur_step}",
1517
+ )
1518
+ continue_training = False
1519
+ break
1520
+
1521
+ if not continue_training:
1522
+ break
1523
+
1524
+ accelerator.end_training()
1525
+
1526
+
1527
+ if __name__ == "__main__":
1528
+ main()
slurm_job.slurm ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=distil-mistral
3
+ #SBATCH --nodes=1
4
+ # set 24h for job wall time limit
5
+ #SBATCH --time=48:00:00
6
+ #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
7
+ #SBATCH --cpus-per-task=32
8
+ #SBATCH --gres=gpu:8
9
+ #SBATCH --exclusive
10
+ #SBATCH --partition=hopper-prod
11
+ #SBATCH --output=/fsx/sanchit/logs/%x-%j.out
12
+
13
+ set -x -e
14
+
15
+ source ~/.bashrc
16
+ source /fsx/sanchit/miniconda3/bin/activate venv
17
+
18
+ echo "START TIME: $(date)"
19
+
20
+
21
+ LOG_PATH="/fsx/sanchit/logs/main_log.txt"
22
+ SAVE_DIR="/fsx/sanchit"
23
+
24
+ GPUS_PER_NODE=8
25
+ NNODES=$SLURM_NNODES
26
+
27
+ # so processes know who to talk to
28
+ MASTER_ADDR=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1`
29
+
30
+ # From https://i.hsfzxjy.site/2021-03-10-obtain-a-random-unused-tcp-port-with-bash/
31
+ function unused_port() {
32
+ N=${1:-1}
33
+ comm -23 \
34
+ <(seq "1025" "65535" | sort) \
35
+ <(ss -Htan |
36
+ awk '{print $4}' |
37
+ cut -d':' -f2 |
38
+ sort -u) |
39
+ shuf |
40
+ head -n "$N"
41
+ }
42
+ MASTER_PORT=$(unused_port)
43
+
44
+ # export TORCH_CPP_LOG_LEVEL=INFO
45
+ # export TORCH_DISTRIBUTED_DEBUG=DETAIL
46
+
47
+ export LAUNCHER="python -u -m accelerate.commands.launch --config_file ./accelerate_config.yaml"
48
+
49
+ export PROGRAM="./run_distillation.py ./config_mistral.yaml"
50
+ export CMD="$LAUNCHER $PROGRAM"
51
+ echo $CMD
52
+
53
+ SRUN_ARGS=" \
54
+ --wait=60 \
55
+ --kill-on-bad-exit=1 \
56
+ "
57
+
58
+ # py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
59
+ clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt
60
+
61
+
62
+ # srun error handling:
63
+ # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
64
+ # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
65
+
66
+ # SRUN_ARGS=" \
67
+ # --wait=60 \
68
+ # --kill-on-bad-exit=1 \
69
+ # "
70
+ #
71
+ # # py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
72
+ # clear; srun $SRUN_ARGS --jobid $SLURM_JOBID bash -c "$CMD" 2>&1 | tee -a $SAVE_DIR/logs/main_log.txt
73
+
74
+ echo "END TIME: $(date)"
75
+
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "additional_special_tokens": [],
31
+ "bos_token": "<s>",
32
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "</s>",
35
+ "legacy": true,
36
+ "model_max_length": 1000000000000000019884624838656,
37
+ "pad_token": "</s>",
38
+ "sp_model_kwargs": {},
39
+ "spaces_between_special_tokens": false,
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }