PANH's picture
Upload 15 files
ffca110 verified
raw
history blame
6.34 kB
from pytorch_lightning import Trainer, seed_everything
from alignscore.dataloader import DSTDataLoader
from alignscore.model import BERTAlignModel
from pytorch_lightning.callbacks import ModelCheckpoint
from argparse import ArgumentParser
import os
def train(datasets, args):
dm = DSTDataLoader(
dataset_config=datasets,
model_name=args.model_name,
sample_mode='seq',
train_batch_size=args.batch_size,
eval_batch_size=16,
num_workers=args.num_workers,
train_eval_split=0.95,
need_mlm=args.do_mlm
)
dm.setup()
model = BERTAlignModel(model=args.model_name, using_pretrained=args.use_pretrained_model,
adam_epsilon=args.adam_epsilon,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_steps_portion=args.warm_up_proportion
)
model.need_mlm = args.do_mlm
training_dataset_used = '_'.join(datasets.keys())
checkpoint_name = '_'.join((
f"{args.ckpt_comment}{args.model_name.replace('/', '-')}",
f"{'scratch_' if not args.use_pretrained_model else ''}{'no_mlm_' if not args.do_mlm else ''}{training_dataset_used}",
str(args.max_samples_per_dataset),
f"{args.batch_size}x{len(args.devices)}x{args.accumulate_grad_batch}"
))
checkpoint_callback = ModelCheckpoint(
dirpath=args.ckpt_save_path,
filename=checkpoint_name + "_{epoch:02d}_{step}",
every_n_train_steps=10000,
save_top_k=1
)
trainer = Trainer(
accelerator='gpu',
max_epochs=args.num_epoch,
devices=args.devices,
strategy="dp",
precision=32,
callbacks=[checkpoint_callback],
accumulate_grad_batches=args.accumulate_grad_batch
)
trainer.fit(model, datamodule=dm)
trainer.save_checkpoint(os.path.join(args.ckpt_save_path, f"{checkpoint_name}_final.ckpt"))
print("Training is finished.")
if __name__ == "__main__":
ALL_TRAINING_DATASETS = {
### NLI
'mnli': {'task_type': 'nli', 'data_path': 'mnli.json'},
'doc_nli': {'task_type': 'bin_nli', 'data_path': 'doc_nli.json'},
'snli': {'task_type': 'nli', 'data_path': 'snli.json'},
'anli_r1': {'task_type': 'nli', 'data_path': 'anli_r1.json'},
'anli_r2': {'task_type': 'nli', 'data_path': 'anli_r2.json'},
'anli_r3': {'task_type': 'nli', 'data_path': 'anli_r3.json'},
### fact checking
'nli_fever': {'task_type': 'fact_checking', 'data_path': 'nli_fever.json'},
'vitaminc': {'task_type': 'fact_checking', 'data_path': 'vitaminc.json'},
### paraphrase
'paws': {'task_type': 'paraphrase', 'data_path': 'paws.json'},
'paws_qqp': {'task_type': 'paraphrase', 'data_path': 'paws_qqp.json'},
'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'paws_unlabeled.json'},
'qqp': {'task_type': 'paraphrase', 'data_path': 'qqp.json'},
'wiki103': {'task_type': 'paraphrase', 'data_path': 'wiki103.json'},
### QA
'squad_v2': {'task_type': 'qa', 'data_path': 'squad_v2_new.json'},
'race': {'task_type': 'qa', 'data_path': 'race.json'},
'adversarial_qa': {'task_type': 'qa', 'data_path': 'adversarial_qa.json'},
'drop': {'task_type': 'qa', 'data_path': 'drop.json'},
'hotpot_qa_distractor': {'task_type': 'qa', 'data_path': 'hotpot_qa_distractor.json'},
'hotpot_qa_fullwiki': {'task_type': 'qa', 'data_path': 'hotpot_qa_fullwiki.json'},
'newsqa': {'task_type': 'qa', 'data_path': 'newsqa.json'},
'quoref': {'task_type': 'qa', 'data_path': 'quoref.json'},
'ropes': {'task_type': 'qa', 'data_path': 'ropes.json'},
'boolq': {'task_type': 'qa', 'data_path': 'boolq.json'},
'eraser_multi_rc': {'task_type': 'qa', 'data_path': 'eraser_multi_rc.json'},
'quail': {'task_type': 'qa', 'data_path': 'quail.json'},
'sciq': {'task_type': 'qa', 'data_path': 'sciq.json'},
'strategy_qa': {'task_type': 'qa', 'data_path': 'strategy_qa.json'},
### Coreference
'gap': {'task_type': 'coreference', 'data_path': 'gap.json'},
### Summarization
'wikihow': {'task_type': 'summarization', 'data_path': 'wikihow.json'},
### Information Retrieval
'msmarco': {'task_type': 'ir', 'data_path': 'msmarco.json'},
### STS
'stsb': {'task_type': 'sts', 'data_path': 'stsb.json'},
'sick': {'task_type': 'sts', 'data_path': 'sick.json'},
}
parser = ArgumentParser()
parser.add_argument('--seed', type=int, default=2022)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--accumulate-grad-batch', type=int, default=1)
parser.add_argument('--num-epoch', type=int, default=3)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--warm-up-proportion', type=float, default=0.06)
parser.add_argument('--adam-epsilon', type=float, default=1e-6)
parser.add_argument('--weight-decay', type=float, default=0.1)
parser.add_argument('--learning-rate', type=float, default=1e-5)
parser.add_argument('--val-check-interval', type=float, default=1. / 4)
parser.add_argument('--devices', nargs='+', type=int, required=True)
parser.add_argument('--model-name', type=str, default="roberta-large")
parser.add_argument('--ckpt-save-path', type=str, required=True)
parser.add_argument('--ckpt-comment', type=str, default="")
parser.add_argument('--trainin-datasets', nargs='+', type=str, default=list(ALL_TRAINING_DATASETS.keys()), choices=list(ALL_TRAINING_DATASETS.keys()))
parser.add_argument('--data-path', type=str, required=True)
parser.add_argument('--max-samples-per-dataset', type=int, default=500000)
parser.add_argument('--do-mlm', type=bool, default=False)
parser.add_argument('--use-pretrained-model', type=bool, default=True)
args = parser.parse_args()
seed_everything(args.seed)
datasets = {
name: {
**ALL_TRAINING_DATASETS[name],
"size": args.max_samples_per_dataset,
"data_path": os.path.join(args.data_path, ALL_TRAINING_DATASETS[name]['data_path'])
}
for name in args.trainin_datasets
}
train(datasets, args)