sdyy commited on
Commit
8265fd0
·
verified ·
1 Parent(s): bfa5d34

Upload run_translation.py

Browse files
Files changed (1) hide show
  1. run_translation.py +696 -0
run_translation.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import sys
24
+ from dataclasses import dataclass, field
25
+ from typing import Optional
26
+
27
+ import datasets
28
+ import evaluate
29
+ import numpy as np
30
+ from datasets import load_dataset
31
+
32
+ import transformers
33
+ from transformers import (
34
+ AutoConfig,
35
+ AutoModelForSeq2SeqLM,
36
+ AutoTokenizer,
37
+ DataCollatorForSeq2Seq,
38
+ HfArgumentParser,
39
+ M2M100Tokenizer,
40
+ MBart50Tokenizer,
41
+ MBart50TokenizerFast,
42
+ MBartTokenizer,
43
+ MBartTokenizerFast,
44
+ Seq2SeqTrainer,
45
+ Seq2SeqTrainingArguments,
46
+ default_data_collator,
47
+ set_seed,
48
+ )
49
+ from transformers.trainer_utils import get_last_checkpoint
50
+ from transformers.utils import check_min_version, send_example_telemetry
51
+ from transformers.utils.versions import require_version
52
+
53
+
54
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
55
+ check_min_version("4.42.0.dev0")
56
+
57
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ # A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
62
+ MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
63
+
64
+
65
+ @dataclass
66
+ class ModelArguments:
67
+ """
68
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
69
+ """
70
+
71
+ model_name_or_path: str = field(
72
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
73
+ )
74
+ config_name: Optional[str] = field(
75
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
76
+ )
77
+ tokenizer_name: Optional[str] = field(
78
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
79
+ )
80
+ cache_dir: Optional[str] = field(
81
+ default=None,
82
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
83
+ )
84
+ use_fast_tokenizer: bool = field(
85
+ default=True,
86
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
87
+ )
88
+ model_revision: str = field(
89
+ default="main",
90
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
91
+ )
92
+ token: str = field(
93
+ default=None,
94
+ metadata={
95
+ "help": (
96
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
97
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
98
+ )
99
+ },
100
+ )
101
+ trust_remote_code: bool = field(
102
+ default=False,
103
+ metadata={
104
+ "help": (
105
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
106
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
107
+ "execute code present on the Hub on your local machine."
108
+ )
109
+ },
110
+ )
111
+
112
+
113
+ @dataclass
114
+ class DataTrainingArguments:
115
+ """
116
+ Arguments pertaining to what data we are going to input our model for training and eval.
117
+ """
118
+
119
+ source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
120
+ target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
121
+
122
+ dataset_name: Optional[str] = field(
123
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
124
+ )
125
+ dataset_config_name: Optional[str] = field(
126
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
127
+ )
128
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
129
+ validation_file: Optional[str] = field(
130
+ default=None,
131
+ metadata={
132
+ "help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file."
133
+ },
134
+ )
135
+ test_file: Optional[str] = field(
136
+ default=None,
137
+ metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."},
138
+ )
139
+ overwrite_cache: bool = field(
140
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
141
+ )
142
+ preprocessing_num_workers: Optional[int] = field(
143
+ default=None,
144
+ metadata={"help": "The number of processes to use for the preprocessing."},
145
+ )
146
+ max_source_length: Optional[int] = field(
147
+ default=128,
148
+ metadata={
149
+ "help": (
150
+ "The maximum total input sequence length after tokenization. Sequences longer "
151
+ "than this will be truncated, sequences shorter will be padded."
152
+ )
153
+ },
154
+ )
155
+ max_target_length: Optional[int] = field(
156
+ default=128,
157
+ metadata={
158
+ "help": (
159
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
160
+ "than this will be truncated, sequences shorter will be padded."
161
+ )
162
+ },
163
+ )
164
+ val_max_target_length: Optional[int] = field(
165
+ default=None,
166
+ metadata={
167
+ "help": (
168
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
169
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`. "
170
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
171
+ "during ``evaluate`` and ``predict``."
172
+ )
173
+ },
174
+ )
175
+ pad_to_max_length: bool = field(
176
+ default=False,
177
+ metadata={
178
+ "help": (
179
+ "Whether to pad all samples to model maximum sentence length. "
180
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
181
+ "efficient on GPU but very bad for TPU."
182
+ )
183
+ },
184
+ )
185
+ max_train_samples: Optional[int] = field(
186
+ default=None,
187
+ metadata={
188
+ "help": (
189
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
190
+ "value if set."
191
+ )
192
+ },
193
+ )
194
+ max_eval_samples: Optional[int] = field(
195
+ default=None,
196
+ metadata={
197
+ "help": (
198
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
199
+ "value if set."
200
+ )
201
+ },
202
+ )
203
+ max_predict_samples: Optional[int] = field(
204
+ default=None,
205
+ metadata={
206
+ "help": (
207
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
208
+ "value if set."
209
+ )
210
+ },
211
+ )
212
+ num_beams: Optional[int] = field(
213
+ default=1,
214
+ metadata={
215
+ "help": (
216
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
217
+ "which is used during ``evaluate`` and ``predict``."
218
+ )
219
+ },
220
+ )
221
+ ignore_pad_token_for_loss: bool = field(
222
+ default=True,
223
+ metadata={
224
+ "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
225
+ },
226
+ )
227
+ source_prefix: Optional[str] = field(
228
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
229
+ )
230
+ forced_bos_token: Optional[str] = field(
231
+ default=None,
232
+ metadata={
233
+ "help": (
234
+ "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
235
+ " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
236
+ " be the target language token.(Usually it is the target language token)"
237
+ )
238
+ },
239
+ )
240
+
241
+ def __post_init__(self):
242
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
243
+ raise ValueError("Need either a dataset name or a training/validation file.")
244
+ elif self.source_lang is None or self.target_lang is None:
245
+ raise ValueError("Need to specify the source language and the target language.")
246
+
247
+ # accepting both json and jsonl file extensions, as
248
+ # many jsonlines files actually have a .json extension
249
+ valid_extensions = ["json", "jsonl"]
250
+
251
+ if self.train_file is not None:
252
+ extension = self.train_file.split(".")[-1]
253
+ assert extension in valid_extensions, "`train_file` should be a jsonlines file."
254
+ if self.validation_file is not None:
255
+ extension = self.validation_file.split(".")[-1]
256
+ assert extension in valid_extensions, "`validation_file` should be a jsonlines file."
257
+ if self.val_max_target_length is None:
258
+ self.val_max_target_length = self.max_target_length
259
+
260
+
261
+ def main():
262
+ # See all possible arguments in src/transformers/training_args.py
263
+ # or by passing the --help flag to this script.
264
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
265
+
266
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
267
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
268
+ # If we pass only one argument to the script and it's the path to a json file,
269
+ # let's parse it to get our arguments.
270
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
271
+ else:
272
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
273
+
274
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
275
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
276
+ send_example_telemetry("run_translation", model_args, data_args)
277
+
278
+ # Setup logging
279
+ logging.basicConfig(
280
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
281
+ datefmt="%m/%d/%Y %H:%M:%S",
282
+ handlers=[logging.StreamHandler(sys.stdout)],
283
+ )
284
+
285
+ if training_args.should_log:
286
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
287
+ transformers.utils.logging.set_verbosity_info()
288
+
289
+ log_level = training_args.get_process_log_level()
290
+ logger.setLevel(log_level)
291
+ datasets.utils.logging.set_verbosity(log_level)
292
+ transformers.utils.logging.set_verbosity(log_level)
293
+ transformers.utils.logging.enable_default_handler()
294
+ transformers.utils.logging.enable_explicit_format()
295
+
296
+ # Log on each process the small summary:
297
+ logger.warning(
298
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
299
+ + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
300
+ )
301
+ logger.info(f"Training/evaluation parameters {training_args}")
302
+
303
+ if data_args.source_prefix is None and model_args.model_name_or_path in [
304
+ "google-t5/t5-small",
305
+ "google-t5/t5-base",
306
+ "google-t5/t5-large",
307
+ "google-t5/t5-3b",
308
+ "google-t5/t5-11b",
309
+ ]:
310
+ logger.warning(
311
+ "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
312
+ "`--source_prefix 'translate English to German: ' `"
313
+ )
314
+
315
+ # Detecting last checkpoint.
316
+ last_checkpoint = None
317
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
318
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
319
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
320
+ raise ValueError(
321
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
322
+ "Use --overwrite_output_dir to overcome."
323
+ )
324
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
325
+ logger.info(
326
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
327
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
328
+ )
329
+
330
+ # Set seed before initializing model.
331
+ set_seed(training_args.seed)
332
+
333
+ # Get the datasets: you can either provide your own JSON training and evaluation files (see below)
334
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
335
+ # (the dataset will be downloaded automatically from the datasets Hub).
336
+ #
337
+ # For translation, only JSON files are supported, with one field named "translation" containing two keys for the
338
+ # source and target languages (unless you adapt what follows).
339
+ #
340
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
341
+ # download the dataset.
342
+ if data_args.dataset_name is not None:
343
+ # Downloading and loading a dataset from the hub.
344
+ raw_datasets = load_dataset(
345
+ data_args.dataset_name,
346
+ data_args.dataset_config_name,
347
+ cache_dir=model_args.cache_dir,
348
+ token=model_args.token,
349
+ )
350
+ else:
351
+ data_files = {}
352
+ if data_args.train_file is not None:
353
+ data_files["train"] = data_args.train_file
354
+ extension = data_args.train_file.split(".")[-1]
355
+ if data_args.validation_file is not None:
356
+ data_files["validation"] = data_args.validation_file
357
+ extension = data_args.validation_file.split(".")[-1]
358
+ if data_args.test_file is not None:
359
+ data_files["test"] = data_args.test_file
360
+ extension = data_args.test_file.split(".")[-1]
361
+ if extension == "jsonl":
362
+ builder_name = "json" # the "json" builder reads both .json and .jsonl files
363
+ else:
364
+ builder_name = extension # e.g. "parquet"
365
+ raw_datasets = load_dataset(
366
+ builder_name,
367
+ data_files=data_files,
368
+ cache_dir=model_args.cache_dir,
369
+ token=model_args.token,
370
+ )
371
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
372
+ # https://huggingface.co/docs/datasets/loading.
373
+
374
+ # Load pretrained model and tokenizer
375
+ #
376
+ # Distributed training:
377
+ # The .from_pretrained methods guarantee that only one local process can concurrently
378
+ # download model & vocab.
379
+ config = AutoConfig.from_pretrained(
380
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
381
+ cache_dir=model_args.cache_dir,
382
+ revision=model_args.model_revision,
383
+ token=model_args.token,
384
+ trust_remote_code=model_args.trust_remote_code,
385
+ )
386
+ tokenizer = AutoTokenizer.from_pretrained(
387
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
388
+ cache_dir=model_args.cache_dir,
389
+ use_fast=model_args.use_fast_tokenizer,
390
+ revision=model_args.model_revision,
391
+ token=model_args.token,
392
+ trust_remote_code=model_args.trust_remote_code,
393
+ )
394
+ model = AutoModelForSeq2SeqLM.from_pretrained(
395
+ model_args.model_name_or_path,
396
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
397
+ config=config,
398
+ cache_dir=model_args.cache_dir,
399
+ revision=model_args.model_revision,
400
+ token=model_args.token,
401
+ trust_remote_code=model_args.trust_remote_code,
402
+ )
403
+
404
+ # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
405
+ # on a small vocab and want a smaller embedding size, remove this test.
406
+ embedding_size = model.get_input_embeddings().weight.shape[0]
407
+ if len(tokenizer) > embedding_size:
408
+ model.resize_token_embeddings(len(tokenizer))
409
+
410
+ # Set decoder_start_token_id
411
+ if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
412
+ if isinstance(tokenizer, MBartTokenizer):
413
+ model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
414
+ else:
415
+ model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
416
+
417
+ if model.config.decoder_start_token_id is None:
418
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
419
+
420
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
421
+
422
+ # Preprocessing the datasets.
423
+ # We need to tokenize inputs and targets.
424
+ if training_args.do_train:
425
+ column_names = raw_datasets["train"].column_names
426
+ elif training_args.do_eval:
427
+ column_names = raw_datasets["validation"].column_names
428
+ elif training_args.do_predict:
429
+ column_names = raw_datasets["test"].column_names
430
+ else:
431
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
432
+ return
433
+
434
+ # For translation we set the codes of our source and target languages (only useful for mBART, the others will
435
+ # ignore those attributes).
436
+ if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
437
+ assert data_args.target_lang is not None and data_args.source_lang is not None, (
438
+ f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and "
439
+ "--target_lang arguments."
440
+ )
441
+
442
+ tokenizer.src_lang = data_args.source_lang
443
+ tokenizer.tgt_lang = data_args.target_lang
444
+
445
+ # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
446
+ # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
447
+ forced_bos_token_id = (
448
+ tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
449
+ )
450
+ model.config.forced_bos_token_id = forced_bos_token_id
451
+
452
+ # Get the language codes for input/target.
453
+ source_lang = data_args.source_lang.split("_")[0]
454
+ target_lang = data_args.target_lang.split("_")[0]
455
+
456
+ # Check the whether the source target length fits in the model, if it has absolute positional embeddings
457
+ if (
458
+ hasattr(model.config, "max_position_embeddings")
459
+ and not hasattr(model.config, "relative_attention_max_distance")
460
+ and model.config.max_position_embeddings < data_args.max_source_length
461
+ ):
462
+ raise ValueError(
463
+ f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
464
+ f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
465
+ f" `--max_source_length` to {model.config.max_position_embeddings} or using a model with larger position "
466
+ "embeddings"
467
+ )
468
+
469
+ # Temporarily set max_target_length for training.
470
+ max_target_length = data_args.max_target_length
471
+ padding = "max_length" if data_args.pad_to_max_length else False
472
+
473
+ if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
474
+ logger.warning(
475
+ "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for "
476
+ f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
477
+ )
478
+
479
+ def preprocess_function(examples):
480
+ inputs = [ex[source_lang] for ex in examples["translation"]]
481
+ targets = [ex[target_lang] for ex in examples["translation"]]
482
+ inputs = [prefix + inp for inp in inputs]
483
+ model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
484
+
485
+ # Tokenize targets with the `text_target` keyword argument
486
+ labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
487
+
488
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
489
+ # padding in the loss.
490
+ if padding == "max_length" and data_args.ignore_pad_token_for_loss:
491
+ labels["input_ids"] = [
492
+ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
493
+ ]
494
+
495
+ model_inputs["labels"] = labels["input_ids"]
496
+ return model_inputs
497
+
498
+ if training_args.do_train:
499
+ if "train" not in raw_datasets:
500
+ raise ValueError("--do_train requires a train dataset")
501
+ train_dataset = raw_datasets["train"]
502
+ if data_args.max_train_samples is not None:
503
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
504
+ train_dataset = train_dataset.select(range(max_train_samples))
505
+ with training_args.main_process_first(desc="train dataset map pre-processing"):
506
+ train_dataset = train_dataset.map(
507
+ preprocess_function,
508
+ batched=True,
509
+ num_proc=data_args.preprocessing_num_workers,
510
+ remove_columns=column_names,
511
+ load_from_cache_file=not data_args.overwrite_cache,
512
+ desc="Running tokenizer on train dataset",
513
+ )
514
+
515
+ if training_args.do_eval:
516
+ max_target_length = data_args.val_max_target_length
517
+ if "validation" not in raw_datasets:
518
+ raise ValueError("--do_eval requires a validation dataset")
519
+ eval_dataset = raw_datasets["validation"]
520
+ if data_args.max_eval_samples is not None:
521
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
522
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
523
+ with training_args.main_process_first(desc="validation dataset map pre-processing"):
524
+ eval_dataset = eval_dataset.map(
525
+ preprocess_function,
526
+ batched=True,
527
+ num_proc=data_args.preprocessing_num_workers,
528
+ remove_columns=column_names,
529
+ load_from_cache_file=not data_args.overwrite_cache,
530
+ desc="Running tokenizer on validation dataset",
531
+ )
532
+
533
+ if training_args.do_predict:
534
+ max_target_length = data_args.val_max_target_length
535
+ if "test" not in raw_datasets:
536
+ raise ValueError("--do_predict requires a test dataset")
537
+ predict_dataset = raw_datasets["test"]
538
+ if data_args.max_predict_samples is not None:
539
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
540
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
541
+ with training_args.main_process_first(desc="prediction dataset map pre-processing"):
542
+ predict_dataset = predict_dataset.map(
543
+ preprocess_function,
544
+ batched=True,
545
+ num_proc=data_args.preprocessing_num_workers,
546
+ remove_columns=column_names,
547
+ load_from_cache_file=not data_args.overwrite_cache,
548
+ desc="Running tokenizer on prediction dataset",
549
+ )
550
+
551
+ # Data collator
552
+ label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
553
+ if data_args.pad_to_max_length:
554
+ data_collator = default_data_collator
555
+ else:
556
+ data_collator = DataCollatorForSeq2Seq(
557
+ tokenizer,
558
+ model=model,
559
+ label_pad_token_id=label_pad_token_id,
560
+ pad_to_multiple_of=8 if training_args.fp16 else None,
561
+ )
562
+
563
+ # Metric
564
+ metric = evaluate.load("sacrebleu", cache_dir=model_args.cache_dir)
565
+
566
+ def postprocess_text(preds, labels):
567
+ preds = [pred.strip() for pred in preds]
568
+ labels = [[label.strip()] for label in labels]
569
+
570
+ return preds, labels
571
+
572
+ def compute_metrics(eval_preds):
573
+ preds, labels = eval_preds
574
+ if isinstance(preds, tuple):
575
+ preds = preds[0]
576
+ # Replace -100s used for padding as we can't decode them
577
+ preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
578
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
579
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
580
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
581
+
582
+ # Some simple post-processing
583
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
584
+
585
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels)
586
+ result = {"bleu": result["score"]}
587
+
588
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
589
+ result["gen_len"] = np.mean(prediction_lens)
590
+ result = {k: round(v, 4) for k, v in result.items()}
591
+ return result
592
+
593
+ # Initialize our Trainer
594
+ trainer = Seq2SeqTrainer(
595
+ model=model,
596
+ args=training_args,
597
+ train_dataset=train_dataset if training_args.do_train else None,
598
+ eval_dataset=eval_dataset if training_args.do_eval else None,
599
+ tokenizer=tokenizer,
600
+ data_collator=data_collator,
601
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
602
+ )
603
+
604
+ # Training
605
+ if training_args.do_train:
606
+ checkpoint = None
607
+ if training_args.resume_from_checkpoint is not None:
608
+ checkpoint = training_args.resume_from_checkpoint
609
+ elif last_checkpoint is not None:
610
+ checkpoint = last_checkpoint
611
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
612
+ trainer.save_model() # Saves the tokenizer too for easy upload
613
+
614
+ metrics = train_result.metrics
615
+ max_train_samples = (
616
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
617
+ )
618
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
619
+
620
+ trainer.log_metrics("train", metrics)
621
+ trainer.save_metrics("train", metrics)
622
+ trainer.save_state()
623
+
624
+ # Evaluation
625
+ results = {}
626
+ max_length = (
627
+ training_args.generation_max_length
628
+ if training_args.generation_max_length is not None
629
+ else data_args.val_max_target_length
630
+ )
631
+ num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
632
+ if training_args.do_eval:
633
+ logger.info("*** Evaluate ***")
634
+
635
+ metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
636
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
637
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
638
+
639
+ trainer.log_metrics("eval", metrics)
640
+ trainer.save_metrics("eval", metrics)
641
+
642
+ if training_args.do_predict:
643
+ logger.info("*** Predict ***")
644
+
645
+ predict_results = trainer.predict(
646
+ predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
647
+ )
648
+ metrics = predict_results.metrics
649
+ max_predict_samples = (
650
+ data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
651
+ )
652
+ metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
653
+
654
+ trainer.log_metrics("predict", metrics)
655
+ trainer.save_metrics("predict", metrics)
656
+
657
+ if trainer.is_world_process_zero():
658
+ if training_args.predict_with_generate:
659
+ predictions = predict_results.predictions
660
+ predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
661
+ predictions = tokenizer.batch_decode(
662
+ predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
663
+ )
664
+ predictions = [pred.strip() for pred in predictions]
665
+ output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
666
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
667
+ writer.write("\n".join(predictions))
668
+
669
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "translation"}
670
+ if data_args.dataset_name is not None:
671
+ kwargs["dataset_tags"] = data_args.dataset_name
672
+ if data_args.dataset_config_name is not None:
673
+ kwargs["dataset_args"] = data_args.dataset_config_name
674
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
675
+ else:
676
+ kwargs["dataset"] = data_args.dataset_name
677
+
678
+ languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
679
+ if len(languages) > 0:
680
+ kwargs["language"] = languages
681
+
682
+ if training_args.push_to_hub:
683
+ trainer.push_to_hub(**kwargs)
684
+ else:
685
+ trainer.create_model_card(**kwargs)
686
+
687
+ return results
688
+
689
+
690
+ def _mp_fn(index):
691
+ # For xla_spawn (TPUs)
692
+ main()
693
+
694
+
695
+ if __name__ == "__main__":
696
+ main()