Vedant Vyas commited on
Commit
446d0be
1 Parent(s): 2fca968
Files changed (3) hide show
  1. readme.md +1 -0
  2. requirements.txt +8 -0
  3. run_translation.py +660 -0
readme.md ADDED
@@ -0,0 +1 @@
 
 
1
+ ## Readme
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate >= 0.12.0
2
+ datasets >= 1.8.0
3
+ sentencepiece != 0.1.92
4
+ protobuf
5
+ sacrebleu >= 1.4.12
6
+ py7zr
7
+ torch >= 1.3
8
+ evaluate
run_translation.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import sys
24
+ from dataclasses import dataclass, field
25
+ from typing import Optional
26
+
27
+ import datasets
28
+ import numpy as np
29
+ from datasets import load_dataset
30
+
31
+ import evaluate
32
+ import transformers
33
+ from transformers import (
34
+ AutoConfig,
35
+ AutoModelForSeq2SeqLM,
36
+ AutoTokenizer,
37
+ DataCollatorForSeq2Seq,
38
+ HfArgumentParser,
39
+ M2M100Tokenizer,
40
+ MBart50Tokenizer,
41
+ MBart50TokenizerFast,
42
+ MBartTokenizer,
43
+ MBartTokenizerFast,
44
+ Seq2SeqTrainer,
45
+ Seq2SeqTrainingArguments,
46
+ default_data_collator,
47
+ set_seed,
48
+ )
49
+ from transformers.trainer_utils import get_last_checkpoint
50
+ from transformers.utils import check_min_version, send_example_telemetry
51
+ from transformers.utils.versions import require_version
52
+
53
+
54
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
55
+ check_min_version("4.26.0.dev0")
56
+
57
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ # A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
62
+ MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
63
+
64
+
65
+ @dataclass
66
+ class ModelArguments:
67
+ """
68
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
69
+ """
70
+
71
+ model_name_or_path: str = field(
72
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
73
+ )
74
+ config_name: Optional[str] = field(
75
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
76
+ )
77
+ tokenizer_name: Optional[str] = field(
78
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
79
+ )
80
+ cache_dir: Optional[str] = field(
81
+ default=None,
82
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
83
+ )
84
+ use_fast_tokenizer: bool = field(
85
+ default=True,
86
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
87
+ )
88
+ model_revision: str = field(
89
+ default="main",
90
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
91
+ )
92
+ use_auth_token: bool = field(
93
+ default=False,
94
+ metadata={
95
+ "help": (
96
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
97
+ "with private models)."
98
+ )
99
+ },
100
+ )
101
+
102
+
103
+ @dataclass
104
+ class DataTrainingArguments:
105
+ """
106
+ Arguments pertaining to what data we are going to input our model for training and eval.
107
+ """
108
+
109
+ source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
110
+ target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
111
+
112
+ dataset_name: Optional[str] = field(
113
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
114
+ )
115
+ dataset_config_name: Optional[str] = field(
116
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
117
+ )
118
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
119
+ validation_file: Optional[str] = field(
120
+ default=None,
121
+ metadata={
122
+ "help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file."
123
+ },
124
+ )
125
+ test_file: Optional[str] = field(
126
+ default=None,
127
+ metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."},
128
+ )
129
+ overwrite_cache: bool = field(
130
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
131
+ )
132
+ preprocessing_num_workers: Optional[int] = field(
133
+ default=None,
134
+ metadata={"help": "The number of processes to use for the preprocessing."},
135
+ )
136
+ max_source_length: Optional[int] = field(
137
+ default=1024,
138
+ metadata={
139
+ "help": (
140
+ "The maximum total input sequence length after tokenization. Sequences longer "
141
+ "than this will be truncated, sequences shorter will be padded."
142
+ )
143
+ },
144
+ )
145
+ max_target_length: Optional[int] = field(
146
+ default=128,
147
+ metadata={
148
+ "help": (
149
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
150
+ "than this will be truncated, sequences shorter will be padded."
151
+ )
152
+ },
153
+ )
154
+ val_max_target_length: Optional[int] = field(
155
+ default=None,
156
+ metadata={
157
+ "help": (
158
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
159
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
160
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
161
+ "during ``evaluate`` and ``predict``."
162
+ )
163
+ },
164
+ )
165
+ pad_to_max_length: bool = field(
166
+ default=False,
167
+ metadata={
168
+ "help": (
169
+ "Whether to pad all samples to model maximum sentence length. "
170
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
171
+ "efficient on GPU but very bad for TPU."
172
+ )
173
+ },
174
+ )
175
+ max_train_samples: Optional[int] = field(
176
+ default=None,
177
+ metadata={
178
+ "help": (
179
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
180
+ "value if set."
181
+ )
182
+ },
183
+ )
184
+ max_eval_samples: Optional[int] = field(
185
+ default=None,
186
+ metadata={
187
+ "help": (
188
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
189
+ "value if set."
190
+ )
191
+ },
192
+ )
193
+ max_predict_samples: Optional[int] = field(
194
+ default=None,
195
+ metadata={
196
+ "help": (
197
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
198
+ "value if set."
199
+ )
200
+ },
201
+ )
202
+ num_beams: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": (
206
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
207
+ "which is used during ``evaluate`` and ``predict``."
208
+ )
209
+ },
210
+ )
211
+ ignore_pad_token_for_loss: bool = field(
212
+ default=True,
213
+ metadata={
214
+ "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
215
+ },
216
+ )
217
+ source_prefix: Optional[str] = field(
218
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
219
+ )
220
+ forced_bos_token: Optional[str] = field(
221
+ default=None,
222
+ metadata={
223
+ "help": (
224
+ "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
225
+ " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
226
+ " be the target language token.(Usually it is the target language token)"
227
+ )
228
+ },
229
+ )
230
+
231
+ def __post_init__(self):
232
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
233
+ raise ValueError("Need either a dataset name or a training/validation file.")
234
+ elif self.source_lang is None or self.target_lang is None:
235
+ raise ValueError("Need to specify the source language and the target language.")
236
+
237
+ # accepting both json and jsonl file extensions, as
238
+ # many jsonlines files actually have a .json extension
239
+ valid_extensions = ["json", "jsonl"]
240
+
241
+ if self.train_file is not None:
242
+ extension = self.train_file.split(".")[-1]
243
+ assert extension in valid_extensions, "`train_file` should be a jsonlines file."
244
+ if self.validation_file is not None:
245
+ extension = self.validation_file.split(".")[-1]
246
+ assert extension in valid_extensions, "`validation_file` should be a jsonlines file."
247
+ if self.val_max_target_length is None:
248
+ self.val_max_target_length = self.max_target_length
249
+
250
+
251
+ def main():
252
+ # See all possible arguments in src/transformers/training_args.py
253
+ # or by passing the --help flag to this script.
254
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
255
+
256
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
257
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
258
+ # If we pass only one argument to the script and it's the path to a json file,
259
+ # let's parse it to get our arguments.
260
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
261
+ else:
262
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
263
+
264
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
265
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
266
+ send_example_telemetry("run_translation", model_args, data_args)
267
+
268
+ # Setup logging
269
+ logging.basicConfig(
270
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
271
+ datefmt="%m/%d/%Y %H:%M:%S",
272
+ handlers=[logging.StreamHandler(sys.stdout)],
273
+ )
274
+
275
+ log_level = training_args.get_process_log_level()
276
+ logger.setLevel(log_level)
277
+ datasets.utils.logging.set_verbosity(log_level)
278
+ transformers.utils.logging.set_verbosity(log_level)
279
+ transformers.utils.logging.enable_default_handler()
280
+ transformers.utils.logging.enable_explicit_format()
281
+
282
+ # Log on each process the small summary:
283
+ logger.warning(
284
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
285
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
286
+ )
287
+ logger.info(f"Training/evaluation parameters {training_args}")
288
+
289
+ if data_args.source_prefix is None and model_args.model_name_or_path in [
290
+ "t5-small",
291
+ "t5-base",
292
+ "t5-large",
293
+ "t5-3b",
294
+ "t5-11b",
295
+ ]:
296
+ logger.warning(
297
+ "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
298
+ "`--source_prefix 'translate English to German: ' `"
299
+ )
300
+
301
+ # Detecting last checkpoint.
302
+ last_checkpoint = None
303
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
304
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
305
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
306
+ raise ValueError(
307
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
308
+ "Use --overwrite_output_dir to overcome."
309
+ )
310
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
311
+ logger.info(
312
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
313
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
314
+ )
315
+
316
+ # Set seed before initializing model.
317
+ set_seed(training_args.seed)
318
+
319
+ # Get the datasets: you can either provide your own JSON training and evaluation files (see below)
320
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
321
+ # (the dataset will be downloaded automatically from the datasets Hub).
322
+ #
323
+ # For translation, only JSON files are supported, with one field named "translation" containing two keys for the
324
+ # source and target languages (unless you adapt what follows).
325
+ #
326
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
327
+ # download the dataset.
328
+ if data_args.dataset_name is not None:
329
+ # Downloading and loading a dataset from the hub.
330
+ raw_datasets = load_dataset(
331
+ data_args.dataset_name,
332
+ data_args.dataset_config_name,
333
+ cache_dir=model_args.cache_dir,
334
+ use_auth_token=True if model_args.use_auth_token else None,
335
+ )
336
+ else:
337
+ data_files = {}
338
+ if data_args.train_file is not None:
339
+ data_files["train"] = data_args.train_file
340
+ extension = data_args.train_file.split(".")[-1]
341
+ if data_args.validation_file is not None:
342
+ data_files["validation"] = data_args.validation_file
343
+ extension = data_args.validation_file.split(".")[-1]
344
+ if data_args.test_file is not None:
345
+ data_files["test"] = data_args.test_file
346
+ extension = data_args.test_file.split(".")[-1]
347
+ raw_datasets = load_dataset(
348
+ extension,
349
+ data_files=data_files,
350
+ cache_dir=model_args.cache_dir,
351
+ use_auth_token=True if model_args.use_auth_token else None,
352
+ )
353
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
354
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
355
+
356
+ # Load pretrained model and tokenizer
357
+ #
358
+ # Distributed training:
359
+ # The .from_pretrained methods guarantee that only one local process can concurrently
360
+ # download model & vocab.
361
+ config = AutoConfig.from_pretrained(
362
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
363
+ cache_dir=model_args.cache_dir,
364
+ revision=model_args.model_revision,
365
+ use_auth_token=True if model_args.use_auth_token else None,
366
+ )
367
+ tokenizer = AutoTokenizer.from_pretrained(
368
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
369
+ cache_dir=model_args.cache_dir,
370
+ use_fast=model_args.use_fast_tokenizer,
371
+ revision=model_args.model_revision,
372
+ use_auth_token=True if model_args.use_auth_token else None,
373
+ )
374
+ model = AutoModelForSeq2SeqLM.from_pretrained(
375
+ model_args.model_name_or_path,
376
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
377
+ config=config,
378
+ cache_dir=model_args.cache_dir,
379
+ revision=model_args.model_revision,
380
+ use_auth_token=True if model_args.use_auth_token else None,
381
+ )
382
+
383
+ # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
384
+ # on a small vocab and want a smaller embedding size, remove this test.
385
+ embedding_size = model.get_input_embeddings().weight.shape[0]
386
+ if len(tokenizer) > embedding_size:
387
+ model.resize_token_embeddings(len(tokenizer))
388
+
389
+ # Set decoder_start_token_id
390
+ if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
391
+ if isinstance(tokenizer, MBartTokenizer):
392
+ model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
393
+ else:
394
+ model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
395
+
396
+ if model.config.decoder_start_token_id is None:
397
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
398
+
399
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
400
+
401
+ # Preprocessing the datasets.
402
+ # We need to tokenize inputs and targets.
403
+ if training_args.do_train:
404
+ column_names = raw_datasets["train"].column_names
405
+ elif training_args.do_eval:
406
+ column_names = raw_datasets["validation"].column_names
407
+ elif training_args.do_predict:
408
+ column_names = raw_datasets["test"].column_names
409
+ else:
410
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
411
+ return
412
+
413
+ # For translation we set the codes of our source and target languages (only useful for mBART, the others will
414
+ # ignore those attributes).
415
+ if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
416
+ assert data_args.target_lang is not None and data_args.source_lang is not None, (
417
+ f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and "
418
+ "--target_lang arguments."
419
+ )
420
+
421
+ tokenizer.src_lang = data_args.source_lang
422
+ tokenizer.tgt_lang = data_args.target_lang
423
+
424
+ # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
425
+ # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
426
+ forced_bos_token_id = (
427
+ tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
428
+ )
429
+ model.config.forced_bos_token_id = forced_bos_token_id
430
+
431
+ # Get the language codes for input/target.
432
+ source_lang = data_args.source_lang.split("_")[0]
433
+ target_lang = data_args.target_lang.split("_")[0]
434
+
435
+ # Temporarily set max_target_length for training.
436
+ max_target_length = data_args.max_target_length
437
+ padding = "max_length" if data_args.pad_to_max_length else False
438
+
439
+ if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
440
+ logger.warning(
441
+ "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
442
+ f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
443
+ )
444
+
445
+ def preprocess_function(examples):
446
+ inputs = [ex[source_lang] for ex in examples["translation"]]
447
+ targets = [ex[target_lang] for ex in examples["translation"]]
448
+ inputs = [prefix + inp for inp in inputs]
449
+ model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
450
+
451
+ # Tokenize targets with the `text_target` keyword argument
452
+ labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
453
+
454
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
455
+ # padding in the loss.
456
+ if padding == "max_length" and data_args.ignore_pad_token_for_loss:
457
+ labels["input_ids"] = [
458
+ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
459
+ ]
460
+
461
+ model_inputs["labels"] = labels["input_ids"]
462
+ return model_inputs
463
+
464
+ if training_args.do_train:
465
+ if "train" not in raw_datasets:
466
+ raise ValueError("--do_train requires a train dataset")
467
+ train_dataset = raw_datasets["train"]
468
+ if data_args.max_train_samples is not None:
469
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
470
+ train_dataset = train_dataset.select(range(max_train_samples))
471
+ with training_args.main_process_first(desc="train dataset map pre-processing"):
472
+ train_dataset = train_dataset.map(
473
+ preprocess_function,
474
+ batched=True,
475
+ num_proc=data_args.preprocessing_num_workers,
476
+ remove_columns=column_names,
477
+ load_from_cache_file=not data_args.overwrite_cache,
478
+ desc="Running tokenizer on train dataset",
479
+ )
480
+
481
+ if training_args.do_eval:
482
+ max_target_length = data_args.val_max_target_length
483
+ if "validation" not in raw_datasets:
484
+ raise ValueError("--do_eval requires a validation dataset")
485
+ eval_dataset = raw_datasets["validation"]
486
+ if data_args.max_eval_samples is not None:
487
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
488
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
489
+ with training_args.main_process_first(desc="validation dataset map pre-processing"):
490
+ eval_dataset = eval_dataset.map(
491
+ preprocess_function,
492
+ batched=True,
493
+ num_proc=data_args.preprocessing_num_workers,
494
+ remove_columns=column_names,
495
+ load_from_cache_file=not data_args.overwrite_cache,
496
+ desc="Running tokenizer on validation dataset",
497
+ )
498
+
499
+ if training_args.do_predict:
500
+ max_target_length = data_args.val_max_target_length
501
+ if "test" not in raw_datasets:
502
+ raise ValueError("--do_predict requires a test dataset")
503
+ predict_dataset = raw_datasets["test"]
504
+ if data_args.max_predict_samples is not None:
505
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
506
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
507
+ with training_args.main_process_first(desc="prediction dataset map pre-processing"):
508
+ predict_dataset = predict_dataset.map(
509
+ preprocess_function,
510
+ batched=True,
511
+ num_proc=data_args.preprocessing_num_workers,
512
+ remove_columns=column_names,
513
+ load_from_cache_file=not data_args.overwrite_cache,
514
+ desc="Running tokenizer on prediction dataset",
515
+ )
516
+
517
+ # Data collator
518
+ label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
519
+ if data_args.pad_to_max_length:
520
+ data_collator = default_data_collator
521
+ else:
522
+ data_collator = DataCollatorForSeq2Seq(
523
+ tokenizer,
524
+ model=model,
525
+ label_pad_token_id=label_pad_token_id,
526
+ pad_to_multiple_of=8 if training_args.fp16 else None,
527
+ )
528
+
529
+ # Metric
530
+ metric = evaluate.load("sacrebleu")
531
+
532
+ def postprocess_text(preds, labels):
533
+ preds = [pred.strip() for pred in preds]
534
+ labels = [[label.strip()] for label in labels]
535
+
536
+ return preds, labels
537
+
538
+ def compute_metrics(eval_preds):
539
+ preds, labels = eval_preds
540
+ if isinstance(preds, tuple):
541
+ preds = preds[0]
542
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
543
+ if data_args.ignore_pad_token_for_loss:
544
+ # Replace -100 in the labels as we can't decode them.
545
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
546
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
547
+
548
+ # Some simple post-processing
549
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
550
+
551
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels)
552
+ result = {"bleu": result["score"]}
553
+
554
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
555
+ result["gen_len"] = np.mean(prediction_lens)
556
+ result = {k: round(v, 4) for k, v in result.items()}
557
+ return result
558
+
559
+ # Initialize our Trainer
560
+ trainer = Seq2SeqTrainer(
561
+ model=model,
562
+ args=training_args,
563
+ train_dataset=train_dataset if training_args.do_train else None,
564
+ eval_dataset=eval_dataset if training_args.do_eval else None,
565
+ tokenizer=tokenizer,
566
+ data_collator=data_collator,
567
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
568
+ )
569
+
570
+ # Training
571
+ if training_args.do_train:
572
+ checkpoint = None
573
+ if training_args.resume_from_checkpoint is not None:
574
+ checkpoint = training_args.resume_from_checkpoint
575
+ elif last_checkpoint is not None:
576
+ checkpoint = last_checkpoint
577
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
578
+ trainer.save_model() # Saves the tokenizer too for easy upload
579
+
580
+ metrics = train_result.metrics
581
+ max_train_samples = (
582
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
583
+ )
584
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
585
+
586
+ trainer.log_metrics("train", metrics)
587
+ trainer.save_metrics("train", metrics)
588
+ trainer.save_state()
589
+
590
+ # Evaluation
591
+ results = {}
592
+ max_length = (
593
+ training_args.generation_max_length
594
+ if training_args.generation_max_length is not None
595
+ else data_args.val_max_target_length
596
+ )
597
+ num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
598
+ if training_args.do_eval:
599
+ logger.info("*** Evaluate ***")
600
+
601
+ metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
602
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
603
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
604
+
605
+ trainer.log_metrics("eval", metrics)
606
+ trainer.save_metrics("eval", metrics)
607
+
608
+ if training_args.do_predict:
609
+ logger.info("*** Predict ***")
610
+
611
+ predict_results = trainer.predict(
612
+ predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
613
+ )
614
+ metrics = predict_results.metrics
615
+ max_predict_samples = (
616
+ data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
617
+ )
618
+ metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
619
+
620
+ trainer.log_metrics("predict", metrics)
621
+ trainer.save_metrics("predict", metrics)
622
+
623
+ if trainer.is_world_process_zero():
624
+ if training_args.predict_with_generate:
625
+ predictions = tokenizer.batch_decode(
626
+ predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
627
+ )
628
+ predictions = [pred.strip() for pred in predictions]
629
+ output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
630
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
631
+ writer.write("\n".join(predictions))
632
+
633
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "translation"}
634
+ if data_args.dataset_name is not None:
635
+ kwargs["dataset_tags"] = data_args.dataset_name
636
+ if data_args.dataset_config_name is not None:
637
+ kwargs["dataset_args"] = data_args.dataset_config_name
638
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
639
+ else:
640
+ kwargs["dataset"] = data_args.dataset_name
641
+
642
+ languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
643
+ if len(languages) > 0:
644
+ kwargs["language"] = languages
645
+
646
+ if training_args.push_to_hub:
647
+ trainer.push_to_hub(**kwargs)
648
+ else:
649
+ trainer.create_model_card(**kwargs)
650
+
651
+ return results
652
+
653
+
654
+ def _mp_fn(index):
655
+ # For xla_spawn (TPUs)
656
+ main()
657
+
658
+
659
+ if __name__ == "__main__":
660
+ main()