import os import numpy as np import torch import os import re import json import argparse import random from transformers import T5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration from model import T5ForConditionalGeneration, T5ForMultimodalGeneration from utils_data import img_shape, load_data_std, load_data_img, ScienceQADatasetStd, ScienceQADatasetImg from utils_prompt import * from utils_evaluate import get_scores from rich.table import Column, Table from rich import box from rich.console import Console console = Console(record=True) from torch import cuda import nltk import evaluate def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--data_root', type=str, default='data') parser.add_argument('--output_dir', type=str, default='experiments') parser.add_argument('--model', type=str, default='allenai/unifiedqa-t5-base') parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--lr', type=float, default=5e-5) parser.add_argument('--bs', type=int, default=16) parser.add_argument('--input_len', type=int, default=512) parser.add_argument('--output_len', type=int, default=64) parser.add_argument('--eval_bs', type=int, default=16) parser.add_argument('--eval_acc', type=int, default=None, help='evaluate accumulation step') parser.add_argument('--train_split', type=str, default='train', choices=['train', 'trainval', 'minitrain']) parser.add_argument('--val_split', type=str, default='val', choices=['test', 'val', 'minival']) parser.add_argument('--test_split', type=str, default='test', choices=['test', 'minitest']) parser.add_argument('--use_generate', action='store_true', help='only for baseline to improve inference speed') parser.add_argument('--final_eval', action='store_true', help='only evaluate the model at the final epoch') parser.add_argument('--user_msg', type=str, default="baseline", help='experiment type in the save_dir') parser.add_argument('--img_type', type=str, default=None, choices=['detr', 'clip', 'resnet'], help='type of image features') parser.add_argument('--eval_le', type=str, default=None, help='generated rationale for the dev set') parser.add_argument('--test_le', type=str, default=None, help='generated rationale for the test set') parser.add_argument('--evaluate_dir', type=str, default=None, help='the directory of model for evaluation') parser.add_argument('--caption_file', type=str, default='data/captions.json') parser.add_argument('--use_caption', action='store_true', help='use image captions or not') parser.add_argument('--prompt_format', type=str, default='QCM-A', help='prompt format template', choices=['QCM-A', 'QCM-LE', 'QCMG-A', 'QCM-LEA', 'QCM-ALE']) parser.add_argument('--seed', type=int, default=42, help='random seed') args = parser.parse_args() return args def T5Trainer( dataframe, args, ): torch.manual_seed(args.seed) # pytorch random seed np.random.seed(args.seed) # numpy random seed torch.backends.cudnn.deterministic = True if args.evaluate_dir is not None: args.model = args.evaluate_dir tokenizer = T5Tokenizer.from_pretrained(args.model) console.log(f"""[Model]: Loading {args.model}...\n""") console.log(f"[Data]: Reading data...\n") problems = dataframe['problems'] qids = dataframe['qids'] train_qids = qids['train'] test_qids = qids['test'] val_qids = qids['val'] if args.evaluate_dir is not None: save_dir = args.evaluate_dir else: model_name = args.model.replace("/","-") gpu_count = torch.cuda.device_count() save_dir = f"{args.output_dir}/{args.user_msg}_{model_name}_{args.img_type}_{args.prompt_format}_lr{args.lr}_bs{args.bs * gpu_count}_op{args.output_len}_ep{args.epoch}" if not os.path.exists(save_dir): os.mkdir(save_dir) padding_idx = tokenizer._convert_token_to_id(tokenizer.pad_token) if args.img_type is not None: patch_size = img_shape[args.img_type] model = T5ForMultimodalGeneration.from_pretrained(args.model, patch_size=patch_size, padding_idx=padding_idx, save_dir=save_dir) name_maps = dataframe['name_maps'] image_features = dataframe['image_features'] train_set = ScienceQADatasetImg( problems, train_qids, name_maps, tokenizer, args.input_len, args.output_len, args, image_features, ) eval_set = ScienceQADatasetImg( problems, val_qids, name_maps, tokenizer, args.input_len, args.output_len, args, image_features, args.eval_le, ) test_set = ScienceQADatasetImg( problems, test_qids, name_maps, tokenizer, args.input_len, args.output_len, args, image_features, args.test_le, ) else: model = T5ForConditionalGeneration.from_pretrained(args.model) train_set = ScienceQADatasetStd( problems, train_qids, tokenizer, args.input_len, args.output_len, args, ) eval_set = ScienceQADatasetStd( problems, val_qids, tokenizer, args.input_len, args.output_len, args, args.eval_le, ) test_set = ScienceQADatasetStd( problems, test_qids, tokenizer, args.input_len, args.output_len, args, args.test_le, ) datacollator = DataCollatorForSeq2Seq(tokenizer) print("model parameters: ", model.num_parameters()) def extract_ans(ans): pattern = re.compile(r'The answer is \(([A-Z])\)') res = pattern.findall(ans) if len(res) == 1: answer = res[0] # 'A', 'B', ... else: answer = "FAILED" return answer # accuracy for answer inference def compute_metrics_acc(eval_preds): if args.use_generate: preds, targets = eval_preds if isinstance(preds, tuple): preds = preds[0] else: preds = eval_preds.predictions[0] targets = eval_preds.label_ids preds = preds.argmax(axis=2) preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True) correct = 0 assert len(preds) == len(targets) for idx, pred in enumerate(preds): reference = targets[idx] reference = extract_ans(reference) extract_pred = extract_ans(pred) best_option = extract_pred if reference == best_option: correct +=1 return {'accuracy': 1.0*correct/len(targets)} # rougel for rationale generation metric = evaluate.load("rouge") def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] labels = [label.strip() for label in labels] preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] return preds, labels def compute_metrics_rougel(eval_preds): if args.use_generate: preds, targets = eval_preds if isinstance(preds, tuple): preds = preds[0] else: preds = eval_preds.predictions[0] targets = eval_preds.label_ids preds = preds.argmax(axis=2) preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True) decoded_preds, decoded_labels = postprocess_text(preds, targets) result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) result = {k: round(v * 100, 4) for k, v in result.items()} prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] result["gen_len"] = np.mean(prediction_lens) return result # only use the last model for evaluation to save time if args.final_eval: training_args = Seq2SeqTrainingArguments( save_dir, do_train=True if args.evaluate_dir is None else False, do_eval=False, evaluation_strategy="no", logging_strategy="steps", save_strategy="epoch", save_total_limit = 2, learning_rate= args.lr, eval_accumulation_steps=args.eval_acc, per_device_train_batch_size=args.bs, per_device_eval_batch_size=args.eval_bs, weight_decay=0.01, num_train_epochs=args.epoch, predict_with_generate=args.use_generate, report_to="none", ) # evaluate at each epoch else: training_args = Seq2SeqTrainingArguments( save_dir, do_train=True if args.evaluate_dir is None else False, do_eval=True, evaluation_strategy="epoch", logging_strategy="steps", save_strategy="epoch", save_total_limit = 2, learning_rate= args.lr, eval_accumulation_steps=args.eval_acc, per_device_train_batch_size=args.bs, per_device_eval_batch_size=args.eval_bs, weight_decay=0.01, num_train_epochs=args.epoch, metric_for_best_model="accuracy" if args.prompt_format != "QCM-LE" else "rougeL", predict_with_generate=args.use_generate, load_best_model_at_end=True, report_to="none", ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=train_set, eval_dataset=eval_set, data_collator=datacollator, tokenizer=tokenizer, compute_metrics = compute_metrics_acc if args.prompt_format != "QCM-LE" else compute_metrics_rougel ) if args.evaluate_dir is None: trainer.train() trainer.save_model(save_dir) metrics = trainer.evaluate(eval_dataset = test_set) trainer.log_metrics("test", metrics) trainer.save_metrics("test", metrics) predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len) if trainer.is_world_process_zero(): if args.use_generate: preds, targets = predict_results.predictions, predict_results.label_ids else: preds = predict_results.predictions[0] targets = predict_results.label_ids preds = preds.argmax(axis=2) preds = tokenizer.batch_decode( preds, skip_special_tokens=True, clean_up_tokenization_spaces=True ) targets = tokenizer.batch_decode( targets, skip_special_tokens=True, clean_up_tokenization_spaces=True ) results_ans = {} results_rationale = {} results_reference = {} num_fail = 0 for idx, qid in enumerate(test_qids): pred = preds[int(idx)] ref = targets[int(idx)] extract_pred = extract_ans(pred) if extract_pred != "FAILED": if extract_pred in args.options: extract_pred = args.options.index(extract_pred) else: extract_pred = random.choice(range(0,len(args.options))) else: num_fail += 1 extract_pred = random.choice(range(len(args.options))) # random choose one option results_ans[str(qid)] = extract_pred results_rationale[str(qid)] = pred results_reference[str(qid)] = ref scores = get_scores(results_ans, results_rationale, results_reference, os.path.join(args.data_root, "scienceqa/problems.json")) preds = [pred.strip() for pred in preds] output_data = { "num_fail": num_fail, "scores": scores, "preds": preds, "labels": targets} output_prediction_file = os.path.join(save_dir,"predictions_ans_test.json") with open(output_prediction_file, "w") as writer: writer.write(json.dumps(output_data, indent=4)) # generate the rationale for the eval set if args.prompt_format == "QCM-LE": torch.cuda.empty_cache() del predict_results, preds, targets predict_results = trainer.predict(test_dataset=eval_set, max_length=args.output_len) if trainer.is_world_process_zero(): if args.use_generate: preds, targets = predict_results.predictions, predict_results.label_ids else: preds = predict_results.predictions[0] targets = predict_results.label_ids preds = preds.argmax(axis=2) preds = tokenizer.batch_decode( preds, skip_special_tokens=True, clean_up_tokenization_spaces=True ) targets = tokenizer.batch_decode( targets, skip_special_tokens=True, clean_up_tokenization_spaces=True ) preds = [pred.strip() for pred in preds] output_data = {"preds": preds, "labels": targets} output_prediction_file = os.path.join(save_dir,"predictions_ans_eval.json") with open(output_prediction_file, "w") as writer: writer.write(json.dumps(output_data, indent=4)) if __name__ == '__main__': # training logger to log training progress training_logger = Table( Column("Epoch", justify="center"), Column("Steps", justify="center"), Column("Loss", justify="center"), title="Training Status", pad_edge=False, box=box.ASCII, ) args = parse_args() print("args",args) print('====Input Arguments====') print(json.dumps(vars(args), indent=2, sort_keys=False)) random.seed(args.seed) if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) if args.img_type is not None: problems, qids, name_maps, image_features = load_data_img(args) # probelms, test question ids, shot example ids dataframe = {'problems':problems, 'qids':qids, 'name_maps': name_maps, 'image_features': image_features} else: problems, qids = load_data_std(args) # probelms, test question ids, shot example ids dataframe = {'problems':problems, 'qids':qids} T5Trainer( dataframe=dataframe, args = args )