flynn-chen
all
97ec4dd
raw
history blame
2.93 kB
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser
from data_collator import T2TDataCollator
device = 'cuda' if torch.cuda.is_available else 'cpu'
logger = logging.getLogger(__name__)
@dataclass
class EvalArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
valid_file_path: str = field(
metadata={"help": "Path for cached valid dataset"}
)
model_type: str = field(metadata={"help": "One of 't5', 'bart'"})
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
num_beams: Optional[int] = field(
default=4,
metadata={"help": "num_beams to use for decoding"}
)
max_decoding_length: Optional[int] = field(
default=32,
metadata={"help": "maximum length for decoding"}
)
output_path: Optional[str] = field(
default="hypothesis.txt",
metadata={"help": "path to save the generated questions."}
)
def get_predictions(model, tokenizer, data_loader, num_beams=4, max_length=32, length_penalty=1):
model.to(device)
predictions = []
model.eval()
with torch.no_grad():
for batch in tqdm(data_loader):
outs = model.generate(
input_ids=batch['input_ids'].to(device),
attention_mask=batch['attention_mask'].to(device),
num_beams=num_beams,
max_length=max_length,
length_penalty=length_penalty,
)
prediction = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
predictions.extend(prediction)
return predictions
def main():
parser = HfArgumentParser((EvalArguments,))
args = parser.parse_args_into_dataclasses()[0]
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
valid_dataset = torch.load(args.valid_file_path)
collator = T2TDataCollator(
tokenizer=tokenizer,
model_type=args.model_type,
mode="inference"
)
loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, collate_fn=collator)
predictions = get_predictions(
model=model,
tokenizer=tokenizer,
data_loader=loader,
num_beams=args.num_beams,
max_length=args.max_decoding_length
)
with open(args.output_path, 'w') as f:
f.write("\n".join(predictions))
logging.info(f"Output saved at {args.output_path}")
if __name__ == "__main__":
main()