msmarco-german-mt5-base-v1 / train_script.py
nreimers's picture
upload
31e641c
import argparse
import logging
from torch.utils.data import Dataset, IterableDataset
import gzip
import json
from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
import sys
from datetime import datetime
import torch
import random
from shutil import copyfile
import os
import wandb
import random
import re
from datasets import load_dataset
import tqdm
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
parser = argparse.ArgumentParser()
parser.add_argument("--lang", required=True)
parser.add_argument("--model_name", default="google/mt5-base")
parser.add_argument("--epochs", default=4, type=int)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--max_source_length", default=320, type=int)
parser.add_argument("--max_target_length", default=64, type=int)
parser.add_argument("--eval_size", default=1000, type=int)
#parser.add_argument("--fp16", default=False, action='store_true')
args = parser.parse_args()
wandb.init(project="doc2query", name=f"{args.lang}-{args.model_name}")
def main():
############ Load dataset
queries = {}
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{args.lang}')['train']):
queries[row['id']] = row['text']
"""
collection = {}
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']):
collection[row['id']] = row['text']
"""
collection = load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']
train_pairs = []
eval_pairs = []
with open('qrels.train.tsv') as fIn:
for line in fIn:
qid, _, did, _ = line.strip().split("\t")
qid = int(qid)
did = int(did)
assert did == collection[did]['id']
text = collection[did]['text']
pair = (queries[qid], text)
if len(eval_pairs) < args.eval_size:
eval_pairs.append(pair)
else:
train_pairs.append(pair)
print(f"Train pairs: {len(train_pairs)}")
############ Model
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
save_steps = 1000
output_dir = 'output/'+args.lang+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print("Output dir:", output_dir)
# Write self to path
os.makedirs(output_dir, exist_ok=True)
train_script_path = os.path.join(output_dir, 'train_script.py')
copyfile(__file__, train_script_path)
with open(train_script_path, 'a') as fOut:
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
####
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
bf16=True,
#fp16=args.fp16,
#fp16_backend="amp",
per_device_train_batch_size=args.batch_size,
evaluation_strategy="steps",
save_steps=save_steps,
logging_steps=100,
eval_steps=save_steps, #logging_steps,
warmup_steps=1000,
save_total_limit=1,
num_train_epochs=args.epochs,
report_to="wandb",
)
############ Arguments
############ Load datasets
print("Input:", train_pairs[0][1])
print("Target:", train_pairs[0][0])
print("Input:", eval_pairs[0][1])
print("Target:", eval_pairs[0][0])
def data_collator(examples):
targets = [row[0] for row in examples]
inputs = [row[1] for row in examples]
label_pad_token_id = -100
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
# replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = torch.tensor(labels["input_ids"])
return model_inputs
## Define the trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_pairs,
eval_dataset=eval_pairs,
tokenizer=tokenizer,
data_collator=data_collator
)
### Save the model
train_result = trainer.train()
trainer.save_model()
if __name__ == "__main__":
main()
# Script was called via:
#python train_hf_trainer_multilingual.py --lang german