cnut1648 commited on
Commit
c635b38
1 Parent(s): ce7510e

Create mednli.py

Browse files
Files changed (1) hide show
  1. mednli.py +602 -0
mednli.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Finetuning the library models for sequence classification on GLUE."""
17
+ # You can also adapt this script on your own text classification task. Pointers for this are left as comments.
18
+ import logging
19
+ import os
20
+ import random
21
+ import sys
22
+ from dataclasses import dataclass, field
23
+ from typing import Optional
24
+
25
+ import datasets
26
+ import numpy as np
27
+ from datasets import load_dataset, concatenate_datasets
28
+
29
+ import evaluate
30
+ import transformers
31
+ from transformers import (
32
+ AutoConfig,
33
+ AutoModelForSequenceClassification,
34
+ AutoTokenizer,
35
+ DataCollatorWithPadding,
36
+ EvalPrediction,
37
+ HfArgumentParser,
38
+ PretrainedConfig,
39
+ Trainer,
40
+ TrainingArguments,
41
+ default_data_collator,
42
+ set_seed,
43
+ )
44
+ from transformers.trainer_utils import get_last_checkpoint
45
+ from transformers.utils import check_min_version, send_example_telemetry
46
+ from transformers.utils.versions import require_version
47
+
48
+
49
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
50
+ check_min_version("4.22.2")
51
+
52
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
53
+
54
+ task_to_keys = {
55
+ "cola": ("sentence", None),
56
+ "mnli": ("premise", "hypothesis"),
57
+ "mrpc": ("sentence1", "sentence2"),
58
+ "qnli": ("question", "sentence"),
59
+ "qqp": ("question1", "question2"),
60
+ "rte": ("sentence1", "sentence2"),
61
+ "sst2": ("sentence", None),
62
+ "stsb": ("sentence1", "sentence2"),
63
+ "wnli": ("sentence1", "sentence2"),
64
+ }
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ @dataclass
70
+ class DataTrainingArguments:
71
+ """
72
+ Arguments pertaining to what data we are going to input our model for training and eval.
73
+ Using `HfArgumentParser` we can turn this class
74
+ into argparse arguments to be able to specify them on
75
+ the command line.
76
+ """
77
+
78
+ task_name: Optional[str] = field(
79
+ default=None,
80
+ metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
81
+ )
82
+ dataset_name: Optional[str] = field(
83
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
84
+ )
85
+ dataset_config_name: Optional[str] = field(
86
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
87
+ )
88
+ max_seq_length: int = field(
89
+ default=128,
90
+ metadata={
91
+ "help": (
92
+ "The maximum total input sequence length after tokenization. Sequences longer "
93
+ "than this will be truncated, sequences shorter will be padded."
94
+ )
95
+ },
96
+ )
97
+ overwrite_cache: bool = field(
98
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
99
+ )
100
+ pad_to_max_length: bool = field(
101
+ default=True,
102
+ metadata={
103
+ "help": (
104
+ "Whether to pad all samples to `max_seq_length`. "
105
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
106
+ )
107
+ },
108
+ )
109
+ max_train_samples: Optional[int] = field(
110
+ default=None,
111
+ metadata={
112
+ "help": (
113
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
114
+ "value if set."
115
+ )
116
+ },
117
+ )
118
+ max_eval_samples: Optional[int] = field(
119
+ default=None,
120
+ metadata={
121
+ "help": (
122
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
123
+ "value if set."
124
+ )
125
+ },
126
+ )
127
+ max_predict_samples: Optional[int] = field(
128
+ default=None,
129
+ metadata={
130
+ "help": (
131
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
132
+ "value if set."
133
+ )
134
+ },
135
+ )
136
+ train_file: Optional[str] = field(
137
+ default=None, metadata={"help": "A csv or a json file containing the training data."}
138
+ )
139
+ validation_file: Optional[str] = field(
140
+ default=None, metadata={"help": "A csv or a json file containing the validation data."}
141
+ )
142
+ test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
143
+
144
+ def __post_init__(self):
145
+ if self.task_name is not None:
146
+ self.task_name = self.task_name.lower()
147
+ if self.task_name not in task_to_keys.keys():
148
+ raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
149
+ elif self.dataset_name is not None:
150
+ pass
151
+ elif self.train_file is None or self.validation_file is None:
152
+ raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
153
+ else:
154
+ train_extension = self.train_file.split(".")[-1]
155
+ assert train_extension in ["csv", "json", "jsonl"], "`train_file` should be a csv or a json file."
156
+ validation_extension = self.validation_file.split(".")[-1]
157
+ assert (
158
+ validation_extension == train_extension
159
+ ), "`validation_file` should have the same extension (csv or json) as `train_file`."
160
+
161
+
162
+ @dataclass
163
+ class ModelArguments:
164
+ """
165
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
166
+ """
167
+
168
+ model_name_or_path: str = field(
169
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
170
+ )
171
+ config_name: Optional[str] = field(
172
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
173
+ )
174
+ tokenizer_name: Optional[str] = field(
175
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
176
+ )
177
+ cache_dir: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
180
+ )
181
+ use_fast_tokenizer: bool = field(
182
+ default=True,
183
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
184
+ )
185
+ model_revision: str = field(
186
+ default="main",
187
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
188
+ )
189
+ use_auth_token: bool = field(
190
+ default=False,
191
+ metadata={
192
+ "help": (
193
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
194
+ "with private models)."
195
+ )
196
+ },
197
+ )
198
+ ignore_mismatched_sizes: bool = field(
199
+ default=False,
200
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
201
+ )
202
+
203
+
204
+ def main():
205
+ # See all possible arguments in src/transformers/training_args.py
206
+ # or by passing the --help flag to this script.
207
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
208
+
209
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
210
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
211
+ # If we pass only one argument to the script and it's the path to a json file,
212
+ # let's parse it to get our arguments.
213
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
214
+ else:
215
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
216
+
217
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
218
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
219
+ send_example_telemetry("run_glue", model_args, data_args)
220
+
221
+ # Setup logging
222
+ logging.basicConfig(
223
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
224
+ datefmt="%m/%d/%Y %H:%M:%S",
225
+ handlers=[logging.StreamHandler(sys.stdout)],
226
+ )
227
+
228
+ log_level = training_args.get_process_log_level()
229
+ logger.setLevel(log_level)
230
+ datasets.utils.logging.set_verbosity(log_level)
231
+ transformers.utils.logging.set_verbosity(log_level)
232
+ transformers.utils.logging.enable_default_handler()
233
+ transformers.utils.logging.enable_explicit_format()
234
+
235
+ # Log on each process the small summary:
236
+ logger.warning(
237
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
238
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
239
+ )
240
+ logger.info(f"Training/evaluation parameters {training_args}")
241
+
242
+ # Detecting last checkpoint.
243
+ last_checkpoint = None
244
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
245
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
246
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
247
+ raise ValueError(
248
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
249
+ "Use --overwrite_output_dir to overcome."
250
+ )
251
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
252
+ logger.info(
253
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
254
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
255
+ )
256
+
257
+ # Set seed before initializing model.
258
+ set_seed(training_args.seed)
259
+
260
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
261
+ # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
262
+ #
263
+ # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
264
+ # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
265
+ # label if at least two columns are provided.
266
+ #
267
+ # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
268
+ # single column. You can easily tweak this behavior (see below)
269
+ #
270
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
271
+ # download the dataset.
272
+ if data_args.task_name is not None:
273
+ # Downloading and loading a dataset from the hub.
274
+ raw_datasets = load_dataset(
275
+ "glue",
276
+ data_args.task_name,
277
+ cache_dir=model_args.cache_dir,
278
+ use_auth_token=True if model_args.use_auth_token else None,
279
+ )
280
+ elif data_args.dataset_name is not None:
281
+ # Downloading and loading a dataset from the hub.
282
+ raw_datasets = load_dataset(
283
+ data_args.dataset_name,
284
+ data_args.dataset_config_name,
285
+ cache_dir=model_args.cache_dir,
286
+ use_auth_token=True if model_args.use_auth_token else None,
287
+ )
288
+ else:
289
+ # Loading a dataset from your local files.
290
+ # CSV/JSON training and evaluation files are needed.
291
+ data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
292
+
293
+ # Get the test dataset: you can provide your own CSV/JSON test file (see below)
294
+ # when you use `do_predict` without specifying a GLUE benchmark task.
295
+ if training_args.do_predict:
296
+ if data_args.test_file is not None:
297
+ train_extension = data_args.train_file.split(".")[-1]
298
+ test_extension = data_args.test_file.split(".")[-1]
299
+ assert (
300
+ test_extension == train_extension
301
+ ), "`test_file` should have the same extension (csv or json) as `train_file`."
302
+ data_files["test"] = data_args.test_file
303
+ else:
304
+ raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
305
+
306
+ for key in data_files.keys():
307
+ logger.info(f"load a local file for {key}: {data_files[key]}")
308
+
309
+ if data_args.train_file.endswith(".csv"):
310
+ # Loading a dataset from local csv files
311
+ raw_datasets = load_dataset(
312
+ "csv",
313
+ data_files=data_files,
314
+ cache_dir=model_args.cache_dir,
315
+ use_auth_token=True if model_args.use_auth_token else None,
316
+ )
317
+ else:
318
+ # Loading a dataset from local json files
319
+ raw_datasets = load_dataset(
320
+ "json",
321
+ data_files=data_files,
322
+ cache_dir=model_args.cache_dir,
323
+ use_auth_token=True if model_args.use_auth_token else None,
324
+ )
325
+ raw_datasets = raw_datasets.rename_column("gold_label", "label")
326
+ # See more about loading any type of standard or custom dataset at
327
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
328
+
329
+ # Labels
330
+ if data_args.task_name is not None:
331
+ is_regression = data_args.task_name == "stsb"
332
+ if not is_regression:
333
+ label_list = raw_datasets["train"].features["label"].names
334
+ num_labels = len(label_list)
335
+ else:
336
+ num_labels = 1
337
+ else:
338
+ # Trying to have good defaults here, don't hesitate to tweak to your needs.
339
+ is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
340
+ if is_regression:
341
+ num_labels = 1
342
+ else:
343
+ # A useful fast method:
344
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
345
+ label_list = raw_datasets["train"].unique("label")
346
+ label_list.sort()
347
+ assert label_list == ['contradiction', 'entailment', 'neutral']
348
+ # need 0 for entailment
349
+ label_list = ['entailment', 'neutral', 'contradiction']
350
+ num_labels = len(label_list)
351
+
352
+ # Load pretrained model and tokenizer
353
+ #
354
+ # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
355
+ # download model & vocab.
356
+ config = AutoConfig.from_pretrained(
357
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
358
+ num_labels=num_labels,
359
+ finetuning_task=data_args.task_name,
360
+ cache_dir=model_args.cache_dir,
361
+ revision=model_args.model_revision,
362
+ use_auth_token=True if model_args.use_auth_token else None,
363
+ )
364
+ tokenizer = AutoTokenizer.from_pretrained(
365
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
366
+ cache_dir=model_args.cache_dir,
367
+ use_fast=model_args.use_fast_tokenizer,
368
+ revision=model_args.model_revision,
369
+ use_auth_token=True if model_args.use_auth_token else None,
370
+ )
371
+ model = AutoModelForSequenceClassification.from_pretrained(
372
+ model_args.model_name_or_path,
373
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
374
+ config=config,
375
+ cache_dir=model_args.cache_dir,
376
+ revision=model_args.model_revision,
377
+ use_auth_token=True if model_args.use_auth_token else None,
378
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
379
+ )
380
+
381
+ # Preprocessing the raw_datasets
382
+ if data_args.task_name is not None:
383
+ sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
384
+ else:
385
+ # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
386
+ non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
387
+ if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
388
+ sentence1_key, sentence2_key = "sentence1", "sentence2"
389
+ else:
390
+ if len(non_label_column_names) >= 2:
391
+ sentence1_key, sentence2_key = non_label_column_names[:2]
392
+ else:
393
+ sentence1_key, sentence2_key = non_label_column_names[0], None
394
+
395
+ # Padding strategy
396
+ if data_args.pad_to_max_length:
397
+ padding = "max_length"
398
+ else:
399
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
400
+ padding = False
401
+
402
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
403
+ label_to_id = None
404
+ if (
405
+ model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
406
+ and data_args.task_name is not None
407
+ and not is_regression
408
+ ):
409
+ # Some have all caps in their config, some don't.
410
+ label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
411
+ if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
412
+ label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
413
+ else:
414
+ logger.warning(
415
+ "Your model seems to have been trained with labels, but they don't match the dataset: ",
416
+ f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
417
+ "\nIgnoring the model labels as a result.",
418
+ )
419
+ elif data_args.task_name is None and not is_regression:
420
+ label_to_id = {v: i for i, v in enumerate(label_list)}
421
+
422
+ if label_to_id is not None:
423
+ model.config.label2id = label_to_id
424
+ model.config.id2label = {id: label for label, id in config.label2id.items()}
425
+ elif data_args.task_name is not None and not is_regression:
426
+ model.config.label2id = {l: i for i, l in enumerate(label_list)}
427
+ model.config.id2label = {id: label for label, id in config.label2id.items()}
428
+
429
+ if data_args.max_seq_length > tokenizer.model_max_length:
430
+ logger.warning(
431
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
432
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
433
+ )
434
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
435
+
436
+ def preprocess_function(examples):
437
+ # Tokenize the texts
438
+ args = (
439
+ (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
440
+ )
441
+ result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
442
+
443
+ # Map labels to IDs (not necessary for GLUE tasks)
444
+ if label_to_id is not None and "label" in examples:
445
+ result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
446
+ return result
447
+
448
+ with training_args.main_process_first(desc="dataset map pre-processing"):
449
+ raw_datasets = raw_datasets.map(
450
+ preprocess_function,
451
+ batched=True,
452
+ load_from_cache_file=not data_args.overwrite_cache,
453
+ desc="Running tokenizer on dataset",
454
+ )
455
+ if training_args.do_train:
456
+ if "train" not in raw_datasets:
457
+ raise ValueError("--do_train requires a train dataset")
458
+ else:
459
+ train_dataset = raw_datasets["train"]
460
+ if data_args.max_train_samples is not None:
461
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
462
+ train_dataset = train_dataset.select(range(max_train_samples))
463
+
464
+ if training_args.do_eval:
465
+ if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
466
+ raise ValueError("--do_eval requires a validation dataset")
467
+ eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
468
+ if data_args.max_eval_samples is not None:
469
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
470
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
471
+
472
+ if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
473
+ if "test" not in raw_datasets and "test_matched" not in raw_datasets:
474
+ raise ValueError("--do_predict requires a test dataset")
475
+ predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"]
476
+ if data_args.max_predict_samples is not None:
477
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
478
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
479
+
480
+ # Log a few random samples from the training set:
481
+ if training_args.do_train:
482
+ for index in random.sample(range(len(train_dataset)), 3):
483
+ logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
484
+
485
+ # Get the metric function
486
+ if data_args.task_name is not None:
487
+ metric = evaluate.load("glue", data_args.task_name)
488
+ else:
489
+ metric = evaluate.load("accuracy")
490
+
491
+ # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
492
+ # predictions and label_ids field) and has to return a dictionary string to float.
493
+ def compute_metrics(p: EvalPrediction):
494
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
495
+ preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
496
+ if data_args.task_name is not None:
497
+ result = metric.compute(predictions=preds, references=p.label_ids)
498
+ if len(result) > 1:
499
+ result["combined_score"] = np.mean(list(result.values())).item()
500
+ return result
501
+ elif is_regression:
502
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
503
+ else:
504
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
505
+
506
+ # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
507
+ # we already did the padding.
508
+ if data_args.pad_to_max_length:
509
+ data_collator = default_data_collator
510
+ elif training_args.fp16:
511
+ data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
512
+ else:
513
+ data_collator = None
514
+
515
+ # Initialize our Trainer
516
+ trainer = Trainer(
517
+ model=model,
518
+ args=training_args,
519
+ train_dataset=train_dataset if training_args.do_train else None,
520
+ eval_dataset=eval_dataset if training_args.do_eval else None,
521
+ compute_metrics=compute_metrics,
522
+ tokenizer=tokenizer,
523
+ data_collator=data_collator,
524
+ )
525
+
526
+ # Training
527
+ if training_args.do_train:
528
+ checkpoint = None
529
+ if training_args.resume_from_checkpoint is not None:
530
+ checkpoint = training_args.resume_from_checkpoint
531
+ elif last_checkpoint is not None:
532
+ checkpoint = last_checkpoint
533
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
534
+ metrics = train_result.metrics
535
+ max_train_samples = (
536
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
537
+ )
538
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
539
+
540
+ trainer.save_model() # Saves the tokenizer too for easy upload
541
+
542
+ trainer.log_metrics("train", metrics)
543
+ trainer.save_metrics("train", metrics)
544
+ trainer.save_state()
545
+
546
+ # Evaluation
547
+ if training_args.do_eval:
548
+ logger.info("*** Evaluate ***")
549
+
550
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
551
+ tasks = [data_args.task_name]
552
+ eval_datasets = [eval_dataset]
553
+
554
+ for eval_dataset, task in zip(eval_datasets, tasks):
555
+ metrics = trainer.evaluate(eval_dataset=eval_dataset)
556
+
557
+ max_eval_samples = (
558
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
559
+ )
560
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
561
+
562
+ trainer.log_metrics("eval", metrics)
563
+ trainer.save_metrics("eval", metrics)
564
+
565
+ if training_args.do_predict:
566
+ logger.info("*** Predict ***")
567
+
568
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
569
+ tasks = [data_args.task_name]
570
+ predict_datasets = [predict_dataset]
571
+
572
+ for predict_dataset, task in zip(predict_datasets, tasks):
573
+ metrics = trainer.evaluate(eval_dataset=predict_dataset)
574
+
575
+ max_eval_samples = (
576
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
577
+ )
578
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
579
+
580
+ trainer.log_metrics("test", metrics)
581
+ trainer.save_metrics("test", metrics)
582
+
583
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
584
+ if data_args.task_name is not None:
585
+ kwargs["language"] = "en"
586
+ kwargs["dataset_tags"] = "glue"
587
+ kwargs["dataset_args"] = data_args.task_name
588
+ kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"
589
+
590
+ if training_args.push_to_hub:
591
+ trainer.push_to_hub(**kwargs)
592
+ else:
593
+ trainer.create_model_card(**kwargs)
594
+
595
+
596
+ def _mp_fn(index):
597
+ # For xla_spawn (TPUs)
598
+ main()
599
+
600
+
601
+ if __name__ == "__main__":
602
+ main()