Jerome2046 commited on
Commit
d64db71
1 Parent(s): 3dbb033

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +411 -0
main.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import sys
24
+ import json
25
+
26
+ import numpy as np
27
+ from datasets import load_dataset
28
+ import jieba
29
+ from rouge_chinese import Rouge
30
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
31
+ import torch
32
+
33
+ import transformers
34
+ from transformers import (
35
+ AutoConfig,
36
+ AutoModel,
37
+ AutoTokenizer,
38
+ DataCollatorForSeq2Seq,
39
+ HfArgumentParser,
40
+ Seq2SeqTrainingArguments,
41
+ set_seed,
42
+ )
43
+ from trainer_seq2seq import Seq2SeqTrainer
44
+
45
+ from arguments import ModelArguments, DataTrainingArguments
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ def main():
50
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
51
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
52
+ # If we pass only one argument to the script and it's the path to a json file,
53
+ # let's parse it to get our arguments.
54
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
55
+ else:
56
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
57
+
58
+ # Setup logging
59
+ logging.basicConfig(
60
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
61
+ datefmt="%m/%d/%Y %H:%M:%S",
62
+ handlers=[logging.StreamHandler(sys.stdout)],
63
+ )
64
+
65
+ if training_args.should_log:
66
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
67
+ transformers.utils.logging.set_verbosity_info()
68
+
69
+ log_level = training_args.get_process_log_level()
70
+ logger.setLevel(log_level)
71
+ # datasets.utils.logging.set_verbosity(log_level)
72
+ transformers.utils.logging.set_verbosity(log_level)
73
+ transformers.utils.logging.enable_default_handler()
74
+ transformers.utils.logging.enable_explicit_format()
75
+
76
+ # Log on each process the small summary:
77
+ logger.warning(
78
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
79
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
80
+ )
81
+ logger.info(f"Training/evaluation parameters {training_args}")
82
+
83
+ # Set seed before initializing model.
84
+ set_seed(training_args.seed)
85
+
86
+ # Load dataset
87
+ data_files = {}
88
+ if data_args.train_file is not None:
89
+ data_files["train"] = data_args.train_file
90
+ extension = data_args.train_file.split(".")[-1]
91
+ if data_args.validation_file is not None:
92
+ data_files["validation"] = data_args.validation_file
93
+ extension = data_args.validation_file.split(".")[-1]
94
+ if data_args.test_file is not None:
95
+ data_files["test"] = data_args.test_file
96
+ extension = data_args.test_file.split(".")[-1]
97
+
98
+ raw_datasets = load_dataset(
99
+ extension,
100
+ data_files=data_files,
101
+ cache_dir=model_args.cache_dir,
102
+ use_auth_token=True if model_args.use_auth_token else None,
103
+ )
104
+
105
+ # Load pretrained model and tokenizer
106
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
107
+ config.pre_seq_len = model_args.pre_seq_len
108
+ config.prefix_projection = model_args.prefix_projection
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
111
+
112
+ if model_args.ptuning_checkpoint is not None:
113
+ # Evaluation
114
+ # Loading extra state dict of prefix encoder
115
+ model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
116
+ prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
117
+ new_prefix_state_dict = {}
118
+ for k, v in prefix_state_dict.items():
119
+ if k.startswith("transformer.prefix_encoder."):
120
+ new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
121
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
122
+ else:
123
+ model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
124
+
125
+ if model_args.quantization_bit is not None:
126
+ print(f"Quantized to {model_args.quantization_bit} bit")
127
+ model = model.quantize(model_args.quantization_bit)
128
+ if model_args.pre_seq_len is not None:
129
+ # P-tuning v2
130
+ model = model.half()
131
+ model.transformer.prefix_encoder.float()
132
+ else:
133
+ # Finetune
134
+ model = model.float()
135
+
136
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
137
+
138
+ # Preprocessing the datasets.
139
+ # We need to tokenize inputs and targets.
140
+ if training_args.do_train:
141
+ column_names = raw_datasets["train"].column_names
142
+ elif training_args.do_eval:
143
+ column_names = raw_datasets["validation"].column_names
144
+ elif training_args.do_predict:
145
+ column_names = raw_datasets["test"].column_names
146
+ else:
147
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
148
+ return
149
+
150
+ # Get the column names for input/target.
151
+ prompt_column = data_args.prompt_column
152
+ response_column = data_args.response_column
153
+ history_column = data_args.history_column
154
+
155
+ # Temporarily set max_target_length for training.
156
+ max_target_length = data_args.max_target_length
157
+
158
+ def preprocess_function_eval(examples):
159
+ inputs, targets = [], []
160
+ for i in range(len(examples[prompt_column])):
161
+ if examples[prompt_column][i] and examples[response_column][i]:
162
+ query = examples[prompt_column][i]
163
+ history = examples[history_column][i] if history_column is not None else None
164
+ prompt = tokenizer.build_prompt(query, history)
165
+ inputs.append(prompt)
166
+ targets.append(examples[response_column][i])
167
+
168
+ inputs = [prefix + inp for inp in inputs]
169
+ model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
170
+ labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
171
+
172
+ if data_args.ignore_pad_token_for_loss:
173
+ labels["input_ids"] = [
174
+ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
175
+ ]
176
+ model_inputs["labels"] = labels["input_ids"]
177
+
178
+ return model_inputs
179
+
180
+ def preprocess_function_train(examples):
181
+ max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
182
+
183
+ model_inputs = {
184
+ "input_ids": [],
185
+ "labels": [],
186
+ }
187
+ for i in range(len(examples[prompt_column])):
188
+ if examples[prompt_column][i] and examples[response_column][i]:
189
+ query, answer = examples[prompt_column][i], examples[response_column][i]
190
+
191
+ history = examples[history_column][i] if history_column is not None else None
192
+ prompt = tokenizer.build_prompt(query, history)
193
+
194
+ prompt = prefix + prompt
195
+ a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
196
+ max_length=data_args.max_source_length)
197
+ b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
198
+ max_length=data_args.max_target_length)
199
+
200
+ context_length = len(a_ids)
201
+ input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
202
+ labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
203
+
204
+ pad_len = max_seq_length - len(input_ids)
205
+ input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
206
+ labels = labels + [tokenizer.pad_token_id] * pad_len
207
+ if data_args.ignore_pad_token_for_loss:
208
+ labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
209
+
210
+ model_inputs["input_ids"].append(input_ids)
211
+ model_inputs["labels"].append(labels)
212
+
213
+ return model_inputs
214
+
215
+ def print_dataset_example(example):
216
+ print("input_ids", example["input_ids"])
217
+ print("inputs", tokenizer.decode(example["input_ids"]))
218
+ print("label_ids", example["labels"])
219
+ print("labels", tokenizer.decode(example["labels"]))
220
+
221
+ if training_args.do_train:
222
+ if "train" not in raw_datasets:
223
+ raise ValueError("--do_train requires a train dataset")
224
+ train_dataset = raw_datasets["train"]
225
+ if data_args.max_train_samples is not None:
226
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
227
+ train_dataset = train_dataset.select(range(max_train_samples))
228
+ with training_args.main_process_first(desc="train dataset map pre-processing"):
229
+ train_dataset = train_dataset.map(
230
+ preprocess_function_train,
231
+ batched=True,
232
+ num_proc=data_args.preprocessing_num_workers,
233
+ remove_columns=column_names,
234
+ load_from_cache_file=not data_args.overwrite_cache,
235
+ desc="Running tokenizer on train dataset",
236
+ )
237
+ print_dataset_example(train_dataset[0])
238
+
239
+ if training_args.do_eval:
240
+ max_target_length = data_args.val_max_target_length
241
+ if "validation" not in raw_datasets:
242
+ raise ValueError("--do_eval requires a validation dataset")
243
+ eval_dataset = raw_datasets["validation"]
244
+ if data_args.max_eval_samples is not None:
245
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
246
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
247
+ with training_args.main_process_first(desc="validation dataset map pre-processing"):
248
+ eval_dataset = eval_dataset.map(
249
+ preprocess_function_eval,
250
+ batched=True,
251
+ num_proc=data_args.preprocessing_num_workers,
252
+ remove_columns=column_names,
253
+ load_from_cache_file=not data_args.overwrite_cache,
254
+ desc="Running tokenizer on validation dataset",
255
+ )
256
+ print_dataset_example(eval_dataset[0])
257
+
258
+ if training_args.do_predict:
259
+ max_target_length = data_args.val_max_target_length
260
+ if "test" not in raw_datasets:
261
+ raise ValueError("--do_predict requires a test dataset")
262
+ predict_dataset = raw_datasets["test"]
263
+ if data_args.max_predict_samples is not None:
264
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
265
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
266
+ with training_args.main_process_first(desc="prediction dataset map pre-processing"):
267
+ predict_dataset = predict_dataset.map(
268
+ preprocess_function_eval,
269
+ batched=True,
270
+ num_proc=data_args.preprocessing_num_workers,
271
+ remove_columns=column_names,
272
+ load_from_cache_file=not data_args.overwrite_cache,
273
+ desc="Running tokenizer on prediction dataset",
274
+ )
275
+ print_dataset_example(predict_dataset[0])
276
+
277
+ # Data collator
278
+ label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
279
+ data_collator = DataCollatorForSeq2Seq(
280
+ tokenizer,
281
+ model=model,
282
+ label_pad_token_id=label_pad_token_id,
283
+ pad_to_multiple_of=None,
284
+ padding=False
285
+ )
286
+
287
+ # Metric
288
+ def compute_metrics(eval_preds):
289
+ preds, labels = eval_preds
290
+ if isinstance(preds, tuple):
291
+ preds = preds[0]
292
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
293
+ if data_args.ignore_pad_token_for_loss:
294
+ # Replace -100 in the labels as we can't decode them.
295
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
296
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
297
+
298
+ score_dict = {
299
+ "rouge-1": [],
300
+ "rouge-2": [],
301
+ "rouge-l": [],
302
+ "bleu-4": []
303
+ }
304
+ for pred, label in zip(decoded_preds, decoded_labels):
305
+ hypothesis = list(jieba.cut(pred))
306
+ reference = list(jieba.cut(label))
307
+ rouge = Rouge()
308
+ scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
309
+ result = scores[0]
310
+
311
+ for k, v in result.items():
312
+ score_dict[k].append(round(v["f"] * 100, 4))
313
+ bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
314
+ score_dict["bleu-4"].append(round(bleu_score * 100, 4))
315
+
316
+ for k, v in score_dict.items():
317
+ score_dict[k] = float(np.mean(v))
318
+ return score_dict
319
+
320
+ # Override the decoding parameters of Seq2SeqTrainer
321
+ training_args.generation_max_length = (
322
+ training_args.generation_max_length
323
+ if training_args.generation_max_length is not None
324
+ else data_args.val_max_target_length
325
+ )
326
+ training_args.generation_num_beams = (
327
+ data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
328
+ )
329
+ # Initialize our Trainer
330
+ trainer = Seq2SeqTrainer(
331
+ model=model,
332
+ args=training_args,
333
+ train_dataset=train_dataset if training_args.do_train else None,
334
+ eval_dataset=eval_dataset if training_args.do_eval else None,
335
+ tokenizer=tokenizer,
336
+ data_collator=data_collator,
337
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
338
+ save_changed=model_args.pre_seq_len is not None
339
+ )
340
+
341
+ # Training
342
+ if training_args.do_train:
343
+ checkpoint = None
344
+ if training_args.resume_from_checkpoint is not None:
345
+ checkpoint = training_args.resume_from_checkpoint
346
+ # elif last_checkpoint is not None:
347
+ # checkpoint = last_checkpoint
348
+ model.gradient_checkpointing_enable()
349
+ model.enable_input_require_grads()
350
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
351
+ # trainer.save_model() # Saves the tokenizer too for easy upload
352
+
353
+ metrics = train_result.metrics
354
+ max_train_samples = (
355
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
356
+ )
357
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
358
+
359
+ trainer.log_metrics("train", metrics)
360
+ trainer.save_metrics("train", metrics)
361
+ trainer.save_state()
362
+
363
+ # Evaluation
364
+ results = {}
365
+ max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
366
+ if training_args.do_eval:
367
+ logger.info("*** Evaluate ***")
368
+ metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
369
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
370
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
371
+
372
+ trainer.log_metrics("eval", metrics)
373
+ trainer.save_metrics("eval", metrics)
374
+
375
+ if training_args.do_predict:
376
+ logger.info("*** Predict ***")
377
+ predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
378
+ metrics = predict_results.metrics
379
+ max_predict_samples = (
380
+ data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
381
+ )
382
+ metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
383
+
384
+ trainer.log_metrics("predict", metrics)
385
+ trainer.save_metrics("predict", metrics)
386
+
387
+ if trainer.is_world_process_zero():
388
+ if training_args.predict_with_generate:
389
+ predictions = tokenizer.batch_decode(
390
+ predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
391
+ )
392
+ predictions = [pred.strip() for pred in predictions]
393
+ labels = tokenizer.batch_decode(
394
+ predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
395
+ )
396
+ labels = [label.strip() for label in labels]
397
+ output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
398
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
399
+ for p, l in zip(predictions, labels):
400
+ res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
401
+ writer.write(f"{res}\n")
402
+ return results
403
+
404
+
405
+ def _mp_fn(index):
406
+ # For xla_spawn (TPUs)
407
+ main()
408
+
409
+
410
+ if __name__ == "__main__":
411
+ main()