import json from adaptor.adapter import Adapter from adaptor.evaluators.generative import BLEU from adaptor.lang_module import LangModule from adaptor.objectives.seq2seq import Sequence2Sequence from adaptor.schedules import ParallelSchedule from adaptor.utils import AdaptationArguments, StoppingStrategy from datasets import load_dataset training_arguments = AdaptationArguments(output_dir="train_dir", learning_rate=5e-5, # we set LR=2e-4 for pre-training experiments # stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, do_train=True, do_eval=True, warmup_steps=1000, max_steps=100000, gradient_accumulation_steps=4, eval_steps=100, logging_steps=10, save_steps=1000, num_train_epochs=50, evaluation_strategy="steps", remove_unused_columns=False) # lang_module = LangModule("google/mt5-small") lang_module = LangModule("Helsinki-NLP/opus-mt-en-cs") metrics_args = {"additional_sep_char": "▁"} val_metrics = [BLEU(**metrics_args, decides_convergence=True)] squad_en = load_dataset("squad") squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000) train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"], squad_train["context"])] val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"], squad_en["validation"]["context"])] train_answers_en = [a["text"][0] for a in squad_train["answers"]] val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]] generative_qa_en = Sequence2Sequence(lang_module, texts_or_path=train_contexts_questions_en, val_texts_or_path=val_contexts_questions_en[:200], labels_or_path=train_answers_en, val_labels_or_path=val_answers_en[:200], batch_size=8, val_evaluators=val_metrics, objective_id="SQUAD-en") squad_dataset = json.load(open("data/czech_squad.json")) contexts_questions = [] answers = [] for i, entry in squad_dataset.items(): contexts_questions.append("question: %s context: %s" % (entry["question"], entry["context"])) answers.append(entry["answers"]["text"][0]) train_contexts_questions = contexts_questions[:-200] val_contexts_questions = contexts_questions[-200:] train_answers = answers[:-200] val_answers = answers[-200:] generative_qa_cs = Sequence2Sequence(lang_module, texts_or_path=train_contexts_questions, val_texts_or_path=val_contexts_questions[:200], labels_or_path=train_answers, val_labels_or_path=val_answers[:200], batch_size=8, val_evaluators=val_metrics, objective_id="SQUAD-cs") schedule = ParallelSchedule(objectives=[generative_qa_en, generative_qa_cs], args=training_arguments) adapter = Adapter(lang_module, schedule, args=training_arguments) adapter.train()