import logging from dataclasses import dataclass, field from typing import Optional import torch from tqdm.auto import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser from data_collator import T2TDataCollator device = 'cuda' if torch.cuda.is_available else 'cpu' logger = logging.getLogger(__name__) @dataclass class EvalArguments: model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) valid_file_path: str = field( metadata={"help": "Path for cached valid dataset"} ) model_type: str = field(metadata={"help": "One of 't5', 'bart'"}) tokenizer_name_or_path: Optional[str] = field( default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) num_beams: Optional[int] = field( default=4, metadata={"help": "num_beams to use for decoding"} ) max_decoding_length: Optional[int] = field( default=32, metadata={"help": "maximum length for decoding"} ) output_path: Optional[str] = field( default="hypothesis.txt", metadata={"help": "path to save the generated questions."} ) def get_predictions(model, tokenizer, data_loader, num_beams=4, max_length=32, length_penalty=1): model.to(device) predictions = [] model.eval() with torch.no_grad(): for batch in tqdm(data_loader): outs = model.generate( input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), num_beams=num_beams, max_length=max_length, length_penalty=length_penalty, ) prediction = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs] predictions.extend(prediction) return predictions def main(): parser = HfArgumentParser((EvalArguments,)) args = parser.parse_args_into_dataclasses()[0] tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path, ) model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path) valid_dataset = torch.load(args.valid_file_path) collator = T2TDataCollator( tokenizer=tokenizer, model_type=args.model_type, mode="inference" ) loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, collate_fn=collator) predictions = get_predictions( model=model, tokenizer=tokenizer, data_loader=loader, num_beams=args.num_beams, max_length=args.max_decoding_length ) with open(args.output_path, 'w') as f: f.write("\n".join(predictions)) logging.info(f"Output saved at {args.output_path}") if __name__ == "__main__": main()