michal-stefanik commited on
Commit
b25e454
1 Parent(s): 7263f32

Upload train_mt5_qa_en+cs.py

Browse files
Files changed (1) hide show
  1. train_mt5_qa_en+cs.py +80 -0
train_mt5_qa_en+cs.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from adaptor.adapter import Adapter
4
+ from adaptor.evaluators.generative import BLEU
5
+ from adaptor.lang_module import LangModule
6
+ from adaptor.objectives.seq2seq import Sequence2Sequence
7
+ from adaptor.schedules import ParallelSchedule
8
+ from adaptor.utils import AdaptationArguments, StoppingStrategy
9
+ from datasets import load_dataset
10
+
11
+ training_arguments = AdaptationArguments(output_dir="train_dir",
12
+ learning_rate=5e-5, # we set LR=2e-4 for pre-training experiments
13
+ # stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
14
+ stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
15
+ do_train=True,
16
+ do_eval=True,
17
+ warmup_steps=1000,
18
+ max_steps=100000,
19
+ gradient_accumulation_steps=4,
20
+ eval_steps=100,
21
+ logging_steps=10,
22
+ save_steps=1000,
23
+ num_train_epochs=50,
24
+ evaluation_strategy="steps",
25
+ remove_unused_columns=False)
26
+
27
+ # lang_module = LangModule("google/mt5-small")
28
+ lang_module = LangModule("Helsinki-NLP/opus-mt-en-cs")
29
+
30
+ metrics_args = {"additional_sep_char": "▁"}
31
+
32
+ val_metrics = [BLEU(**metrics_args, decides_convergence=True)]
33
+
34
+ squad_en = load_dataset("squad")
35
+ squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000)
36
+
37
+ train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"],
38
+ squad_train["context"])]
39
+ val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"],
40
+ squad_en["validation"]["context"])]
41
+ train_answers_en = [a["text"][0] for a in squad_train["answers"]]
42
+ val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]]
43
+
44
+ generative_qa_en = Sequence2Sequence(lang_module,
45
+ texts_or_path=train_contexts_questions_en,
46
+ val_texts_or_path=val_contexts_questions_en[:200],
47
+ labels_or_path=train_answers_en,
48
+ val_labels_or_path=val_answers_en[:200],
49
+ batch_size=8,
50
+ val_evaluators=val_metrics,
51
+ objective_id="SQUAD-en")
52
+
53
+ squad_dataset = json.load(open("data/czech_squad.json"))
54
+
55
+ contexts_questions = []
56
+ answers = []
57
+
58
+ for i, entry in squad_dataset.items():
59
+ contexts_questions.append("question: %s context: %s" % (entry["question"], entry["context"]))
60
+ answers.append(entry["answers"]["text"][0])
61
+
62
+ train_contexts_questions = contexts_questions[:-200]
63
+ val_contexts_questions = contexts_questions[-200:]
64
+ train_answers = answers[:-200]
65
+ val_answers = answers[-200:]
66
+
67
+ generative_qa_cs = Sequence2Sequence(lang_module,
68
+ texts_or_path=train_contexts_questions,
69
+ val_texts_or_path=val_contexts_questions[:200],
70
+ labels_or_path=train_answers,
71
+ val_labels_or_path=val_answers[:200],
72
+ batch_size=8,
73
+ val_evaluators=val_metrics,
74
+ objective_id="SQUAD-cs")
75
+
76
+ schedule = ParallelSchedule(objectives=[generative_qa_en, generative_qa_cs],
77
+ args=training_arguments)
78
+
79
+ adapter = Adapter(lang_module, schedule, args=training_arguments)
80
+ adapter.train()