|
import logging |
|
|
|
import torch |
|
from typing import Dict |
|
from functools import partial |
|
from transformers.utils.logging import enable_explicit_format |
|
from transformers.trainer_callback import PrinterCallback |
|
from transformers import ( |
|
AutoTokenizer, |
|
HfArgumentParser, |
|
EvalPrediction, |
|
Trainer, |
|
set_seed, |
|
PreTrainedTokenizerFast |
|
) |
|
|
|
from logger_config import logger, LoggerCallback |
|
from config import Arguments |
|
from trainers import BiencoderTrainer |
|
from loaders import RetrievalDataLoader |
|
from collators import BiencoderCollator |
|
from metrics import accuracy, batch_mrr |
|
from models import BiencoderModel |
|
|
|
|
|
def _common_setup(args: Arguments): |
|
if args.process_index > 0: |
|
logger.setLevel(logging.WARNING) |
|
enable_explicit_format() |
|
set_seed(args.seed) |
|
|
|
|
|
def _compute_metrics(args: Arguments, eval_pred: EvalPrediction) -> Dict[str, float]: |
|
|
|
preds = eval_pred.predictions |
|
scores = torch.tensor(preds[-1]).float() |
|
labels = torch.arange(0, scores.shape[0], dtype=torch.long) * args.train_n_passages |
|
labels = labels % scores.shape[1] |
|
|
|
topk_metrics = accuracy(output=scores, target=labels, topk=(1, 3)) |
|
mrr = batch_mrr(output=scores, target=labels) |
|
|
|
return {'mrr': mrr, 'acc1': topk_metrics[0], 'acc3': topk_metrics[1]} |
|
|
|
|
|
def main(): |
|
parser = HfArgumentParser((Arguments,)) |
|
args: Arguments = parser.parse_args_into_dataclasses()[0] |
|
_common_setup(args) |
|
logger.info('Args={}'.format(str(args))) |
|
|
|
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path) |
|
model: BiencoderModel = BiencoderModel.build(args=args) |
|
logger.info(model) |
|
logger.info('Vocab size: {}'.format(len(tokenizer))) |
|
|
|
data_collator = BiencoderCollator( |
|
tokenizer=tokenizer, |
|
pad_to_multiple_of=8 if args.fp16 else None) |
|
|
|
retrieval_data_loader = RetrievalDataLoader(args=args, tokenizer=tokenizer) |
|
train_dataset = retrieval_data_loader.train_dataset |
|
eval_dataset = retrieval_data_loader.eval_dataset |
|
|
|
trainer: Trainer = BiencoderTrainer( |
|
model=model, |
|
args=args, |
|
train_dataset=train_dataset if args.do_train else None, |
|
eval_dataset=eval_dataset if args.do_eval else None, |
|
data_collator=data_collator, |
|
compute_metrics=partial(_compute_metrics, args), |
|
tokenizer=tokenizer, |
|
) |
|
trainer.remove_callback(PrinterCallback) |
|
trainer.add_callback(LoggerCallback) |
|
retrieval_data_loader.trainer = trainer |
|
model.trainer = trainer |
|
|
|
if args.do_train: |
|
train_result = trainer.train() |
|
trainer.save_model() |
|
|
|
metrics = train_result.metrics |
|
metrics["train_samples"] = len(train_dataset) |
|
|
|
trainer.log_metrics("train", metrics) |
|
trainer.save_metrics("train", metrics) |
|
|
|
if args.do_eval: |
|
logger.info("*** Evaluate ***") |
|
metrics = trainer.evaluate(metric_key_prefix="eval") |
|
metrics["eval_samples"] = len(eval_dataset) |
|
|
|
trainer.log_metrics("eval", metrics) |
|
trainer.save_metrics("eval", metrics) |
|
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|