Joshua Lochner commited on
Commit
490a61c
1 Parent(s): e77b67b

Merge duplicated training dataclasses

Browse files
Files changed (4) hide show
  1. src/preprocess.py +0 -42
  2. src/shared.py +101 -1
  3. src/train.py +15 -82
  4. src/train_classifier.py +44 -152
src/preprocess.py CHANGED
@@ -490,54 +490,12 @@ def download_file(url, filename):
490
 
491
  @dataclass
492
  class PreprocessingDatasetArguments(DatasetArguments):
493
-
494
- train_file: Optional[str] = field(
495
- default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
496
- )
497
- validation_file: Optional[str] = field(
498
- default='valid.json',
499
- metadata={
500
- 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
501
- },
502
- )
503
- test_file: Optional[str] = field(
504
- default='test.json',
505
- metadata={
506
- 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
507
- },
508
- )
509
-
510
- c_train_file: Optional[str] = field(
511
- default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
512
- )
513
- c_validation_file: Optional[str] = field(
514
- default='c_valid.json',
515
- metadata={
516
- 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
517
- },
518
- )
519
- c_test_file: Optional[str] = field(
520
- default='c_test.json',
521
- metadata={
522
- 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
523
- },
524
- )
525
-
526
  # excess_file: Optional[str] = field(
527
  # default='excess.json',
528
  # metadata={
529
  # 'help': 'The excess segments left after the split'
530
  # },
531
  # )
532
- dataset_cache_dir: Optional[str] = field(
533
- default=None,
534
- metadata={
535
- 'help': 'Where to store the cached datasets'
536
- },
537
- )
538
- overwrite_cache: bool = field(
539
- default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
540
- )
541
 
542
  positive_file: Optional[str] = field(
543
  default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
 
490
 
491
  @dataclass
492
  class PreprocessingDatasetArguments(DatasetArguments):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  # excess_file: Optional[str] = field(
494
  # default='excess.json',
495
  # metadata={
496
  # 'help': 'The excess segments left after the split'
497
  # },
498
  # )
 
 
 
 
 
 
 
 
 
499
 
500
  positive_file: Optional[str] = field(
501
  default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
src/shared.py CHANGED
@@ -104,6 +104,10 @@ class DatasetArguments:
104
  },
105
  )
106
 
 
 
 
 
107
  dataset_cache_dir: Optional[str] = field(
108
  default=None,
109
  metadata={
@@ -111,6 +115,35 @@ class DatasetArguments:
111
  },
112
  )
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  @dataclass
116
  class OutputArguments:
@@ -178,7 +211,7 @@ def reset():
178
  print(torch.cuda.memory_summary(device=None, abbreviated=False))
179
 
180
 
181
- def load_datasets(dataset_args):
182
 
183
  print('Reading datasets')
184
  data_files = {}
@@ -240,6 +273,39 @@ class CustomTrainingArguments(OutputArguments, TrainingArguments):
240
  # * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
241
  # * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  logging.basicConfig()
245
  logger = logging.getLogger(__name__)
@@ -279,3 +345,37 @@ def train_from_checkpoint(trainer, last_checkpoint, training_args):
279
  trainer.save_model() # Saves the tokenizer too for easy upload
280
 
281
  return train_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  },
105
  )
106
 
107
+ overwrite_cache: bool = field(
108
+ default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
109
+ )
110
+
111
  dataset_cache_dir: Optional[str] = field(
112
  default=None,
113
  metadata={
 
115
  },
116
  )
117
 
118
+ train_file: Optional[str] = field(
119
+ default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
120
+ )
121
+ validation_file: Optional[str] = field(
122
+ default='valid.json',
123
+ metadata={
124
+ 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
125
+ },
126
+ )
127
+ test_file: Optional[str] = field(
128
+ default='test.json',
129
+ metadata={
130
+ 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
131
+ },
132
+ )
133
+
134
+ def __post_init__(self):
135
+ if self.train_file is None or self.validation_file is None:
136
+ raise ValueError(
137
+ "Need either a GLUE task, a training/validation file or a dataset name.")
138
+ else:
139
+ train_extension = self.train_file.split(".")[-1]
140
+ assert train_extension in [
141
+ "csv", "json"], "`train_file` should be a csv or a json file."
142
+ validation_extension = self.validation_file.split(".")[-1]
143
+ assert (
144
+ validation_extension == train_extension
145
+ ), "`validation_file` should have the same extension (csv or json) as `train_file`."
146
+
147
 
148
  @dataclass
149
  class OutputArguments:
 
211
  print(torch.cuda.memory_summary(device=None, abbreviated=False))
212
 
213
 
214
+ def load_datasets(dataset_args: DatasetArguments):
215
 
216
  print('Reading datasets')
217
  data_files = {}
 
273
  # * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
274
  # * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
275
 
276
+ preprocessing_num_workers: Optional[int] = field(
277
+ default=None,
278
+ metadata={'help': 'The number of processes to use for the preprocessing.'},
279
+ )
280
+ max_seq_length: int = field(
281
+ default=512,
282
+ metadata={
283
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
284
+ "than this will be truncated, sequences shorter will be padded."
285
+ },
286
+ )
287
+ max_train_samples: Optional[int] = field(
288
+ default=None,
289
+ metadata={
290
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
291
+ "value if set."
292
+ },
293
+ )
294
+ max_eval_samples: Optional[int] = field(
295
+ default=None,
296
+ metadata={
297
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
298
+ "value if set."
299
+ },
300
+ )
301
+ max_predict_samples: Optional[int] = field(
302
+ default=None,
303
+ metadata={
304
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
305
+ "value if set."
306
+ },
307
+ )
308
+
309
 
310
  logging.basicConfig()
311
  logger = logging.getLogger(__name__)
 
345
  trainer.save_model() # Saves the tokenizer too for easy upload
346
 
347
  return train_result
348
+
349
+
350
+ def prepare_datasets(raw_datasets, dataset_args: DatasetArguments, training_args: CustomTrainingArguments, preprocess_function):
351
+
352
+ with training_args.main_process_first(desc="dataset map pre-processing"):
353
+ raw_datasets = raw_datasets.map(
354
+ preprocess_function,
355
+ batched=True,
356
+ load_from_cache_file=not dataset_args.overwrite_cache,
357
+ desc="Running tokenizer on dataset",
358
+ )
359
+
360
+ if 'train' not in raw_datasets:
361
+ raise ValueError('Train dataset missing')
362
+ train_dataset = raw_datasets['train']
363
+ if training_args.max_train_samples is not None:
364
+ train_dataset = train_dataset.select(
365
+ range(training_args.max_train_samples))
366
+
367
+ if 'validation' not in raw_datasets:
368
+ raise ValueError('Validation dataset missing')
369
+ eval_dataset = raw_datasets['validation']
370
+ if training_args.max_eval_samples is not None:
371
+ eval_dataset = eval_dataset.select(
372
+ range(training_args.max_eval_samples))
373
+
374
+ if 'test' not in raw_datasets:
375
+ raise ValueError('Test dataset missing')
376
+ predict_dataset = raw_datasets['test']
377
+ if training_args.max_predict_samples is not None:
378
+ predict_dataset = predict_dataset.select(
379
+ range(training_args.max_predict_samples))
380
+
381
+ return train_dataset, eval_dataset, predict_dataset
src/train.py CHANGED
@@ -1,12 +1,17 @@
1
  from preprocess import PreprocessingDatasetArguments
2
- from shared import CustomTokens, load_datasets, CustomTrainingArguments, get_last_checkpoint, train_from_checkpoint
 
 
 
 
 
 
 
3
  from model import ModelArguments
4
  import transformers
5
  import logging
6
  import os
7
  import sys
8
- from dataclasses import dataclass, field
9
- from typing import Optional
10
  from datasets import utils as d_utils
11
  from transformers import (
12
  DataCollatorForSeq2Seq,
@@ -35,38 +40,6 @@ logging.basicConfig(
35
  )
36
 
37
 
38
-
39
- @dataclass
40
- class DataTrainingArguments:
41
- """
42
- Arguments pertaining to what data we are going to input our model for training and eval.
43
- """
44
-
45
- preprocessing_num_workers: Optional[int] = field(
46
- default=None,
47
- metadata={'help': 'The number of processes to use for the preprocessing.'},
48
- )
49
-
50
- max_train_samples: Optional[int] = field(
51
- default=None,
52
- metadata={
53
- 'help': 'For debugging purposes or quicker training, truncate the number of training examples to this value if set.'
54
- },
55
- )
56
- max_eval_samples: Optional[int] = field(
57
- default=None,
58
- metadata={
59
- 'help': 'For debugging purposes or quicker training, truncate the number of evaluation examples to this value if set.'
60
- },
61
- )
62
- max_predict_samples: Optional[int] = field(
63
- default=None,
64
- metadata={
65
- 'help': 'For debugging purposes or quicker training, truncate the number of prediction examples to this value if set.'
66
- },
67
- )
68
-
69
-
70
  def main():
71
 
72
  # See all possible arguments in src/transformers/training_args.py
@@ -76,10 +49,9 @@ def main():
76
  hf_parser = HfArgumentParser((
77
  ModelArguments,
78
  PreprocessingDatasetArguments,
79
- DataTrainingArguments,
80
  CustomTrainingArguments
81
  ))
82
- model_args, dataset_args, data_training_args, training_args = hf_parser.parse_args_into_dataclasses()
83
 
84
  log_level = training_args.get_process_log_level()
85
  logger.setLevel(log_level)
@@ -128,7 +100,6 @@ def main():
128
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
129
  # https://huggingface.co/docs/datasets/loading_datasets.html.
130
 
131
-
132
  # Detecting last checkpoint.
133
  last_checkpoint = get_last_checkpoint(training_args)
134
 
@@ -165,47 +136,8 @@ def main():
165
 
166
  return model_inputs
167
 
168
- def prepare_dataset(dataset, desc):
169
- return dataset.map(
170
- preprocess_function,
171
- batched=True,
172
- num_proc=data_training_args.preprocessing_num_workers,
173
- remove_columns=column_names,
174
- load_from_cache_file=not dataset_args.overwrite_cache,
175
- desc=desc, # tokenizing train dataset
176
- )
177
- # train_dataset # TODO shuffle?
178
-
179
- # if training_args.do_train:
180
- if 'train' not in raw_datasets: # TODO do checks above?
181
- raise ValueError('Train dataset missing')
182
- train_dataset = raw_datasets['train']
183
- if data_training_args.max_train_samples is not None:
184
- train_dataset = train_dataset.select(
185
- range(data_training_args.max_train_samples))
186
- with training_args.main_process_first(desc='train dataset map pre-processing'):
187
- train_dataset = prepare_dataset(
188
- train_dataset, desc='Running tokenizer on train dataset')
189
-
190
- if 'validation' not in raw_datasets:
191
- raise ValueError('Validation dataset missing')
192
- eval_dataset = raw_datasets['validation']
193
- if data_training_args.max_eval_samples is not None:
194
- eval_dataset = eval_dataset.select(
195
- range(data_training_args.max_eval_samples))
196
- with training_args.main_process_first(desc='validation dataset map pre-processing'):
197
- eval_dataset = prepare_dataset(
198
- eval_dataset, desc='Running tokenizer on validation dataset')
199
-
200
- if 'test' not in raw_datasets:
201
- raise ValueError('Test dataset missing')
202
- predict_dataset = raw_datasets['test']
203
- if data_training_args.max_predict_samples is not None:
204
- predict_dataset = predict_dataset.select(
205
- range(data_training_args.max_predict_samples))
206
- with training_args.main_process_first(desc='prediction dataset map pre-processing'):
207
- predict_dataset = prepare_dataset(
208
- predict_dataset, desc='Running tokenizer on prediction dataset')
209
 
210
  # Data collator
211
  data_collator = DataCollatorForSeq2Seq(
@@ -228,10 +160,11 @@ def main():
228
  )
229
 
230
  # Training
231
- train_result = train_from_checkpoint(trainer, last_checkpoint, training_args)
 
232
 
233
  metrics = train_result.metrics
234
- max_train_samples = data_training_args.max_train_samples or len(
235
  train_dataset)
236
  metrics['train_samples'] = min(max_train_samples, len(train_dataset))
237
 
@@ -240,7 +173,7 @@ def main():
240
  trainer.save_state()
241
 
242
  kwargs = {'finetuned_from': model_args.model_name_or_path,
243
- 'tasks': 'summarization'}
244
 
245
  if training_args.push_to_hub:
246
  trainer.push_to_hub(**kwargs)
 
1
  from preprocess import PreprocessingDatasetArguments
2
+ from shared import (
3
+ CustomTokens,
4
+ prepare_datasets,
5
+ load_datasets,
6
+ CustomTrainingArguments,
7
+ get_last_checkpoint,
8
+ train_from_checkpoint
9
+ )
10
  from model import ModelArguments
11
  import transformers
12
  import logging
13
  import os
14
  import sys
 
 
15
  from datasets import utils as d_utils
16
  from transformers import (
17
  DataCollatorForSeq2Seq,
 
40
  )
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def main():
44
 
45
  # See all possible arguments in src/transformers/training_args.py
 
49
  hf_parser = HfArgumentParser((
50
  ModelArguments,
51
  PreprocessingDatasetArguments,
 
52
  CustomTrainingArguments
53
  ))
54
+ model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
55
 
56
  log_level = training_args.get_process_log_level()
57
  logger.setLevel(log_level)
 
100
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
101
  # https://huggingface.co/docs/datasets/loading_datasets.html.
102
 
 
103
  # Detecting last checkpoint.
104
  last_checkpoint = get_last_checkpoint(training_args)
105
 
 
136
 
137
  return model_inputs
138
 
139
+ train_dataset, eval_dataset, predict_dataset = prepare_datasets(
140
+ raw_datasets, dataset_args, training_args, preprocess_function)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  # Data collator
143
  data_collator = DataCollatorForSeq2Seq(
 
160
  )
161
 
162
  # Training
163
+ train_result = train_from_checkpoint(
164
+ trainer, last_checkpoint, training_args)
165
 
166
  metrics = train_result.metrics
167
+ max_train_samples = training_args.max_train_samples or len(
168
  train_dataset)
169
  metrics['train_samples'] = min(max_train_samples, len(train_dataset))
170
 
 
173
  trainer.save_state()
174
 
175
  kwargs = {'finetuned_from': model_args.model_name_or_path,
176
+ 'tasks': 'summarization'}
177
 
178
  if training_args.push_to_hub:
179
  trainer.push_to_hub(**kwargs)
src/train_classifier.py CHANGED
@@ -3,14 +3,12 @@
3
 
4
  import logging
5
  import os
6
- import random
7
  import sys
8
  from dataclasses import dataclass, field
9
  from typing import Optional
10
 
11
  import datasets
12
  import numpy as np
13
- from datasets import load_metric
14
 
15
  import transformers
16
  from transformers import (
@@ -18,96 +16,39 @@ from transformers import (
18
  EvalPrediction,
19
  HfArgumentParser,
20
  Trainer,
21
- default_data_collator,
22
  set_seed,
23
  )
24
  from transformers.utils import check_min_version
25
  from transformers.utils.versions import require_version
26
- from shared import CATEGORIES, load_datasets, CustomTrainingArguments, train_from_checkpoint, get_last_checkpoint
27
- from preprocess import PreprocessingDatasetArguments
28
  from model import get_model_tokenizer, ModelArguments
29
 
30
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
31
- check_min_version("4.17.0")
32
- require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
33
 
34
- os.environ["WANDB_DISABLED"] = "true"
35
 
36
  logger = logging.getLogger(__name__)
37
 
38
 
39
  @dataclass
40
- class DataArguments:
41
- """
42
- Arguments pertaining to what data we are going to input our model for training and eval.
43
-
44
- Using `HfArgumentParser` we can turn this class
45
- into argparse arguments to be able to specify them on
46
- the command line.
47
- """
48
-
49
- max_seq_length: int = field(
50
- default=512,
51
- metadata={
52
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
53
- "than this will be truncated, sequences shorter will be padded."
54
- },
55
- )
56
- overwrite_cache: bool = field(
57
- default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
58
- )
59
- pad_to_max_length: bool = field(
60
- default=True,
61
- metadata={
62
- "help": "Whether to pad all samples to `max_seq_length`. "
63
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
64
- },
65
  )
66
- max_train_samples: Optional[int] = field(
67
- default=None,
68
  metadata={
69
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
70
- "value if set."
71
  },
72
  )
73
- max_eval_samples: Optional[int] = field(
74
- default=None,
75
  metadata={
76
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
77
- "value if set."
78
  },
79
  )
80
- max_predict_samples: Optional[int] = field(
81
- default=None,
82
- metadata={
83
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
84
- "value if set."
85
- },
86
- )
87
-
88
- dataset_cache_dir: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
89
- 'dataset_cache_dir']
90
- data_dir: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
91
- 'data_dir']
92
- train_file: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
93
- 'c_train_file']
94
- validation_file: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
95
- 'c_validation_file']
96
- test_file: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
97
- 'c_test_file']
98
-
99
- def __post_init__(self):
100
- if self.train_file is None or self.validation_file is None:
101
- raise ValueError(
102
- "Need either a GLUE task, a training/validation file or a dataset name.")
103
- else:
104
- train_extension = self.train_file.split(".")[-1]
105
- assert train_extension in [
106
- "csv", "json"], "`train_file` should be a csv or a json file."
107
- validation_extension = self.validation_file.split(".")[-1]
108
- assert (
109
- validation_extension == train_extension
110
- ), "`validation_file` should have the same extension (csv or json) as `train_file`."
111
 
112
 
113
  def main():
@@ -115,14 +56,17 @@ def main():
115
  # or by passing the --help flag to this script.
116
  # We now keep distinct sets of args, for a cleaner separation of concerns.
117
 
118
- parser = HfArgumentParser(
119
- (ModelArguments, DataArguments, CustomTrainingArguments))
120
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
 
 
 
121
 
122
  # Setup logging
123
  logging.basicConfig(
124
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
125
- datefmt="%m/%d/%Y %H:%M:%S",
126
  handlers=[logging.StreamHandler(sys.stdout)],
127
  )
128
 
@@ -135,10 +79,10 @@ def main():
135
 
136
  # Log on each process the small summary:
137
  logger.warning(
138
- f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
139
- + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
140
  )
141
- logger.info(f"Training/evaluation parameters {training_args}")
142
 
143
  # Detecting last checkpoint.
144
  last_checkpoint = get_last_checkpoint(training_args)
@@ -148,7 +92,7 @@ def main():
148
 
149
  # Loading a dataset from your local files.
150
  # CSV/JSON training and evaluation files are needed.
151
- raw_datasets = load_datasets(data_args)
152
 
153
  # See more about loading any type of standard or custom dataset at
154
  # https://huggingface.co/docs/datasets/loading_datasets.html.
@@ -158,69 +102,26 @@ def main():
158
  'id2label': {k: str(v).upper() for k, v in enumerate(CATEGORIES)},
159
  'label2id': {str(v).upper(): k for k, v in enumerate(CATEGORIES)}
160
  }
161
- model, tokenizer = get_model_tokenizer(model_args, training_args, config_args=config_args, model_type='classifier')
 
162
 
163
-
164
- # Padding strategy
165
- if data_args.pad_to_max_length:
166
- padding = "max_length"
167
- else:
168
- # We will pad later, dynamically at batch creation, to the max sequence length in each batch
169
- padding = False
170
-
171
- if data_args.max_seq_length > tokenizer.model_max_length:
172
  logger.warning(
173
- f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
174
- f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
175
  )
176
- max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
 
177
 
178
  def preprocess_function(examples):
179
  # Tokenize the texts
180
  result = tokenizer(
181
- examples['text'], padding=padding, max_length=max_seq_length, truncation=True)
182
  result['label'] = examples['label']
183
  return result
184
 
185
- with training_args.main_process_first(desc="dataset map pre-processing"):
186
- raw_datasets = raw_datasets.map(
187
- preprocess_function,
188
- batched=True,
189
- load_from_cache_file=not data_args.overwrite_cache,
190
- desc="Running tokenizer on dataset",
191
- )
192
- if training_args.do_train:
193
- if "train" not in raw_datasets:
194
- raise ValueError("--do_train requires a train dataset")
195
- train_dataset = raw_datasets["train"]
196
- if data_args.max_train_samples is not None:
197
- train_dataset = train_dataset.select(
198
- range(data_args.max_train_samples))
199
-
200
- if training_args.do_eval:
201
- if "validation" not in raw_datasets:
202
- raise ValueError("--do_eval requires a validation dataset")
203
- eval_dataset = raw_datasets["validation"]
204
- if data_args.max_eval_samples is not None:
205
- eval_dataset = eval_dataset.select(
206
- range(data_args.max_eval_samples))
207
-
208
- if training_args.do_predict or data_args.test_file is not None:
209
- if "test" not in raw_datasets:
210
- raise ValueError("--do_predict requires a test dataset")
211
- predict_dataset = raw_datasets["test"]
212
- if data_args.max_predict_samples is not None:
213
- predict_dataset = predict_dataset.select(
214
- range(data_args.max_predict_samples))
215
-
216
- # Log a few random samples from the training set:
217
- if training_args.do_train:
218
- for index in random.sample(range(len(train_dataset)), 3):
219
- logger.info(
220
- f"Sample {index} of the training set: {train_dataset[index]}.")
221
-
222
- # Get the metric function
223
- metric = load_metric("accuracy")
224
 
225
  # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
226
  # predictions and label_ids field) and has to return a dictionary string to float.
@@ -228,20 +129,11 @@ def main():
228
  preds = p.predictions[0] if isinstance(
229
  p.predictions, tuple) else p.predictions
230
  preds = np.argmax(preds, axis=1)
231
- if data_args.task_name is not None:
232
- result = metric.compute(predictions=preds, references=p.label_ids)
233
- if len(result) > 1:
234
- result["combined_score"] = np.mean(
235
- list(result.values())).item()
236
- return result
237
- else:
238
- return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
239
 
240
  # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
241
  # we already did the padding.
242
- if data_args.pad_to_max_length:
243
- data_collator = default_data_collator
244
- elif training_args.fp16:
245
  data_collator = DataCollatorWithPadding(
246
  tokenizer, pad_to_multiple_of=8)
247
  else:
@@ -264,24 +156,24 @@ def main():
264
 
265
  metrics = train_result.metrics
266
  max_train_samples = (
267
- data_args.max_train_samples if data_args.max_train_samples is not None else len(
268
  train_dataset)
269
  )
270
- metrics["train_samples"] = min(max_train_samples, len(train_dataset))
271
 
272
  trainer.save_model() # Saves the tokenizer too for easy upload
273
 
274
- trainer.log_metrics("train", metrics)
275
- trainer.save_metrics("train", metrics)
276
  trainer.save_state()
277
 
278
- kwargs = {"finetuned_from": model_args.model_name_or_path,
279
- "tasks": "text-classification"}
280
  if training_args.push_to_hub:
281
  trainer.push_to_hub(**kwargs)
282
  else:
283
  trainer.create_model_card(**kwargs)
284
 
285
 
286
- if __name__ == "__main__":
287
  main()
 
3
 
4
  import logging
5
  import os
 
6
  import sys
7
  from dataclasses import dataclass, field
8
  from typing import Optional
9
 
10
  import datasets
11
  import numpy as np
 
12
 
13
  import transformers
14
  from transformers import (
 
16
  EvalPrediction,
17
  HfArgumentParser,
18
  Trainer,
 
19
  set_seed,
20
  )
21
  from transformers.utils import check_min_version
22
  from transformers.utils.versions import require_version
23
+ from shared import CATEGORIES, DatasetArguments, prepare_datasets, load_datasets, CustomTrainingArguments, train_from_checkpoint, get_last_checkpoint
 
24
  from model import get_model_tokenizer, ModelArguments
25
 
26
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
27
+ check_min_version('4.17.0')
28
+ require_version('datasets>=1.8.0', 'To fix: pip install -r requirements.txt')
29
 
30
+ os.environ['WANDB_DISABLED'] = 'true'
31
 
32
  logger = logging.getLogger(__name__)
33
 
34
 
35
  @dataclass
36
+ class ClassifierDatasetArguments(DatasetArguments):
37
+ train_file: Optional[str] = field(
38
+ default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
+ validation_file: Optional[str] = field(
41
+ default='c_valid.json',
42
  metadata={
43
+ 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
 
44
  },
45
  )
46
+ test_file: Optional[str] = field(
47
+ default='c_test.json',
48
  metadata={
49
+ 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
 
50
  },
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  def main():
 
56
  # or by passing the --help flag to this script.
57
  # We now keep distinct sets of args, for a cleaner separation of concerns.
58
 
59
+ hf_parser = HfArgumentParser((
60
+ ModelArguments,
61
+ ClassifierDatasetArguments,
62
+ CustomTrainingArguments
63
+ ))
64
+ model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
65
 
66
  # Setup logging
67
  logging.basicConfig(
68
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
69
+ datefmt='%m/%d/%Y %H:%M:%S',
70
  handlers=[logging.StreamHandler(sys.stdout)],
71
  )
72
 
 
79
 
80
  # Log on each process the small summary:
81
  logger.warning(
82
+ f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
83
+ + f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
84
  )
85
+ logger.info(f'Training/evaluation parameters {training_args}')
86
 
87
  # Detecting last checkpoint.
88
  last_checkpoint = get_last_checkpoint(training_args)
 
92
 
93
  # Loading a dataset from your local files.
94
  # CSV/JSON training and evaluation files are needed.
95
+ raw_datasets = load_datasets(dataset_args)
96
 
97
  # See more about loading any type of standard or custom dataset at
98
  # https://huggingface.co/docs/datasets/loading_datasets.html.
 
102
  'id2label': {k: str(v).upper() for k, v in enumerate(CATEGORIES)},
103
  'label2id': {str(v).upper(): k for k, v in enumerate(CATEGORIES)}
104
  }
105
+ model, tokenizer = get_model_tokenizer(
106
+ model_args, training_args, config_args=config_args, model_type='classifier')
107
 
108
+ if training_args.max_seq_length > tokenizer.model_max_length:
 
 
 
 
 
 
 
 
109
  logger.warning(
110
+ f'The max_seq_length passed ({training_args.max_seq_length}) is larger than the maximum length for the'
111
+ f'model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}.'
112
  )
113
+ max_seq_length = min(training_args.max_seq_length,
114
+ tokenizer.model_max_length)
115
 
116
  def preprocess_function(examples):
117
  # Tokenize the texts
118
  result = tokenizer(
119
+ examples['text'], padding='max_length', max_length=max_seq_length, truncation=True)
120
  result['label'] = examples['label']
121
  return result
122
 
123
+ train_dataset, eval_dataset, predict_dataset = prepare_datasets(
124
+ raw_datasets, dataset_args, training_args, preprocess_function)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
127
  # predictions and label_ids field) and has to return a dictionary string to float.
 
129
  preds = p.predictions[0] if isinstance(
130
  p.predictions, tuple) else p.predictions
131
  preds = np.argmax(preds, axis=1)
132
+ return {'accuracy': (preds == p.label_ids).astype(np.float32).mean().item()}
 
 
 
 
 
 
 
133
 
134
  # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
135
  # we already did the padding.
136
+ if training_args.fp16:
 
 
137
  data_collator = DataCollatorWithPadding(
138
  tokenizer, pad_to_multiple_of=8)
139
  else:
 
156
 
157
  metrics = train_result.metrics
158
  max_train_samples = (
159
+ training_args.max_train_samples if training_args.max_train_samples is not None else len(
160
  train_dataset)
161
  )
162
+ metrics['train_samples'] = min(max_train_samples, len(train_dataset))
163
 
164
  trainer.save_model() # Saves the tokenizer too for easy upload
165
 
166
+ trainer.log_metrics('train', metrics)
167
+ trainer.save_metrics('train', metrics)
168
  trainer.save_state()
169
 
170
+ kwargs = {'finetuned_from': model_args.model_name_or_path,
171
+ 'tasks': 'text-classification'}
172
  if training_args.push_to_hub:
173
  trainer.push_to_hub(**kwargs)
174
  else:
175
  trainer.create_model_card(**kwargs)
176
 
177
 
178
+ if __name__ == '__main__':
179
  main()