mt5-base-generative-QA_en-cs / train_mt5_qa_en+cs.py
michal-stefanik's picture
Upload train_mt5_qa_en+cs.py
b25e454
import json
from adaptor.adapter import Adapter
from adaptor.evaluators.generative import BLEU
from adaptor.lang_module import LangModule
from adaptor.objectives.seq2seq import Sequence2Sequence
from adaptor.schedules import ParallelSchedule
from adaptor.utils import AdaptationArguments, StoppingStrategy
from datasets import load_dataset
training_arguments = AdaptationArguments(output_dir="train_dir",
learning_rate=5e-5, # we set LR=2e-4 for pre-training experiments
# stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
do_train=True,
do_eval=True,
warmup_steps=1000,
max_steps=100000,
gradient_accumulation_steps=4,
eval_steps=100,
logging_steps=10,
save_steps=1000,
num_train_epochs=50,
evaluation_strategy="steps",
remove_unused_columns=False)
# lang_module = LangModule("google/mt5-small")
lang_module = LangModule("Helsinki-NLP/opus-mt-en-cs")
metrics_args = {"additional_sep_char": "▁"}
val_metrics = [BLEU(**metrics_args, decides_convergence=True)]
squad_en = load_dataset("squad")
squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000)
train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"],
squad_train["context"])]
val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"],
squad_en["validation"]["context"])]
train_answers_en = [a["text"][0] for a in squad_train["answers"]]
val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]]
generative_qa_en = Sequence2Sequence(lang_module,
texts_or_path=train_contexts_questions_en,
val_texts_or_path=val_contexts_questions_en[:200],
labels_or_path=train_answers_en,
val_labels_or_path=val_answers_en[:200],
batch_size=8,
val_evaluators=val_metrics,
objective_id="SQUAD-en")
squad_dataset = json.load(open("data/czech_squad.json"))
contexts_questions = []
answers = []
for i, entry in squad_dataset.items():
contexts_questions.append("question: %s context: %s" % (entry["question"], entry["context"]))
answers.append(entry["answers"]["text"][0])
train_contexts_questions = contexts_questions[:-200]
val_contexts_questions = contexts_questions[-200:]
train_answers = answers[:-200]
val_answers = answers[-200:]
generative_qa_cs = Sequence2Sequence(lang_module,
texts_or_path=train_contexts_questions,
val_texts_or_path=val_contexts_questions[:200],
labels_or_path=train_answers,
val_labels_or_path=val_answers[:200],
batch_size=8,
val_evaluators=val_metrics,
objective_id="SQUAD-cs")
schedule = ParallelSchedule(objectives=[generative_qa_en, generative_qa_cs],
args=training_arguments)
adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()