sin2piusc commited on
Commit
454fc44
1 Parent(s): 90fc927

Upload whisper-trainer.ipynb

Browse files
Files changed (1) hide show
  1. whisper-trainer.ipynb +1440 -0
whisper-trainer.ipynb ADDED
@@ -0,0 +1,1440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import torch\n",
11
+ "import transformers\n",
12
+ "import evaluate\n",
13
+ "import string\n",
14
+ "import re\n",
15
+ "import warnings\n",
16
+ "import tensorboard\n",
17
+ "import datetime\n",
18
+ "import neologdn\n",
19
+ "import datasets\n",
20
+ "import MeCab\n",
21
+ "import pandas as pd\n",
22
+ "import soundfile as sf\n",
23
+ "\n",
24
+ "from evaluate import load\n",
25
+ "from torch.utils.data import DataLoader\n",
26
+ "from tqdm import tqdm\n",
27
+ "import numpy as np\n",
28
+ "import gc\n",
29
+ "from multiprocessing import Pool\n",
30
+ "\n",
31
+ "from dataclasses import dataclass\n",
32
+ "from typing import List, Optional, Any, Dict, List, Union\n",
33
+ "from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n",
34
+ "from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n",
35
+ "#from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor\n",
36
+ "from lomo_optim import Lomo\n",
37
+ "from lomo_optim import AdaLomo\n",
38
+ "\n",
39
+ "from datasets import (\n",
40
+ " Audio,\n",
41
+ " interleave_datasets,\n",
42
+ " concatenate_datasets,\n",
43
+ " IterableDataset,\n",
44
+ " load_dataset,\n",
45
+ " IterableDatasetDict,\n",
46
+ " Features,\n",
47
+ " Value,\n",
48
+ " disable_caching,\n",
49
+ " enable_caching,\n",
50
+ " DatasetDict,\n",
51
+ " DownloadConfig,\n",
52
+ " load_from_disk,\n",
53
+ " Dataset,\n",
54
+ ")\n",
55
+ "\n",
56
+ "from peft import (\n",
57
+ " PeftModel,\n",
58
+ " PeftConfig,\n",
59
+ " prepare_model_for_kbit_training,\n",
60
+ " LoraConfig,\n",
61
+ " get_peft_model,\n",
62
+ " replace_lora_weights_loftq,\n",
63
+ " AdaLoraConfig,\n",
64
+ " LoHaModel, \n",
65
+ " LoHaConfig,\n",
66
+ " LoKrModel, \n",
67
+ " LoKrConfig,\n",
68
+ ")\n",
69
+ "from transformers import (\n",
70
+ " WhisperForConditionalGeneration,\n",
71
+ " WhisperProcessor,\n",
72
+ " Seq2SeqTrainer,\n",
73
+ " TrainerCallback,\n",
74
+ " Seq2SeqTrainingArguments,\n",
75
+ " TrainerState,\n",
76
+ " TrainerControl,\n",
77
+ " TrainingArguments,\n",
78
+ " BitsAndBytesConfig,\n",
79
+ " WhisperTokenizer,\n",
80
+ " WhisperFeatureExtractor,\n",
81
+ " PushToHubCallback,\n",
82
+ " AutoTokenizer,\n",
83
+ " WhisperConfig,\n",
84
+ ")"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
94
+ "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
95
+ "\n",
96
+ "model_name_or_path =\"\"\n",
97
+ "dataset = \"\"\n",
98
+ "\n",
99
+ "cache_dir=\"\"\n",
100
+ "output_dir=\"\" \n",
101
+ "language = \"\"\n",
102
+ "language_abbr = \"\"\n",
103
+ "task = \"\"\n",
104
+ "\n",
105
+ "warnings.filterwarnings('ignore', 'Unable to register * factory' , Warning) \n",
106
+ "#ransformers.utils.logging.set_verbosity_info()\n"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "####\n",
116
+ "norm_everything = False\n",
117
+ "do_remove_special_characters = False \n",
118
+ "do_normalize_basic = False #hf basic \n",
119
+ "do_normalize_jp = False #mecab japanese\n",
120
+ "do_audio_filter = True\n",
121
+ "use_peft = True\n",
122
+ "use_adalora = False\n",
123
+ "use_loha = False\n",
124
+ "use_lokr = False\n",
125
+ "\n",
126
+ "special_characters = '[\\,\\、\\。\\.\\「\\」\\…\\?\\・\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n",
127
+ "metric = evaluate.load(\"cer\")\n",
128
+ "normalizer = BasicTextNormalizer()\n",
129
+ "wakati = MeCab.Tagger(\"-Owakati\")"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "feature_extractor = WhisperFeatureExtractor.from_pretrained(\n",
139
+ " model_name_or_path,\n",
140
+ " do_normalize = False,\n",
141
+ " # device=\"cuda\",\n",
142
+ " # sampling_rate=16000,\n",
143
+ " # return_attention_mask=True,\n",
144
+ " # truncation=True,\n",
145
+ " # n_fft=512,\n",
146
+ " # n_mels=512,\n",
147
+ " # chunk_length=60,\n",
148
+ " # hop_length=320,\n",
149
+ " # pad_mode=\"reflect\",\n",
150
+ " # power=2.0,\n",
151
+ " # norm=\"slaney\",\n",
152
+ " # mel_scale=\"slaney\",\n",
153
+ " )\n",
154
+ "tokenizer = WhisperTokenizer.from_pretrained(\n",
155
+ " model_name_or_path,\n",
156
+ " language=language,\n",
157
+ " task=task,\n",
158
+ " )\n",
159
+ "processor = WhisperProcessor.from_pretrained(\n",
160
+ " model_name_or_path,\n",
161
+ " tokenizer=tokenizer,\n",
162
+ " feature_extractor=feature_extractor,\n",
163
+ " language=language,\n",
164
+ " task=task,\n",
165
+ " )"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "\n",
175
+ "special_characters = '[,\\���\\。\\.\\「\\」\\…\\?\\・\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n",
176
+ "metric = evaluate.load(\"cer\")\n",
177
+ "normalizer = BasicTextNormalizer()\n",
178
+ "wakati = MeCab.Tagger(\"-Owakati\")\n",
179
+ "\n",
180
+ "def load_streaming_dataset(dataset_name, dataset_config_name, split=\"train\", **kwargs):\n",
181
+ "\n",
182
+ " if \"+\" in split:\n",
183
+ " dataset_splits = [\n",
184
+ " load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=True, **kwargs)\n",
185
+ " for split_name in split.split(\"+\")\n",
186
+ " ]\n",
187
+ " interleaved_dataset = interleave_datasets(dataset_splits)\n",
188
+ " return interleaved_dataset\n",
189
+ " else:\n",
190
+ " dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=True, **kwargs)\n",
191
+ " return dataset\n",
192
+ "\n",
193
+ "def load_multiple_streaming_datasets(\n",
194
+ " dataset_names: List,\n",
195
+ " dataset_config_names: List,\n",
196
+ " splits: Optional[List] = None,\n",
197
+ " text_column_names: Optional[List] = None,\n",
198
+ " sampling_rate: Optional[int] = 16000,\n",
199
+ " stopping_strategy: Optional[str] = \"all_exhausted\",\n",
200
+ " **kwargs\n",
201
+ ") -> IterableDataset:\n",
202
+ "\n",
203
+ " if len(dataset_names) != len(dataset_config_names):\n",
204
+ " raise ValueError(\n",
205
+ " f\"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and\"\n",
206
+ " f\" {len(dataset_config_names)} configs.\"\n",
207
+ " )\n",
208
+ "\n",
209
+ " if splits is not None and len(splits) != len(dataset_names):\n",
210
+ " raise ValueError(\n",
211
+ " f\"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits.\"\n",
212
+ " )\n",
213
+ "\n",
214
+ " if text_column_names is not None and len(text_column_names) != len(dataset_names):\n",
215
+ " raise ValueError(\n",
216
+ " f\"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and\"\n",
217
+ " f\" {len(text_column_names)} text column names.\"\n",
218
+ " )\n",
219
+ "\n",
220
+ " splits = splits if splits is not None else [\"train\" for i in range(len(dataset_names))]\n",
221
+ " text_column_names = (\n",
222
+ " text_column_names if text_column_names is not None else [\"text\" for i in range(len(dataset_names))]\n",
223
+ " )\n",
224
+ "\n",
225
+ " all_datasets = []\n",
226
+ " for i, dataset_name in enumerate(dataset_names):\n",
227
+ " dataset = load_dataset(dataset_name, dataset_config_names[i], split=splits[i], streaming=True, **kwargs)\n",
228
+ " dataset = dataset.cast_column(\"audio\", Audio(sampling_rate))\n",
229
+ " if text_column_names[i] != \"sentence\":\n",
230
+ " dataset = dataset.rename_column(text_column_names[i], \"sentence\")\n",
231
+ " dataset = dataset.remove_columns(set(dataset.features.keys()) - set([\"audio\", \"sentence\"]))\n",
232
+ " all_datasets.append(dataset)\n",
233
+ "\n",
234
+ " interleaved_dataset = interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)\n",
235
+ " return interleaved_dataset\n",
236
+ "\n",
237
+ "class SavePeftModelCallback(TrainerCallback):\n",
238
+ " def on_save(\n",
239
+ " self,\n",
240
+ " args: TrainingArguments,\n",
241
+ " state: TrainerState,\n",
242
+ " control: TrainerControl,\n",
243
+ " **kwargs,\n",
244
+ " ):\n",
245
+ " checkpoint_folder = os.path.join(args.output_dir, f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\")\n",
246
+ " peft_model_path = os.path.join(checkpoint_folder, \"adapter_model\")\n",
247
+ " kwargs[\"model\"].save_pretrained(peft_model_path)#, path_initial_model_for_weight_conversion=peft_model_path)\n",
248
+ " pytorch_model_path = os.path.join(checkpoint_folder, \"pytorch_model.bin\")\n",
249
+ " if os.path.exists(pytorch_model_path):\n",
250
+ " os.remove(pytorch_model_path)\n",
251
+ " return control\n",
252
+ " \n",
253
+ "class ShuffleCallback(TrainerCallback):\n",
254
+ " def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):\n",
255
+ " if isinstance(train_dataloader.dataset, IterableDatasetShard):\n",
256
+ " pass \n",
257
+ " elif isinstance(train_dataloader.dataset, IterableDataset):\n",
258
+ " # train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)\n",
259
+ " if int(os.environ[\"WORLD_SIZE\"]) == 1: \n",
260
+ " train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)\n",
261
+ " else:\n",
262
+ " train_dataloader.dataset.set_epoch(train_dataloader.dataset.epoch + 1)\n",
263
+ "\n",
264
+ "@dataclass\n",
265
+ "class DataCollatorSpeechSeq2SeqWithPadding:\n",
266
+ " processor: Any\n",
267
+ "\n",
268
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
269
+ " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
270
+ " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
271
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
272
+ " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
273
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
274
+ " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
275
+ " labels = labels[:, 1:]\n",
276
+ " batch[\"labels\"] = labels\n",
277
+ " return batch\n",
278
+ " \n",
279
+ "def make_inputs_require_grad(module, input, output):\n",
280
+ " output.requires_grad_(True)\n",
281
+ "\n",
282
+ "def prepare_dataset(batch):\n",
283
+ " audio = batch[\"audio\"]\n",
284
+ " #batch[\"input_features\"] = batch[\"input_features\"].to(dtype=torch.bfloat16)\n",
285
+ " batch[\"input_features\"] = processor.feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
286
+ " batch[\"audio_length\"] = len(audio[\"array\"]) / audio[\"sampling_rate\"]\n",
287
+ " \n",
288
+ " # if do_norm:\n",
289
+ " # batch[\"sentence\"] = neologdn.normalize(batch[\"sentence\"]).strip()\n",
290
+ " # batch[\"sentence\"] = normalizer(batch[\"sentence\"]).strip()\n",
291
+ " # batch[\"sentence\"] = wakati.parse(batch[\"sentence\"]).strip()\n",
292
+ " # batch[\"sentence\"] = re.sub(special_characters,'', batch[\"sentence\"]).strip()\n",
293
+ " \n",
294
+ " batch[\"labels\"] = processor.tokenizer(batch[\"sentence\"]).input_ids\n",
295
+ " return batch\n",
296
+ "\n",
297
+ "def augmented_speech(batch, augment):\n",
298
+ " samples = np.array(batch[\"speech\"])\n",
299
+ " batch[\"speech\"] = augment(samples=samples, sample_rate=16000)\n",
300
+ " batch[\"sampling_rate\"] = 16000\n",
301
+ " batch[\"target_text\"] = batch[\"target_text\"]\n",
302
+ " return batch\n",
303
+ "\n",
304
+ "from torch.utils.data import Dataset\n",
305
+ " \n",
306
+ "class ds(Dataset):\n",
307
+ " def __init__(self, X, y): #convert into PyTorch tensors and remember them\n",
308
+ " self.X = torch.tensor(X, dtype=torch.float32)\n",
309
+ " self.y = torch.tensor(y, dtype=torch.float32)\n",
310
+ " \n",
311
+ " def __len__(self): #this should return the size of the dataset\n",
312
+ " return len(self.X)\n",
313
+ " \n",
314
+ " def __getitem__(self, idx): #this should return one sample from the dataset\n",
315
+ " features = self.X[idx]\n",
316
+ " target = self.y[idx]\n",
317
+ " return features, target\n",
318
+ " \n",
319
+ "def normalize_transcriptions(batch):\n",
320
+ " transcription = batch[\"sentence\"]\n",
321
+ " if do_lower_case:\n",
322
+ " transcription = transcription.lower()\n",
323
+ " if do_remove_punctuation:\n",
324
+ " transcription = normalizer(transcription).strip()\n",
325
+ " if do_remove_special_characters:\n",
326
+ " transcription = re.sub(special_characters,'', transcription).strip()\n",
327
+ " if do_normalize_jp_neo:\n",
328
+ " transcription = neologdn.normalize(transcription).strip()\n",
329
+ " if do_normalize_basic:\n",
330
+ " transcription = normalizer(transcription).strip()\n",
331
+ " if do_normalize_jp:\n",
332
+ " transcription = wakati.parse(transcription).strip()\n",
333
+ " transcription = fullwidth_to_halfwidth(transcription) \n",
334
+ " batch[\"sentence\"] = transcription\n",
335
+ " return batch\n",
336
+ "\n",
337
+ "def norm_everything(batch):\n",
338
+ " batch[\"sentence\"] = neologdn.normalize(batch[\"sentence\"]).strip()\n",
339
+ " batch[\"sentence\"] = normalizer(batch[\"sentence\"]).strip()\n",
340
+ " batch[\"sentence\"] = wakati.parse(batch[\"sentence\"]).strip()\n",
341
+ " batch[\"sentence\"] = re.sub(special_characters,'', batch[\"sentence\"]).strip()\n",
342
+ " return batch\n",
343
+ "\n",
344
+ "def filter_length(audio_length):\n",
345
+ " return audio_length > min_audio_length and audio_length < max_audio_length\n",
346
+ "\n",
347
+ "def filter_labels(labels):\n",
348
+ " return min_label_length < len(labels) < max_label_length #len(labels) < max_label_length \n",
349
+ "\n",
350
+ "wakati = MeCab.Tagger(\"-Owakati\")\n",
351
+ "FULLWIDTH_TO_HALFWIDTH = str.maketrans(\n",
352
+ " ' 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!゛#$%&()*+、ー。/:;〈=〉?@[]^_‘{|}~',\n",
353
+ " ' 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&()*+,-./:;<=>?@[]^_`{|}~',\n",
354
+ " )\n",
355
+ "\n",
356
+ "def fullwidth_to_halfwidth(s):\n",
357
+ " s = s.translate(FULLWIDTH_TO_HALFWIDTH)\n",
358
+ " return wakati.parse(s)\n",
359
+ "\n",
360
+ "wer_metric = evaluate.load(\"wer\")\n",
361
+ "cer_metric = evaluate.load(\"cer\")\n",
362
+ "\n",
363
+ "def compute_metrics(pred):\n",
364
+ " \n",
365
+ " pred_ids = pred.predictions\n",
366
+ " label_ids = pred.label_ids\n",
367
+ " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n",
368
+ " pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)\n",
369
+ " label_str = processor.batch_decode(label_ids, skip_special_tokens=True)\n",
370
+ " \n",
371
+ " pred_str_norm_jp = [wakati.parse(pred) for pred in pred_str] #mecab normalizer\n",
372
+ " label_str_norm_jp = [wakati.parse(label) for label in label_str] #mecab normalizer\n",
373
+ " pred_str_norm_jp = [\n",
374
+ " pred_str_norm_jp[i] for i in range(len(pred_str_norm_jp)) if len(label_str_norm_jp[i]) > 0\n",
375
+ " ]\n",
376
+ " label_str_norm_jp = [\n",
377
+ " label_str_norm_jp[i]\n",
378
+ " for i in range(len(label_str_norm_jp))\n",
379
+ " if len(label_str_norm_jp[i]) > 0\n",
380
+ " ]\n",
381
+ " \n",
382
+ " pred_str_norm = [normalizer(pred) for pred in pred_str] #BasicTextNormalizer\n",
383
+ " label_str_norm = [normalizer(label) for label in label_str] #BasicTextNormalizer\n",
384
+ " pred_str_norm = [\n",
385
+ " pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0\n",
386
+ " ]\n",
387
+ " label_str_norm = [\n",
388
+ " label_str_norm[i]\n",
389
+ " for i in range(len(label_str_norm))\n",
390
+ " if len(label_str_norm[i]) > 0\n",
391
+ " ]\n",
392
+ "\n",
393
+ " wer_ortho = 100 * wer_metric.compute(predictions=pred_str, references=label_str) #No Normalizer\n",
394
+ " cer_ortho = 100 * cer_metric.compute(predictions=pred_str, references=label_str) #No Normalizer\n",
395
+ " wer = 100 * wer_metric.compute(predictions=pred_str_norm, references=label_str_norm) #BasicTextNormalizer\n",
396
+ " cer = 100 * cer_metric.compute(predictions=pred_str_norm, references=label_str_norm) #BasicTextNormalizer\n",
397
+ " cer_mecab = 100 * cer_metric.compute(predictions=pred_str_norm_jp, references=label_str_norm_jp) #mecab normalizer\n",
398
+ " \n",
399
+ " return {\"wer_ortho\": wer_ortho, \"wer\": wer, \"cer_ortho\": cer_ortho, \"cer\": cer, \"cer_mecab\": cer_mecab} "
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "code",
404
+ "execution_count": null,
405
+ "metadata": {},
406
+ "outputs": [],
407
+ "source": [
408
+ "wer_metric = evaluate.load(\"wer\")\n",
409
+ "cer_metric = evaluate.load(\"cer\")\n",
410
+ "\n",
411
+ "def compute_metrics(pred):\n",
412
+ " \n",
413
+ " pred_ids = pred.predictions\n",
414
+ " label_ids = pred.label_ids\n",
415
+ " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n",
416
+ " pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)\n",
417
+ " label_str = processor.batch_decode(label_ids, skip_special_tokens=True)\n",
418
+ " \n",
419
+ " pred_str_norm_jp = [wakati.parse(pred) for pred in pred_str] #mecab normalizer\n",
420
+ " label_str_norm_jp = [wakati.parse(label) for label in label_str] #mecab normalizer\n",
421
+ " pred_str_norm_jp = [\n",
422
+ " pred_str_norm_jp[i] for i in range(len(pred_str_norm_jp)) if len(label_str_norm_jp[i]) > 0\n",
423
+ " ]\n",
424
+ " label_str_norm_jp = [\n",
425
+ " label_str_norm_jp[i]\n",
426
+ " for i in range(len(label_str_norm_jp))\n",
427
+ " if len(label_str_norm_jp[i]) > 0\n",
428
+ " ]\n",
429
+ " \n",
430
+ " pred_str_norm = [normalizer(pred) for pred in pred_str] #BasicTextNormalizer\n",
431
+ " label_str_norm = [normalizer(label) for label in label_str] #BasicTextNormalizer\n",
432
+ " pred_str_norm = [\n",
433
+ " pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0\n",
434
+ " ]\n",
435
+ " label_str_norm = [\n",
436
+ " label_str_norm[i]\n",
437
+ " for i in range(len(label_str_norm))\n",
438
+ " if len(label_str_norm[i]) > 0\n",
439
+ " ]\n",
440
+ "\n",
441
+ " wer_ortho = 100 * wer_metric.compute(predictions=pred_str, references=label_str) #No Normalizer\n",
442
+ " cer_ortho = 100 * cer_metric.compute(predictions=pred_str, references=label_str) #No Normalizer\n",
443
+ " wer = 100 * wer_metric.compute(predictions=pred_str_norm, references=label_str_norm) #BasicTextNormalizer\n",
444
+ " cer = 100 * cer_metric.compute(predictions=pred_str_norm, references=label_str_norm) #BasicTextNormalizer\n",
445
+ " cer_mecab = 100 * cer_metric.compute(predictions=pred_str_norm_jp, references=label_str_norm_jp) #mecab normalizer\n",
446
+ " \n",
447
+ " return {\"wer_ortho\": wer_ortho, \"wer\": wer, \"cer_ortho\": cer_ortho, \"cer\": cer, \"cer_mecab\": cer_mecab} \n"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "metadata": {},
454
+ "outputs": [],
455
+ "source": [
456
+ "bnb_config = BitsAndBytesConfig(\n",
457
+ " load_in_4bit=False,\n",
458
+ " load_in_8bit=False,\n",
459
+ " bnb_4bit_quant_type=\"nf4\",\n",
460
+ " bnb_4bit_use_double_quant=False,\n",
461
+ " bnb_4bit_compute_dtype=torch.bfloat16,\n",
462
+ " )"
463
+ ]
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "execution_count": null,
468
+ "metadata": {},
469
+ "outputs": [],
470
+ "source": [
471
+ "# model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path)\n",
472
+ "# state_dict = model.state_dict() # slice first 1/2 embeddings (=15 seconds input audio)\n",
473
+ "# state_dict[\"model.encoder.embed_positions.weight\"] = state_dict[\"model.encoder.embed_positions.weight\"][:1500, :]\n",
474
+ "\n",
475
+ "# config = WhisperConfig.from_pretrained(\n",
476
+ "# model_name_or_path,\n",
477
+ "# #max_source_positions=1500,\n",
478
+ "# device_map=\"auto\",\n",
479
+ "# torch_dtype=\"auto\",#torch.bfloat16,#\"auto\",#torch.bfloat16,\n",
480
+ "# activation_function=\"gelu\",\n",
481
+ "# apply_spec_augment = True,\n",
482
+ "# add_cross_attention = True,\n",
483
+ "# use_cache = False,\n",
484
+ "# dropout = 0.1,\n",
485
+ "# )\n",
486
+ "# model = WhisperForConditionalGeneration(config)\n",
487
+ "\n",
488
+ "\n",
489
+ "model = WhisperForConditionalGeneration.from_pretrained(\n",
490
+ " model_name_or_path,\n",
491
+ " device_map=\"auto\",\n",
492
+ " torch_dtype=\"auto\",#torch.bfloat16,#\"auto\",#torch.bfloat16,\n",
493
+ " activation_function=\"gelu\",\n",
494
+ " apply_spec_augment = True,\n",
495
+ " add_cross_attention = True,\n",
496
+ " use_cache = False,\n",
497
+ " dropout = 0.1,\n",
498
+ " # encoder_attention_heads=16,\n",
499
+ " # decoder_attention_heads=16,\n",
500
+ " )\n",
501
+ "\n",
502
+ "model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)\n",
503
+ "#model.config.suppress_tokens = []\n",
504
+ "# model.config.forced_decoder_ids = None\n",
505
+ "# model.config.encoder_attention_heads = 16\n",
506
+ "# model.config.decoder_attention_heads = 16\n",
507
+ "\n",
508
+ "# model.config.suppress_tokens = []\n",
509
+ "# model.config.freeze_feature_encoder = True\n",
510
+ "# model.freeze_encoder()\n",
511
+ "# model.config.forced_decoder_ids = None\n",
512
+ "# model.generation_config.language = \"<|ja|>\"\n",
513
+ "# model.generation_config.task = \"transcribe\"\n",
514
+ "\n",
515
+ "# model.config.mask_time_prob=0.01\n",
516
+ "# model.config.mask_time_length=2\n",
517
+ "# model.config.mask_time_min_masks=2\n",
518
+ "# model.config.mask_feature_prob=0.01\n",
519
+ "# model.config.mask_feature_length=5\n",
520
+ "# model.config.mask_feature_min_masks=0\n",
521
+ "# model.config.median_filter_width=7\n",
522
+ "# model.config.attention_dropout = 0.01\n",
523
+ "# model.config.hidden_dropout = 0.1\n",
524
+ "# model.config.encoder_attention_heads = 24\n",
525
+ "# model.config.decoder_attention_heads = 12\n",
526
+ "# model.config.attention_dropout = 0.05\n",
527
+ "\n"
528
+ ]
529
+ },
530
+ {
531
+ "cell_type": "code",
532
+ "execution_count": null,
533
+ "metadata": {},
534
+ "outputs": [],
535
+ "source": [
536
+ "if use_peft:\n",
537
+ " \n",
538
+ " model = prepare_model_for_kbit_training(model) #quantization_config = QuantoConfig(weights=\"int8\")\n",
539
+ " \n",
540
+ " if use_adalora:\n",
541
+ " config = AdaLoraConfig(\n",
542
+ " peft_type=\"ADALORA\", \n",
543
+ " task_type=\"automatic-speech-recognition\",\n",
544
+ " init_r=16,\n",
545
+ " target_r=32,\n",
546
+ " beta1=0.75,\n",
547
+ " beta2=0.85,\n",
548
+ " tinit=0.0,\n",
549
+ " tfinal=0.0,\n",
550
+ " deltaT=0.0,\n",
551
+ " lora_alpha=64,\n",
552
+ " lora_dropout=0.01,\n",
553
+ " target_modules=\"all-linear\", # [\"k_proj\", \"q_proj\", \"v_proj\", \"out_proj\", \"fc1\", \"fc2\"],\n",
554
+ " orth_reg_weight=0.01,\n",
555
+ " ) \n",
556
+ " # elif use_loha:\n",
557
+ " # config = LoHaConfig(\n",
558
+ " # peft_type=\"loha\",\n",
559
+ " # task_type=\"automatic-speech-recognition\",\n",
560
+ " # r=32,\n",
561
+ " # lora_alpha=32,\n",
562
+ " # target_modules=\"all-linear\", # [\"k_proj\", \"q_proj\", \"v_proj\", \"out_proj\", \"fc1\", \"fc2\"],\n",
563
+ " # rank_dropout=0.0,\n",
564
+ " # module_dropout=0.0,\n",
565
+ " # init_weights=True,\n",
566
+ " # use_effective_conv2d=True,\n",
567
+ " # )\n",
568
+ " # elif use_lokr:\n",
569
+ " # config = LoKrConfig(\n",
570
+ " # task_type=\"automatic-speech-recognition\",\n",
571
+ " # r=32,\n",
572
+ " # lora_alpha=32,\n",
573
+ " # target_modules=\"all-linear\", # [\"k_proj\", \"q_proj\", \"v_proj\", \"out_proj\", \"fc1\", \"fc2\"],\n",
574
+ " # rank_dropout=0.0,\n",
575
+ " # module_dropout=0.0,\n",
576
+ " # init_weights=True,\n",
577
+ " # use_effective_conv2d=True,\n",
578
+ " # )\n",
579
+ " else:\n",
580
+ " config = LoraConfig(\n",
581
+ " task_type=\"automatic-speech-recognition\",\n",
582
+ " r=32,\n",
583
+ " lora_alpha=64,\n",
584
+ " target_modules=\"all-linear\",#[\"q_proj\", \"v_proj\", \"k_proj\"],\n",
585
+ " lora_dropout=0.1,\n",
586
+ " bias=\"none\",\n",
587
+ " # use_dora=True,\n",
588
+ " use_rslora=True,\n",
589
+ " init_lora_weights=\"pissa\",#_niter_16\"\n",
590
+ " )\n",
591
+ " \n",
592
+ " model = get_peft_model(model, config)\n",
593
+ " model.print_trainable_parameters()"
594
+ ]
595
+ },
596
+ {
597
+ "cell_type": "code",
598
+ "execution_count": null,
599
+ "metadata": {},
600
+ "outputs": [],
601
+ "source": [
602
+ "dataset_names = [\"\", \"\", \"\"] # example: [\"google/fleurs\", \"mozilla/common_voice_16\", \"sin2piusc/jsut_ver1.1\"]\n",
603
+ "dataset_config_names = [\"\", \"\", \"\"] # example: [\"default\", \"jp\", \"en\"]\n",
604
+ "splits = [\"\", \"\", \"\"] # example: [\"train\", \"train\", \"train\"]\n",
605
+ "text_column_names = [\"\", \"\", \"\"] # example: [\"transcription\", \"sentence\", \"sentence\"]\n",
606
+ "\n",
607
+ "ds = load_multiple_streaming_datasets(dataset_names, dataset_config_names=dataset_config_names, text_column_names=text_column_names, stopping_strategy=\"all_exhausted\", sampling_rate=16000, trust_remote_code=True)\n",
608
+ "\n",
609
+ "# if norm_everything:\n",
610
+ "# vectorized_dataset = ds.map(norm_everything)\n",
611
+ "\n",
612
+ "# ds = load_from_disk(dataset)\n",
613
+ "# vectorized_dataset = ds.map(prepare_dataset)\n"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": null,
619
+ "metadata": {},
620
+ "outputs": [],
621
+ "source": [
622
+ "max_audio_length = 15.0\n",
623
+ "min_audio_length = 1.0\n",
624
+ "max_label_length = model.config.max_length\n",
625
+ "min_label_length = 6 \n",
626
+ "\n",
627
+ "def filter_length(audio_length):\n",
628
+ " return audio_length > min_audio_length and audio_length < max_audio_length\n",
629
+ "\n",
630
+ "def filter_labels(labels):\n",
631
+ " return min_label_length < len(labels) < max_label_length\n",
632
+ "\n",
633
+ "if do_audio_filter:\n",
634
+ " vectorized_dataset = (vectorized_dataset\n",
635
+ " .filter(filter_length, input_columns=[\"audio_length\"])\n",
636
+ " .filter(filter_labels, input_columns=[\"labels\"])\n",
637
+ " )\n",
638
+ "\n",
639
+ "vectorized_dataset = (\n",
640
+ " vectorized_dataset\n",
641
+ " .remove_columns(\"audio_length\")\n",
642
+ " .remove_columns(\"sentence\")\n",
643
+ " .remove_columns(\"audio\")\n",
644
+ " )\n",
645
+ "\n",
646
+ "# vectorized_dataset = vectorized_dataset.shuffle(seed=42)\n",
647
+ "# vectorized_dataset_test = vectorized_dataset.take(500)\n"
648
+ ]
649
+ },
650
+ {
651
+ "cell_type": "code",
652
+ "execution_count": null,
653
+ "metadata": {},
654
+ "outputs": [],
655
+ "source": [
656
+ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)\n",
657
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
658
+ "torch.backends.cudnn.allow_tf32 = True\n",
659
+ "checkpointing_args = {\"use_reentrant\": False} # ,\"preserve_rng_state\": False, \"determinism_check\": \"none\"}"
660
+ ]
661
+ },
662
+ {
663
+ "cell_type": "code",
664
+ "execution_count": null,
665
+ "metadata": {},
666
+ "outputs": [],
667
+ "source": [
668
+ "model.save_pretrained(output_dir + \"/pretrained/\")\n",
669
+ "processor.save_pretrained(output_dir + \"/processor/\")\n",
670
+ "feature_extractor.save_pretrained(output_dir + \"/feature_extractor/\")\n",
671
+ "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path).save_pretrained(output_dir + \"/tokenizer/\")"
672
+ ]
673
+ },
674
+ {
675
+ "cell_type": "code",
676
+ "execution_count": null,
677
+ "metadata": {},
678
+ "outputs": [],
679
+ "source": [
680
+ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)\n",
681
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
682
+ "torch.backends.cudnn.allow_tf32 = True\n",
683
+ "checkpointing_args = {\"use_reentrant\": False} \n",
684
+ "\n",
685
+ "training_args = Seq2SeqTrainingArguments(\n",
686
+ " output_dir=output_dir,\n",
687
+ " overwrite_output_dir = False,\n",
688
+ " per_device_train_batch_size=2,\n",
689
+ " gradient_accumulation_steps=8,\n",
690
+ " eval_accumulation_steps=1,\n",
691
+ " per_device_eval_batch_size=2,\n",
692
+ " learning_rate=1.25e-5,\n",
693
+ " warmup_steps=200,\n",
694
+ " max_steps=1000,\n",
695
+ " gradient_checkpointing=True,\n",
696
+ " tf32=True, # bf16=True,#tf32=True,#bf16=True,# bf16_full_eval=True,#fp16_full_eval=False,\n",
697
+ " eval_strategy=\"steps\", # generation_max_length=150,\n",
698
+ " save_steps=100,\n",
699
+ " eval_steps=100,\n",
700
+ " logging_steps=50,\n",
701
+ " logging_dir=(output_dir + \"/logs\"),\n",
702
+ " logging_strategy=\"steps\",\n",
703
+ " logging_first_step=False,\n",
704
+ " log_level=\"critical\",\n",
705
+ " report_to=[\"tensorboard\"],\n",
706
+ " push_to_hub=False,\n",
707
+ " half_precision_backend=\"auto\",\n",
708
+ " hub_token=\"\",\n",
709
+ " remove_unused_columns=False,\n",
710
+ " label_names=[\"labels\"],\n",
711
+ " hub_private_repo=True,\n",
712
+ " optim=\"adafactor\", # optim=\"adafactor\", \n",
713
+ " weight_decay=0.05,\n",
714
+ " metric_for_best_model=\"cer\",\n",
715
+ " save_total_limit=5,\n",
716
+ " load_best_model_at_end=True,\n",
717
+ " predict_with_generate=True,\n",
718
+ " greater_is_better=True,\n",
719
+ " gradient_checkpointing_kwargs=checkpointing_args,\n",
720
+ " do_predict=True,\n",
721
+ " generation_max_length=128,\n",
722
+ " # dataloader_drop_last=True,\n",
723
+ " # dataloader_num_workers=4,\n",
724
+ " # dataloader_pin_memory=True,\n",
725
+ " # dataloader_persistent_workers=True,\n",
726
+ " restore_callback_states_from_checkpoint=True,\n",
727
+ " # max_grad_norm=0.99,\n",
728
+ " eval_on_start=False,\n",
729
+ " auto_find_batch_size=True,\n",
730
+ " ignore_data_skip=True,\n",
731
+ ")\n"
732
+ ]
733
+ },
734
+ {
735
+ "cell_type": "code",
736
+ "execution_count": null,
737
+ "metadata": {},
738
+ "outputs": [],
739
+ "source": [
740
+ "trainer = Seq2SeqTrainer(\n",
741
+ " args=training_args,\n",
742
+ " model=model,\n",
743
+ " train_dataset=vectorized_dataset,#[\"train\"],\n",
744
+ " eval_dataset=vectorized_dataset_test,#[\"test\"],\n",
745
+ " data_collator=data_collator,\n",
746
+ " tokenizer=processor.feature_extractor,\n",
747
+ " callbacks=[SavePeftModelCallback(),ShuffleCallback()],\n",
748
+ " compute_metrics=compute_metrics, \n",
749
+ " )\n",
750
+ "\n",
751
+ "trainer.train()#trainer.evaluate()#trainer.train(resume_from_checkpoint=True)"
752
+ ]
753
+ },
754
+ {
755
+ "cell_type": "code",
756
+ "execution_count": null,
757
+ "metadata": {},
758
+ "outputs": [],
759
+ "source": [
760
+ "#last evaluation\n",
761
+ "\n",
762
+ "eval_dataloader = DataLoader(vectorized_dataset[\"test\"], batch_size=1, collate_fn=data_collator)\n",
763
+ "model.eval()\n",
764
+ "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
765
+ " with torch.amp.autocast('cuda'):\n",
766
+ " with torch.no_grad():\n",
767
+ " generated_tokens = (\n",
768
+ " model.generate(\n",
769
+ " #language = \"japanese\",\n",
770
+ " input_features=batch[\"input_features\"].to(\"cuda\"),\n",
771
+ " decoder_input_ids=batch[\"labels\"][:, :4].to(\"cuda\"),\n",
772
+ " max_new_tokens=255,\n",
773
+ " )\n",
774
+ " .cpu()\n",
775
+ " .numpy()\n",
776
+ " )\n",
777
+ " labels = batch[\"labels\"].cpu().numpy()\n",
778
+ " labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)\n",
779
+ " decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
780
+ " decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
781
+ " metric.add_batch(\n",
782
+ " predictions=decoded_preds,\n",
783
+ " references=decoded_labels,\n",
784
+ " )\n",
785
+ " del generated_tokens, labels, batch\n",
786
+ " gc.collect()\n",
787
+ "cer = 100 * metric.compute()\n",
788
+ "print(f\"{cer=}\")"
789
+ ]
790
+ },
791
+ {
792
+ "cell_type": "code",
793
+ "execution_count": null,
794
+ "metadata": {},
795
+ "outputs": [],
796
+ "source": [
797
+ "trainer.push_to_hub()\n",
798
+ "trainer.save_model()\n",
799
+ "trainer.save_state()"
800
+ ]
801
+ },
802
+ {
803
+ "cell_type": "code",
804
+ "execution_count": null,
805
+ "metadata": {},
806
+ "outputs": [],
807
+ "source": [
808
+ "# ADAMW_HF = \"adamw_hf\"\n",
809
+ "# ADAMW_TORCH = \"adamw_torch\"\n",
810
+ "# ADAMW_TORCH_FUSED = \"adamw_torch_fused\"\n",
811
+ "# ADAMW_TORCH_XLA = \"adamw_torch_xla\"\n",
812
+ "# ADAMW_TORCH_NPU_FUSED = \"adamw_torch_npu_fused\"\n",
813
+ "# ADAMW_APEX_FUSED = \"adamw_apex_fused\"\n",
814
+ "# ADAFACTOR = \"adafactor\"\n",
815
+ "# ADAMW_ANYPRECISION = \"adamw_anyprecision\"\n",
816
+ "# SGD = \"sgd\"\n",
817
+ "# ADAGRAD = \"adagrad\"\n",
818
+ "# ADAMW_BNB = \"adamw_bnb_8bit\"\n",
819
+ "# ADAMW_8BIT = \"adamw_8bit\" # just an alias for adamw_bnb_8bit\n",
820
+ "# LION_8BIT = \"lion_8bit\"\n",
821
+ "# LION = \"lion_32bit\"\n",
822
+ "# PAGED_ADAMW = \"paged_adamw_32bit\"\n",
823
+ "# PAGED_ADAMW_8BIT = \"paged_adamw_8bit\"\n",
824
+ "# PAGED_LION = \"paged_lion_32bit\"\n",
825
+ "# PAGED_LION_8BIT = \"paged_lion_8bit\"\n",
826
+ "# RMSPROP = \"rmsprop\"\n",
827
+ "# RMSPROP_BNB = \"rmsprop_bnb\"\n",
828
+ "# RMSPROP_8BIT = \"rmsprop_bnb_8bit\"\n",
829
+ "# RMSPROP_32BIT = \"rmsprop_bnb_32bit\"\n",
830
+ "# GALORE_ADAMW = \"galore_adamw\"\n",
831
+ "# GALORE_ADAMW_8BIT = \"galore_adamw_8bit\"\n",
832
+ "# GALORE_ADAFACTOR = \"galore_adafactor\"\n",
833
+ "# GALORE_ADAMW_LAYERWISE = \"galore_adamw_layerwise\"\n",
834
+ "# GALORE_ADAMW_8BIT_LAYERWISE = \"galore_adamw_8bit_layerwise\"\n",
835
+ "# GALORE_ADAFACTOR_LAYERWISE = \"galore_adafactor_layerwise\"\n",
836
+ "# LOMO = \"lomo\"\n",
837
+ "# ADALOMO = \"adalomo\"\n",
838
+ "# TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop\n",
839
+ "# itself**.\n",
840
+ "\n",
841
+ "# Using [`HfArgumentParser`] we can turn this class into\n",
842
+ "# [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the\n",
843
+ "# command line.\n",
844
+ "\n",
845
+ "# Parameters:\n",
846
+ "# output_dir (`str`):\n",
847
+ "# The output directory where the model predictions and checkpoints will be written.\n",
848
+ "# overwrite_output_dir (`bool`, *optional*, defaults to `False`):\n",
849
+ "# If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`\n",
850
+ "# points to a checkpoint directory.\n",
851
+ "# do_train (`bool`, *optional*, defaults to `False`):\n",
852
+ "# Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used\n",
853
+ "# by your training/evaluation scripts instead. See the [example\n",
854
+ "# scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.\n",
855
+ "# do_eval (`bool`, *optional*):\n",
856
+ "# Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is\n",
857
+ "# different from `\"no\"`. This argument is not directly used by [`Trainer`], it's intended to be used by your\n",
858
+ "# training/evaluation scripts instead. See the [example\n",
859
+ "# scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.\n",
860
+ "# do_predict (`bool`, *optional*, defaults to `False`):\n",
861
+ "# Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's\n",
862
+ "# intended to be used by your training/evaluation scripts instead. See the [example\n",
863
+ "# scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.\n",
864
+ "# eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `\"no\"`):\n",
865
+ "# The evaluation strategy to adopt during training. Possible values are:\n",
866
+ "\n",
867
+ "# - `\"no\"`: No evaluation is done during training.\n",
868
+ "# - `\"steps\"`: Evaluation is done (and logged) every `eval_steps`.\n",
869
+ "# - `\"epoch\"`: Evaluation is done at the end of each epoch.\n",
870
+ "\n",
871
+ "# prediction_loss_only (`bool`, *optional*, defaults to `False`):\n",
872
+ "# When performing evaluation and generating predictions, only returns the loss.\n",
873
+ "# per_device_train_batch_size (`int`, *optional*, defaults to 8):\n",
874
+ "# The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training.\n",
875
+ "# per_device_eval_batch_size (`int`, *optional*, defaults to 8):\n",
876
+ "# The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.\n",
877
+ "# gradient_accumulation_steps (`int`, *optional*, defaults to 1):\n",
878
+ "# Number of updates steps to accumulate the gradients for, before performing a backward/update pass.\n",
879
+ "\n",
880
+ "# <Tip warning={true}>\n",
881
+ "\n",
882
+ "# When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,\n",
883
+ "# evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.\n",
884
+ "\n",
885
+ "# </Tip>\n",
886
+ "\n",
887
+ "# eval_accumulation_steps (`int`, *optional*):\n",
888
+ "# Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If\n",
889
+ "# left unset, the whole predictions are accumulated on GPU/NPU/TPU before being moved to the CPU (faster but\n",
890
+ "# requires more memory).\n",
891
+ "# eval_delay (`float`, *optional*):\n",
892
+ "# Number of epochs or steps to wait for before the first evaluation can be performed, depending on the\n",
893
+ "# eval_strategy.\n",
894
+ "# torch_empty_cache_steps (`int`, *optional*):\n",
895
+ "# Number of steps to wait before calling `torch.<device>.empty_cache()`. If left unset or set to None, cache will not be emptied.\n",
896
+ "\n",
897
+ "# <Tip>\n",
898
+ "\n",
899
+ "# This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372).\n",
900
+ "\n",
901
+ "# </Tip>\n",
902
+ "\n",
903
+ "# learning_rate (`float`, *optional*, defaults to 5e-5):\n",
904
+ "# The initial learning rate for [`AdamW`] optimizer.\n",
905
+ "# weight_decay (`float`, *optional*, defaults to 0):\n",
906
+ "# The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`]\n",
907
+ "# optimizer.\n",
908
+ "# adam_beta1 (`float`, *optional*, defaults to 0.9):\n",
909
+ "# The beta1 hyperparameter for the [`AdamW`] optimizer.\n",
910
+ "# adam_beta2 (`float`, *optional*, defaults to 0.999):\n",
911
+ "# The beta2 hyperparameter for the [`AdamW`] optimizer.\n",
912
+ "# adam_epsilon (`float`, *optional*, defaults to 1e-8):\n",
913
+ "# The epsilon hyperparameter for the [`AdamW`] optimizer.\n",
914
+ "# max_grad_norm (`float`, *optional*, defaults to 1.0):\n",
915
+ "# Maximum gradient norm (for gradient clipping).\n",
916
+ "# num_train_epochs(`float`, *optional*, defaults to 3.0):\n",
917
+ "# Total number of training epochs to perform (if not an integer, will perform the decimal part percents of\n",
918
+ "# the last epoch before stopping training).\n",
919
+ "# max_steps (`int`, *optional*, defaults to -1):\n",
920
+ "# If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.\n",
921
+ "# For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until\n",
922
+ "# `max_steps` is reached.\n",
923
+ "# lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `\"linear\"`):\n",
924
+ "# The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.\n",
925
+ "# lr_scheduler_kwargs ('dict', *optional*, defaults to {}):\n",
926
+ "# The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values.\n",
927
+ "# warmup_ratio (`float`, *optional*, defaults to 0.0):\n",
928
+ "# Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.\n",
929
+ "# warmup_steps (`int`, *optional*, defaults to 0):\n",
930
+ "# Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.\n",
931
+ "# log_level (`str`, *optional*, defaults to `passive`):\n",
932
+ "# Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug',\n",
933
+ "# 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and keeps the\n",
934
+ "# current log level for the Transformers library (which will be `\"warning\"` by default).\n",
935
+ "# log_level_replica (`str`, *optional*, defaults to `\"warning\"`):\n",
936
+ "# Logger log level to use on replicas. Same choices as `log_level`\"\n",
937
+ "# log_on_each_node (`bool`, *optional*, defaults to `True`):\n",
938
+ "# In multinode distributed training, whether to log using `log_level` once per node, or only on the main\n",
939
+ "# node.\n",
940
+ "# logging_dir (`str`, *optional*):\n",
941
+ "# [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\n",
942
+ "# *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\n",
943
+ "# logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `\"steps\"`):\n",
944
+ "# The logging strategy to adopt during training. Possible values are:\n",
945
+ "\n",
946
+ "# - `\"no\"`: No logging is done during training.\n",
947
+ "# - `\"epoch\"`: Logging is done at the end of each epoch.\n",
948
+ "# - `\"steps\"`: Logging is done every `logging_steps`.\n",
949
+ "\n",
950
+ "# logging_first_step (`bool`, *optional*, defaults to `False`):\n",
951
+ "# Whether to log the first `global_step` or not.\n",
952
+ "# logging_steps (`int` or `float`, *optional*, defaults to 500):\n",
953
+ "# Number of update steps between two logs if `logging_strategy=\"steps\"`. Should be an integer or a float in\n",
954
+ "# range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.\n",
955
+ "# logging_nan_inf_filter (`bool`, *optional*, defaults to `True`):\n",
956
+ "# Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is `nan`\n",
957
+ "# or `inf` is filtered and the average loss of the current logging window is taken instead.\n",
958
+ "\n",
959
+ "# <Tip>\n",
960
+ "\n",
961
+ "# `logging_nan_inf_filter` only influences the logging of loss values, it does not change the behavior the\n",
962
+ "# gradient is computed or applied to the model.\n",
963
+ "\n",
964
+ "# </Tip>\n",
965
+ "\n",
966
+ "# save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `\"steps\"`):\n",
967
+ "# The checkpoint save strategy to adopt during training. Possible values are:\n",
968
+ "\n",
969
+ "# - `\"no\"`: No save is done during training.\n",
970
+ "# - `\"epoch\"`: Save is done at the end of each epoch.\n",
971
+ "# - `\"steps\"`: Save is done every `save_steps`.\n",
972
+ "\n",
973
+ "# If `\"epoch\"` or `\"steps\"` is chosen, saving will also be performed at the\n",
974
+ "# very end of training, always.\n",
975
+ "# save_steps (`int` or `float`, *optional*, defaults to 500):\n",
976
+ "# Number of updates steps before two checkpoint saves if `save_strategy=\"steps\"`. Should be an integer or a\n",
977
+ "# float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.\n",
978
+ "# save_total_limit (`int`, *optional*):\n",
979
+ "# If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in\n",
980
+ "# `output_dir`. When `load_best_model_at_end` is enabled, the \"best\" checkpoint according to\n",
981
+ "# `metric_for_best_model` will always be retained in addition to the most recent ones. For example, for\n",
982
+ "# `save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained\n",
983
+ "# alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two\n",
984
+ "# checkpoints are saved: the last one and the best one (if they are different).\n",
985
+ "# save_safetensors (`bool`, *optional*, defaults to `True`):\n",
986
+ "# Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of\n",
987
+ "# default `torch.load` and `torch.save`.\n",
988
+ "# save_on_each_node (`bool`, *optional*, defaults to `False`):\n",
989
+ "# When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on\n",
990
+ "# the main one.\n",
991
+ "\n",
992
+ "# This should not be activated when the different nodes use the same storage as the files will be saved with\n",
993
+ "# the same names for each node.\n",
994
+ "# save_only_model (`bool`, *optional*, defaults to `False`):\n",
995
+ "# When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.\n",
996
+ "# Note that when this is true, you won't be able to resume training from checkpoint.\n",
997
+ "# This enables you to save storage by not storing the optimizer, scheduler & rng state.\n",
998
+ "# You can only load the model using `from_pretrained` with this option set to `True`.\n",
999
+ "# restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):\n",
1000
+ "# Whether to restore the callback states from the checkpoint. If `True`, will override\n",
1001
+ "# callbacks passed to the `Trainer` if they exist in the checkpoint.\"\n",
1002
+ "# use_cpu (`bool`, *optional*, defaults to `False`):\n",
1003
+ "# Whether or not to use cpu. If set to False, we will use cuda or mps device if available.\n",
1004
+ "# seed (`int`, *optional*, defaults to 42):\n",
1005
+ "# Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the\n",
1006
+ "# [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters.\n",
1007
+ "# data_seed (`int`, *optional*):\n",
1008
+ "# Random seed to be used with data samplers. If not set, random generators for data sampling will use the\n",
1009
+ "# same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model\n",
1010
+ "# seed.\n",
1011
+ "# jit_mode_eval (`bool`, *optional*, defaults to `False`):\n",
1012
+ "# Whether or not to use PyTorch jit trace for inference.\n",
1013
+ "# use_ipex (`bool`, *optional*, defaults to `False`):\n",
1014
+ "# Use Intel extension for PyTorch when it is available. [IPEX\n",
1015
+ "# installation](https://github.com/intel/intel-extension-for-pytorch).\n",
1016
+ "# bf16 (`bool`, *optional*, defaults to `False`):\n",
1017
+ "# Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher\n",
1018
+ "# NVIDIA architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change.\n",
1019
+ "# fp16 (`bool`, *optional*, defaults to `False`):\n",
1020
+ "# Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.\n",
1021
+ "# fp16_opt_level (`str`, *optional*, defaults to 'O1'):\n",
1022
+ "# For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on\n",
1023
+ "# the [Apex documentation](https://nvidia.github.io/apex/amp).\n",
1024
+ "# fp16_backend (`str`, *optional*, defaults to `\"auto\"`):\n",
1025
+ "# This argument is deprecated. Use `half_precision_backend` instead.\n",
1026
+ "# half_precision_backend (`str`, *optional*, defaults to `\"auto\"`):\n",
1027
+ "# The backend to use for mixed precision training. Must be one of `\"auto\", \"apex\", \"cpu_amp\"`. `\"auto\"` will\n",
1028
+ "# use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the\n",
1029
+ "# requested backend.\n",
1030
+ "# bf16_full_eval (`bool`, *optional*, defaults to `False`):\n",
1031
+ "# Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm\n",
1032
+ "# metric values. This is an experimental API and it may change.\n",
1033
+ "# fp16_full_eval (`bool`, *optional*, defaults to `False`):\n",
1034
+ "# Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm\n",
1035
+ "# metric values.\n",
1036
+ "# tf32 (`bool`, *optional*):\n",
1037
+ "# Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends\n",
1038
+ "# on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to\n",
1039
+ "# the [TF32](https://huggingface.co/docs/transformers/performance#tf32) documentation. This is an\n",
1040
+ "# experimental API and it may change.\n",
1041
+ "# local_rank (`int`, *optional*, defaults to -1):\n",
1042
+ "# Rank of the process during distributed training.\n",
1043
+ "# ddp_backend (`str`, *optional*):\n",
1044
+ "# The backend to use for distributed training. Must be one of `\"nccl\"`, `\"mpi\"`, `\"ccl\"`, `\"gloo\"`, `\"hccl\"`.\n",
1045
+ "# tpu_num_cores (`int`, *optional*):\n",
1046
+ "# When training on TPU, the number of TPU cores (automatically passed by launcher script).\n",
1047
+ "# dataloader_drop_last (`bool`, *optional*, defaults to `False`):\n",
1048
+ "# Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)\n",
1049
+ "# or not.\n",
1050
+ "# eval_steps (`int` or `float`, *optional*):\n",
1051
+ "# Number of update steps between two evaluations if `eval_strategy=\"steps\"`. Will default to the same\n",
1052
+ "# value as `logging_steps` if not set. Should be an integer or a float in range `[0,1)`. If smaller than 1,\n",
1053
+ "# will be interpreted as ratio of total training steps.\n",
1054
+ "# dataloader_num_workers (`int`, *optional*, defaults to 0):\n",
1055
+ "# Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the\n",
1056
+ "# main process.\n",
1057
+ "# past_index (`int`, *optional*, defaults to -1):\n",
1058
+ "# Some models like [TransformerXL](../model_doc/transformerxl) or [XLNet](../model_doc/xlnet) can make use of\n",
1059
+ "# the past hidden states for their predictions. If this argument is set to a positive int, the `Trainer` will\n",
1060
+ "# use the corresponding output (usually index 2) as the past state and feed it to the model at the next\n",
1061
+ "# training step under the keyword argument `mems`.\n",
1062
+ "# run_name (`str`, *optional*, defaults to `output_dir`):\n",
1063
+ "# A descriptor for the run. Typically used for [wandb](https://www.wandb.com/),\n",
1064
+ "# [mlflow](https://www.mlflow.org/) and [comet](https://www.comet.com/site) logging. If not specified, will\n",
1065
+ "# be the same as `output_dir`.\n",
1066
+ "# disable_tqdm (`bool`, *optional*):\n",
1067
+ "# Whether or not to disable the tqdm progress bars and table of metrics produced by\n",
1068
+ "# [`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is\n",
1069
+ "# set to warn or lower (default), `False` otherwise.\n",
1070
+ "# remove_unused_columns (`bool`, *optional*, defaults to `True`):\n",
1071
+ "# Whether or not to automatically remove the columns unused by the model forward method.\n",
1072
+ "# label_names (`List[str]`, *optional*):\n",
1073
+ "# The list of keys in your dictionary of inputs that correspond to the labels.\n",
1074
+ "\n",
1075
+ "# Will eventually default to the list of argument names accepted by the model that contain the word \"label\",\n",
1076
+ "# except if the model used is one of the `XxxForQuestionAnswering` in which case it will also include the\n",
1077
+ "# `[\"start_positions\", \"end_positions\"]` keys.\n",
1078
+ "# load_best_model_at_end (`bool`, *optional*, defaults to `False`):\n",
1079
+ "# Whether or not to load the best model found during training at the end of training. When this option is\n",
1080
+ "# enabled, the best checkpoint will always be saved. See\n",
1081
+ "# [`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit)\n",
1082
+ "# for more.\n",
1083
+ "\n",
1084
+ "# <Tip>\n",
1085
+ "\n",
1086
+ "# When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in\n",
1087
+ "# the case it is \"steps\", `save_steps` must be a round multiple of `eval_steps`.\n",
1088
+ "\n",
1089
+ "# </Tip>\n",
1090
+ "\n",
1091
+ "# metric_for_best_model (`str`, *optional*):\n",
1092
+ "# Use in conjunction with `load_best_model_at_end` to specify the metric to use to compare two different\n",
1093
+ "# models. Must be the name of a metric returned by the evaluation with or without the prefix `\"eval_\"`. Will\n",
1094
+ "# default to `\"loss\"` if unspecified and `load_best_model_at_end=True` (to use the evaluation loss).\n",
1095
+ "\n",
1096
+ "# If you set this value, `greater_is_better` will default to `True`. Don't forget to set it to `False` if\n",
1097
+ "# your metric is better when lower.\n",
1098
+ "# greater_is_better (`bool`, *optional*):\n",
1099
+ "# Use in conjunction with `load_best_model_at_end` and `metric_for_best_model` to specify if better models\n",
1100
+ "# should have a greater metric or not. Will default to:\n",
1101
+ "\n",
1102
+ "# - `True` if `metric_for_best_model` is set to a value that doesn't end in `\"loss\"`.\n",
1103
+ "# - `False` if `metric_for_best_model` is not set, or set to a value that ends in `\"loss\"`.\n",
1104
+ "# ignore_data_skip (`bool`, *optional*, defaults to `False`):\n",
1105
+ "# When resuming training, whether or not to skip the epochs and batches to get the data loading at the same\n",
1106
+ "# stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step\n",
1107
+ "# can take a long time) but will not yield the same results as the interrupted training would have.\n",
1108
+ "# fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`):\n",
1109
+ "# Use PyTorch Distributed Parallel Training (in distributed training only).\n",
1110
+ "\n",
1111
+ "# A list of options along the following:\n",
1112
+ "\n",
1113
+ "# - `\"full_shard\"`: Shard parameters, gradients and optimizer states.\n",
1114
+ "# - `\"shard_grad_op\"`: Shard optimizer states and gradients.\n",
1115
+ "# - `\"hybrid_shard\"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes.\n",
1116
+ "# - `\"hybrid_shard_zero2\"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes.\n",
1117
+ "# - `\"offload\"`: Offload parameters and gradients to CPUs (only compatible with `\"full_shard\"` and\n",
1118
+ "# `\"shard_grad_op\"`).\n",
1119
+ "# - `\"auto_wrap\"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.\n",
1120
+ "# fsdp_config (`str` or `dict`, *optional*):\n",
1121
+ "# Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of\n",
1122
+ "# fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.\n",
1123
+ "\n",
1124
+ "# A List of config and its options:\n",
1125
+ "# - min_num_params (`int`, *optional*, defaults to `0`):\n",
1126
+ "# FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is\n",
1127
+ "# passed).\n",
1128
+ "# - transformer_layer_cls_to_wrap (`List[str]`, *optional*):\n",
1129
+ "# List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,\n",
1130
+ "# `T5Block` .... (useful only when `fsdp` flag is passed).\n",
1131
+ "# - backward_prefetch (`str`, *optional*)\n",
1132
+ "# FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when\n",
1133
+ "# `fsdp` field is passed).\n",
1134
+ "\n",
1135
+ "# A list of options along the following:\n",
1136
+ "\n",
1137
+ "# - `\"backward_pre\"` : Prefetches the next set of parameters before the current set of parameter's\n",
1138
+ "# gradient\n",
1139
+ "# computation.\n",
1140
+ "# - `\"backward_post\"` : This prefetches the next set of parameters after the current set of\n",
1141
+ "# parameter’s\n",
1142
+ "# gradient computation.\n",
1143
+ "# - forward_prefetch (`bool`, *optional*, defaults to `False`)\n",
1144
+ "# FSDP's forward prefetch mode (useful only when `fsdp` field is passed).\n",
1145
+ "# If `\"True\"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the\n",
1146
+ "# forward pass.\n",
1147
+ "# - limit_all_gathers (`bool`, *optional*, defaults to `False`)\n",
1148
+ "# FSDP's limit_all_gathers (useful only when `fsdp` field is passed).\n",
1149
+ "# If `\"True\"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight\n",
1150
+ "# all-gathers.\n",
1151
+ "# - use_orig_params (`bool`, *optional*, defaults to `True`)\n",
1152
+ "# If `\"True\"`, allows non-uniform `requires_grad` during init, which means support for interspersed\n",
1153
+ "# frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please\n",
1154
+ "# refer this\n",
1155
+ "# [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019\n",
1156
+ "# - sync_module_states (`bool`, *optional*, defaults to `True`)\n",
1157
+ "# If `\"True\"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to\n",
1158
+ "# ensure they are the same across all ranks after initialization\n",
1159
+ "# - cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`)\n",
1160
+ "# If `\"True\"`, only the first process loads the pretrained model checkpoint while all other processes\n",
1161
+ "# have empty weights. When this setting as `\"True\"`, `sync_module_states` also must to be `\"True\"`,\n",
1162
+ "# otherwise all the processes except the main process would have random weights leading to unexpected\n",
1163
+ "# behaviour during training.\n",
1164
+ "# - activation_checkpointing (`bool`, *optional*, defaults to `False`):\n",
1165
+ "# If `\"True\"`, activation checkpointing is a technique to reduce memory usage by clearing activations of\n",
1166
+ "# certain layers and recomputing them during a backward pass. Effectively, this trades extra\n",
1167
+ "# computation time for reduced memory usage.\n",
1168
+ "# - xla (`bool`, *optional*, defaults to `False`):\n",
1169
+ "# Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature\n",
1170
+ "# and its API may evolve in the future.\n",
1171
+ "# - xla_fsdp_settings (`dict`, *optional*)\n",
1172
+ "# The value is a dictionary which stores the XLA FSDP wrapping parameters.\n",
1173
+ "\n",
1174
+ "# For a complete list of options, please see [here](\n",
1175
+ "# https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py).\n",
1176
+ "# - xla_fsdp_grad_ckpt (`bool`, *optional*, defaults to `False`):\n",
1177
+ "# Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be\n",
1178
+ "# used when the xla flag is set to true, and an auto wrapping policy is specified through\n",
1179
+ "# fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.\n",
1180
+ "\n",
1181
+ "# deepspeed (`str` or `dict`, *optional*):\n",
1182
+ "# Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may\n",
1183
+ "# evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,\n",
1184
+ "# `ds_config.json`) or an already loaded json file as a `dict`\"\n",
1185
+ "\n",
1186
+ "# <Tip warning={true}>\n",
1187
+ "# If enabling any Zero-init, make sure that your model is not initialized until\n",
1188
+ "# *after* initializing the `TrainingArguments`, else it will not be applied.\n",
1189
+ "# </Tip>\n",
1190
+ "\n",
1191
+ "# accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*):\n",
1192
+ "# Config to be used with the internal `Accelerator` implementation. The value is either a location of\n",
1193
+ "# accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`,\n",
1194
+ "# or an instance of [`~trainer_pt_utils.AcceleratorConfig`].\n",
1195
+ "\n",
1196
+ "# A list of config and its options:\n",
1197
+ "# - split_batches (`bool`, *optional*, defaults to `False`):\n",
1198
+ "# Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If\n",
1199
+ "# `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a\n",
1200
+ "# round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set\n",
1201
+ "# in your script multiplied by the number of processes.\n",
1202
+ "# - dispatch_batches (`bool`, *optional*):\n",
1203
+ "# If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process\n",
1204
+ "# and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose\n",
1205
+ "# underlying dataset is an `IterableDataset`, `False` otherwise.\n",
1206
+ "# - even_batches (`bool`, *optional*, defaults to `True`):\n",
1207
+ "# If set to `True`, in cases where the total batch size across all processes does not exactly divide the\n",
1208
+ "# dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among\n",
1209
+ "# all workers.\n",
1210
+ "# - use_seedable_sampler (`bool`, *optional*, defaults to `True`):\n",
1211
+ "# Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures\n",
1212
+ "# training results are fully reproducable using a different sampling technique. While seed-to-seed results\n",
1213
+ "# may differ, on average the differences are neglible when using multiple different seeds to compare. Should\n",
1214
+ "# also be ran with [`~utils.set_seed`] for the best results.\n",
1215
+ "# - use_configured_state (`bool`, *optional*, defaults to `False`):\n",
1216
+ "# Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.\n",
1217
+ "# If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues\n",
1218
+ "# with hyperparameter tuning.\n",
1219
+ "\n",
1220
+ "# label_smoothing_factor (`float`, *optional*, defaults to 0.0):\n",
1221
+ "# The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded\n",
1222
+ "# labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +\n",
1223
+ "# label_smoothing_factor/num_labels` respectively.\n",
1224
+ "# debug (`str` or list of [`~debug_utils.DebugOption`], *optional*, defaults to `\"\"`):\n",
1225
+ "# Enable one or more debug features. This is an experimental feature.\n",
1226
+ "\n",
1227
+ "# Possible options are:\n",
1228
+ "\n",
1229
+ "# - `\"underflow_overflow\"`: detects overflow in model's input/outputs and reports the last frames that led to\n",
1230
+ "# the event\n",
1231
+ "# - `\"tpu_metrics_debug\"`: print debug metrics on TPU\n",
1232
+ "\n",
1233
+ "# The options should be separated by whitespaces.\n",
1234
+ "# optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `\"adamw_torch\"`):\n",
1235
+ "# The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or\n",
1236
+ "# adafactor.\n",
1237
+ "# optim_args (`str`, *optional*):\n",
1238
+ "# Optional arguments that are supplied to AnyPrecisionAdamW.\n",
1239
+ "# group_by_length (`bool`, *optional*, defaults to `False`):\n",
1240
+ "# Whether or not to group together samples of roughly the same length in the training dataset (to minimize\n",
1241
+ "# padding applied and be more efficient). Only useful if applying dynamic padding.\n",
1242
+ "# length_column_name (`str`, *optional*, defaults to `\"length\"`):\n",
1243
+ "# Column name for precomputed lengths. If the column exists, grouping by length will use these values rather\n",
1244
+ "# than computing them on train startup. Ignored unless `group_by_length` is `True` and the dataset is an\n",
1245
+ "# instance of `Dataset`.\n",
1246
+ "# report_to (`str` or `List[str]`, *optional*, defaults to `\"all\"`):\n",
1247
+ "# The list of integrations to report the results and logs to. Supported platforms are `\"azure_ml\"`,\n",
1248
+ "# `\"clearml\"`, `\"codecarbon\"`, `\"comet_ml\"`, `\"dagshub\"`, `\"dvclive\"`, `\"flyte\"`, `\"mlflow\"`, `\"neptune\"`,\n",
1249
+ "# `\"tensorboard\"`, and `\"wandb\"`. Use `\"all\"` to report to all integrations installed, `\"none\"` for no\n",
1250
+ "# integrations.\n",
1251
+ "# ddp_find_unused_parameters (`bool`, *optional*):\n",
1252
+ "# When using distributed training, the value of the flag `find_unused_parameters` passed to\n",
1253
+ "# `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.\n",
1254
+ "# ddp_bucket_cap_mb (`int`, *optional*):\n",
1255
+ "# When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`.\n",
1256
+ "# ddp_broadcast_buffers (`bool`, *optional*):\n",
1257
+ "# When using distributed training, the value of the flag `broadcast_buffers` passed to\n",
1258
+ "# `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.\n",
1259
+ "# dataloader_pin_memory (`bool`, *optional*, defaults to `True`):\n",
1260
+ "# Whether you want to pin memory in data loaders or not. Will default to `True`.\n",
1261
+ "# dataloader_persistent_workers (`bool`, *optional*, defaults to `False`):\n",
1262
+ "# If True, the data loader will not shut down the worker processes after a dataset has been consumed once.\n",
1263
+ "# This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will\n",
1264
+ "# increase RAM usage. Will default to `False`.\n",
1265
+ "# dataloader_prefetch_factor (`int`, *optional*):\n",
1266
+ "# Number of batches loaded in advance by each worker.\n",
1267
+ "# 2 means there will be a total of 2 * num_workers batches prefetched across all workers.\n",
1268
+ "# skip_memory_metrics (`bool`, *optional*, defaults to `True`):\n",
1269
+ "# Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows\n",
1270
+ "# down the training and evaluation speed.\n",
1271
+ "# push_to_hub (`bool`, *optional*, defaults to `False`):\n",
1272
+ "# Whether or not to push the model to the Hub every time the model is saved. If this is activated,\n",
1273
+ "# `output_dir` will begin a git directory synced with the repo (determined by `hub_model_id`) and the content\n",
1274
+ "# will be pushed each time a save is triggered (depending on your `save_strategy`). Calling\n",
1275
+ "# [`~Trainer.save_model`] will also trigger a push.\n",
1276
+ "\n",
1277
+ "# <Tip warning={true}>\n",
1278
+ "\n",
1279
+ "# If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be\n",
1280
+ "# pushed.\n",
1281
+ "\n",
1282
+ "# </Tip>\n",
1283
+ "\n",
1284
+ "# resume_from_checkpoint (`str`, *optional*):\n",
1285
+ "# The path to a folder with a valid checkpoint for your model. This argument is not directly used by\n",
1286
+ "# [`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example\n",
1287
+ "# scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.\n",
1288
+ "# hub_model_id (`str`, *optional*):\n",
1289
+ "# The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in\n",
1290
+ "# which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,\n",
1291
+ "# for instance `\"user_name/model\"`, which allows you to push to an organization you are a member of with\n",
1292
+ "# `\"organization_name/model\"`. Will default to `user_name/output_dir_name` with *output_dir_name* being the\n",
1293
+ "# name of `output_dir`.\n",
1294
+ "\n",
1295
+ "# Will default to the name of `output_dir`.\n",
1296
+ "# hub_strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `\"every_save\"`):\n",
1297
+ "# Defines the scope of what is pushed to the Hub and when. Possible values are:\n",
1298
+ "\n",
1299
+ "# - `\"end\"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and a\n",
1300
+ "# draft of a model card when the [`~Trainer.save_model`] method is called.\n",
1301
+ "# - `\"every_save\"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and\n",
1302
+ "# a draft of a model card each time there is a model save. The pushes are asynchronous to not block\n",
1303
+ "# training, and in case the save are very frequent, a new push is only attempted if the previous one is\n",
1304
+ "# finished. A last push is made with the final model at the end of training.\n",
1305
+ "# - `\"checkpoint\"`: like `\"every_save\"` but the latest checkpoint is also pushed in a subfolder named\n",
1306
+ "# last-checkpoint, allowing you to resume training easily with\n",
1307
+ "# `trainer.train(resume_from_checkpoint=\"last-checkpoint\")`.\n",
1308
+ "# - `\"all_checkpoints\"`: like `\"checkpoint\"` but all checkpoints are pushed like they appear in the output\n",
1309
+ "# folder (so you will get one checkpoint folder per folder in your final repository)\n",
1310
+ "\n",
1311
+ "# hub_token (`str`, *optional*):\n",
1312
+ "# The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with\n",
1313
+ "# `huggingface-cli login`.\n",
1314
+ "# hub_private_repo (`bool`, *optional*, defaults to `False`):\n",
1315
+ "# If True, the Hub repo will be set to private.\n",
1316
+ "# hub_always_push (`bool`, *optional*, defaults to `False`):\n",
1317
+ "# Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.\n",
1318
+ "# gradient_checkpointing (`bool`, *optional*, defaults to `False`):\n",
1319
+ "# If True, use gradient checkpointing to save memory at the expense of slower backward pass.\n",
1320
+ "# gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):\n",
1321
+ "# Key word arguments to be passed to the `gradient_checkpointing_enable` method.\n",
1322
+ "# include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):\n",
1323
+ "# Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics\n",
1324
+ "# that need inputs, predictions and references for scoring calculation in Metric class.\n",
1325
+ "# eval_do_concat_batches (`bool`, *optional*, defaults to `True`):\n",
1326
+ "# Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,\n",
1327
+ "# will instead store them as lists, with each batch kept separate.\n",
1328
+ "# auto_find_batch_size (`bool`, *optional*, defaults to `False`)\n",
1329
+ "# Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding\n",
1330
+ "# CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)\n",
1331
+ "# full_determinism (`bool`, *optional*, defaults to `False`)\n",
1332
+ "# If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in\n",
1333
+ "# distributed training. Important: this will negatively impact the performance, so only use it for debugging.\n",
1334
+ "# torchdynamo (`str`, *optional*):\n",
1335
+ "# If set, the backend compiler for TorchDynamo. Possible choices are `\"eager\"`, `\"aot_eager\"`, `\"inductor\"`,\n",
1336
+ "# `\"nvfuser\"`, `\"aot_nvfuser\"`, `\"aot_cudagraphs\"`, `\"ofi\"`, `\"fx2trt\"`, `\"onnxrt\"` and `\"ipex\"`.\n",
1337
+ "# ray_scope (`str`, *optional*, defaults to `\"last\"`):\n",
1338
+ "# The scope to use when doing hyperparameter search with Ray. By default, `\"last\"` will be used. Ray will\n",
1339
+ "# then use the last checkpoint of all trials, compare those, and select the best one. However, other options\n",
1340
+ "# are also available. See the [Ray documentation](\n",
1341
+ "# https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for\n",
1342
+ "# more options.\n",
1343
+ "# ddp_timeout (`int`, *optional*, defaults to 1800):\n",
1344
+ "# The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when\n",
1345
+ "# performing slow operations in distributed runnings. Please refer the [PyTorch documentation]\n",
1346
+ "# (https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more\n",
1347
+ "# information.\n",
1348
+ "# use_mps_device (`bool`, *optional*, defaults to `False`):\n",
1349
+ "# This argument is deprecated.`mps` device will be used if it is available similar to `cuda` device.\n",
1350
+ "# torch_compile (`bool`, *optional*, defaults to `False`):\n",
1351
+ "# Whether or not to compile the model using PyTorch 2.0\n",
1352
+ "# [`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/).\n",
1353
+ "\n",
1354
+ "# This will use the best defaults for the [`torch.compile`\n",
1355
+ "# API](https://pytorch.org/docs/stable/generated/torch.compile.html?highlight=torch+compile#torch.compile).\n",
1356
+ "# You can customize the defaults with the argument `torch_compile_backend` and `torch_compile_mode` but we\n",
1357
+ "# don't guarantee any of them will work as the support is progressively rolled in in PyTorch.\n",
1358
+ "\n",
1359
+ "# This flag and the whole compile API is experimental and subject to change in future releases.\n",
1360
+ "# torch_compile_backend (`str`, *optional*):\n",
1361
+ "# The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.\n",
1362
+ "\n",
1363
+ "# Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.\n",
1364
+ "\n",
1365
+ "# This flag is experimental and subject to change in future releases.\n",
1366
+ "# torch_compile_mode (`str`, *optional*):\n",
1367
+ "# The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.\n",
1368
+ "\n",
1369
+ "# Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.\n",
1370
+ "\n",
1371
+ "# This flag is experimental and subject to change in future releases.\n",
1372
+ "# split_batches (`bool`, *optional*):\n",
1373
+ "# Whether or not the accelerator should split the batches yielded by the dataloaders across the devices\n",
1374
+ "# during distributed training. If\n",
1375
+ "\n",
1376
+ "# set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it\n",
1377
+ "# must be a\n",
1378
+ "\n",
1379
+ "# round multiple of the number of processes you are using (such as GPUs).\n",
1380
+ "# include_tokens_per_second (`bool`, *optional*):\n",
1381
+ "# Whether or not to compute the number of tokens per second per device for training speed metrics.\n",
1382
+ "\n",
1383
+ "# This will iterate over the entire training dataloader once beforehand,\n",
1384
+ "\n",
1385
+ "# and will slow down the entire process.\n",
1386
+ "\n",
1387
+ "# include_num_input_tokens_seen (`bool`, *optional*):\n",
1388
+ "# Whether or not to track the number of input tokens seen throughout training.\n",
1389
+ "\n",
1390
+ "# May be slower in distributed training as gather operations must be called.\n",
1391
+ "\n",
1392
+ "# neftune_noise_alpha (`Optional[float]`):\n",
1393
+ "# If not `None`, this will activate NEFTune noise embeddings. This can drastically improve model performance\n",
1394
+ "# for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the\n",
1395
+ "# [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also\n",
1396
+ "# `PeftModel` from peft.\n",
1397
+ "# optim_target_modules (`Union[str, List[str]]`, *optional*):\n",
1398
+ "# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm\n",
1399
+ "# https://arxiv.org/abs/2403.03507\n",
1400
+ "# See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe\n",
1401
+ "# optimizer, e.g. one of: \"galore_adamw\", \"galore_adamw_8bit\", \"galore_adafactor\" and make sure that the target modules are `nn.Linear` modules\n",
1402
+ "# only.\n",
1403
+ "\n",
1404
+ "# batch_eval_metrics (`Optional[bool]`, defaults to `False`):\n",
1405
+ "# If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics\n",
1406
+ "# rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function\n",
1407
+ "# that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global\n",
1408
+ "# summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.\n",
1409
+ "\n",
1410
+ "# eval_on_start (`bool`, *optional*, defaults to `False`):\n",
1411
+ "# Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.\n",
1412
+ "\n",
1413
+ "# eval_use_gather_object (`bool`, *optional*, defaults to `False`):\n",
1414
+ "# Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices.\n",
1415
+ "# \"\"\""
1416
+ ]
1417
+ }
1418
+ ],
1419
+ "metadata": {
1420
+ "kernelspec": {
1421
+ "display_name": "Python 3",
1422
+ "language": "python",
1423
+ "name": "python3"
1424
+ },
1425
+ "language_info": {
1426
+ "codemirror_mode": {
1427
+ "name": "ipython",
1428
+ "version": 3
1429
+ },
1430
+ "file_extension": ".py",
1431
+ "mimetype": "text/x-python",
1432
+ "name": "python",
1433
+ "nbconvert_exporter": "python",
1434
+ "pygments_lexer": "ipython3",
1435
+ "version": "3.11.9"
1436
+ }
1437
+ },
1438
+ "nbformat": 4,
1439
+ "nbformat_minor": 2
1440
+ }