File size: 4,061 Bytes
b25e454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import json

from adaptor.adapter import Adapter
from adaptor.evaluators.generative import BLEU
from adaptor.lang_module import LangModule
from adaptor.objectives.seq2seq import Sequence2Sequence
from adaptor.schedules import ParallelSchedule
from adaptor.utils import AdaptationArguments, StoppingStrategy
from datasets import load_dataset

training_arguments = AdaptationArguments(output_dir="train_dir",
                                         learning_rate=5e-5,  # we set LR=2e-4 for pre-training experiments
                                         # stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
                                         stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
                                         do_train=True,
                                         do_eval=True,
                                         warmup_steps=1000,
                                         max_steps=100000,
                                         gradient_accumulation_steps=4,
                                         eval_steps=100,
                                         logging_steps=10,
                                         save_steps=1000,
                                         num_train_epochs=50,
                                         evaluation_strategy="steps",
                                         remove_unused_columns=False)

# lang_module = LangModule("google/mt5-small")
lang_module = LangModule("Helsinki-NLP/opus-mt-en-cs")

metrics_args = {"additional_sep_char": "▁"}

val_metrics = [BLEU(**metrics_args, decides_convergence=True)]

squad_en = load_dataset("squad")
squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000)

train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"],
                                                                                   squad_train["context"])]
val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"],
                                                                                 squad_en["validation"]["context"])]
train_answers_en = [a["text"][0] for a in squad_train["answers"]]
val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]]

generative_qa_en = Sequence2Sequence(lang_module,
                                     texts_or_path=train_contexts_questions_en,
                                     val_texts_or_path=val_contexts_questions_en[:200],
                                     labels_or_path=train_answers_en,
                                     val_labels_or_path=val_answers_en[:200],
                                     batch_size=8,
                                     val_evaluators=val_metrics,
                                     objective_id="SQUAD-en")

squad_dataset = json.load(open("data/czech_squad.json"))

contexts_questions = []
answers = []

for i, entry in squad_dataset.items():
    contexts_questions.append("question: %s context: %s" % (entry["question"], entry["context"]))
    answers.append(entry["answers"]["text"][0])

train_contexts_questions = contexts_questions[:-200]
val_contexts_questions = contexts_questions[-200:]
train_answers = answers[:-200]
val_answers = answers[-200:]

generative_qa_cs = Sequence2Sequence(lang_module,
                                     texts_or_path=train_contexts_questions,
                                     val_texts_or_path=val_contexts_questions[:200],
                                     labels_or_path=train_answers,
                                     val_labels_or_path=val_answers[:200],
                                     batch_size=8,
                                     val_evaluators=val_metrics,
                                     objective_id="SQUAD-cs")

schedule = ParallelSchedule(objectives=[generative_qa_en, generative_qa_cs],
                            args=training_arguments)

adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()