michal-stefanik commited on
Commit
36e28a3
1 Parent(s): e8fa15d

Upload train_roberta_extractive_qa.py

Browse files
Files changed (1) hide show
  1. train_roberta_extractive_qa.py +100 -0
train_roberta_extractive_qa.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: BEFORE RUNNING: pip install git+https://github.com/gaussalgo/adaptor.git@QA_to_objectives
2
+
3
+ from adaptor.objectives.question_answering import ExtractiveQA
4
+ import json
5
+
6
+ from adaptor.adapter import Adapter
7
+ from adaptor.evaluators.question_answering import BLEUForQA
8
+ from adaptor.lang_module import LangModule
9
+ from adaptor.schedules import ParallelSchedule
10
+ from adaptor.utils import AdaptationArguments, StoppingStrategy
11
+
12
+ # custom classes
13
+ from datasets import load_dataset
14
+
15
+ model_name = "bert-base-multilingual-cased"
16
+
17
+ lang_module = LangModule(model_name)
18
+
19
+ training_arguments = AdaptationArguments(output_dir="train_dir",
20
+ learning_rate=4e-5,
21
+ stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
22
+ do_train=True,
23
+ do_eval=True,
24
+ warmup_steps=1000,
25
+ max_steps=100000,
26
+ gradient_accumulation_steps=1,
27
+ eval_steps=1,
28
+ logging_steps=10,
29
+ save_steps=1000,
30
+ num_train_epochs=30,
31
+ evaluation_strategy="steps")
32
+
33
+ val_metrics = [BLEUForQA(decides_convergence=True)]
34
+
35
+ # get eval and train dataset
36
+ squad_dataset = json.load(open("data/czech_squad.json"))
37
+ questions = []
38
+ contexts = []
39
+ answers = []
40
+ skipped = 0
41
+
42
+ for i, entry in squad_dataset.items():
43
+ if entry["answers"]["text"][0] in entry["context"]:
44
+ # and len(entry["context"]) < 1024: # these are characters, will be automatically truncated from input anyway
45
+ questions.append(entry["question"])
46
+ contexts.append(entry["context"])
47
+ answers.append(entry["answers"]["text"][0])
48
+ else:
49
+ skipped += 1
50
+
51
+ print("Skipped examples from SQuAD-cs: %s" % skipped)
52
+
53
+ train_questions = questions[:-200]
54
+ val_questions = questions[-200:]
55
+
56
+ train_answers = answers[:-200]
57
+ val_answers = answers[-200:]
58
+
59
+ train_context = contexts[:-200]
60
+ val_context = contexts[-200:]
61
+
62
+ # declaration of extractive question answering objective
63
+ generative_qa_cs = ExtractiveQA(lang_module,
64
+ texts_or_path=train_questions,
65
+ text_pair_or_path=train_context,
66
+ labels_or_path=train_answers,
67
+ val_texts_or_path=val_questions,
68
+ val_text_pair_or_path=val_context,
69
+ val_labels_or_path=val_answers,
70
+ batch_size=1,
71
+ val_evaluators=val_metrics,
72
+ objective_id="SQUAD-cs")
73
+
74
+ # english SQuAD
75
+ squad_en = load_dataset("squad")
76
+ squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000)
77
+
78
+ train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"],
79
+ squad_train["context"])]
80
+ val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"],
81
+ squad_en["validation"]["context"])]
82
+ train_answers_en = [a["text"][0] for a in squad_train["answers"]]
83
+ val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]]
84
+
85
+ generative_qa_en = ExtractiveQA(lang_module,
86
+ texts_or_path=squad_train["question"],
87
+ text_pair_or_path=squad_train["context"],
88
+ labels_or_path=[a["text"][0] for a in squad_train["answers"]],
89
+ val_texts_or_path=squad_en["validation"]["question"][:200],
90
+ val_text_pair_or_path=squad_en["validation"]["context"][:200],
91
+ val_labels_or_path=[a["text"][0] for a in squad_en["validation"]["answers"]][:200],
92
+ batch_size=10,
93
+ val_evaluators=val_metrics,
94
+ objective_id="SQUAD-en")
95
+
96
+ schedule = ParallelSchedule(objectives=[generative_qa_cs, generative_qa_en],
97
+ args=training_arguments)
98
+
99
+ adapter = Adapter(lang_module, schedule, args=training_arguments)
100
+ adapter.train()