import os import sys import json import torch import argparse from tqdm import tqdm from transformers import BartForConditionalGeneration, BartTokenizer sys.path.insert(0, '..') from utils.text_utils import detokenize_sent from utils.spacy_tokenizer import spacy_tokenize_gec, spacy_tokenize_bea19 parser = argparse.ArgumentParser() parser.add_argument('-m', '--model_path') parser.add_argument('-i', '--input_path') parser.add_argument('-o', '--output_path') parser.add_argument('--bea19', action='store_true') args = parser.parse_args() tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") model = BartForConditionalGeneration.from_pretrained(args.model_path, force_bos_token_to_be_generated=True) model.eval() model.cuda() def run_model(sents): num_ret_seqs = 10 inp_max_len = 66 batch = [tokenizer(s, return_tensors='pt', padding='max_length', max_length=inp_max_len) for s in sents] oidx2bidx = {} #original index to final batch index final_batch = [] for oidx, elm in enumerate(batch): if elm['input_ids'].size(1) <= inp_max_len: oidx2bidx[oidx] = len(final_batch) final_batch.append(elm) batch = {key: torch.cat([elm[key] for elm in final_batch], dim=0) for key in final_batch[0]} with torch.no_grad(): generated_ids = model.generate(batch['input_ids'].cuda(), attention_mask=batch['attention_mask'].cuda(), num_beams=10, num_return_sequences=num_ret_seqs, max_length=65) _out = tokenizer.batch_decode(generated_ids.detach().cpu(), skip_special_tokens=True) outs = [] for i in range(0, len(_out), num_ret_seqs): outs.append(_out[i:i+num_ret_seqs]) final_outs = [[sents[oidx]] if oidx not in oidx2bidx else outs[oidx2bidx[oidx]] for oidx in range(len(sents))] return final_outs def run_for_wiki_yahoo_conll(): sents = [detokenize_sent(l.strip()) for l in open(args.input_path)] b_size = 40 outs = [] for j in tqdm(range(0, len(sents), b_size)): sents_batch = sents[j:j+b_size] outs_batch = run_model(sents_batch) for sent, preds in zip(sents_batch, outs_batch): preds_detoked = [detokenize_sent(pred) for pred in preds] preds = [' '.join(spacy_tokenize_gec(pred)) for pred in preds_detoked] outs.append({'src': sent, 'preds': preds}) os.system('mkdir -p {}'.format(os.path.dirname(args.output_path))) with open(args.output_path, 'w') as outf: for out in outs: print (out['preds'][0], file=outf) def run_for_bea19(): sents = [detokenize_sent(l.strip()) for l in open(args.input_path)] b_size = 40 outs = [] for j in tqdm(range(0, len(sents), b_size)): sents_batch = sents[j:j+b_size] outs_batch = run_model(sents_batch) for sent, preds in zip(sents_batch, outs_batch): preds_detoked = [detokenize_sent(pred) for pred in preds] preds = [' '.join(spacy_tokenize_bea19(pred)) for pred in preds_detoked] outs.append({'src': sent, 'preds': preds}) os.system('mkdir -p {}'.format(os.path.dirname(args.output_path))) with open(args.output_path, 'w') as outf: for out in outs: print (out['preds'][0], file=outf) if args.bea19: run_for_bea19() else: run_for_wiki_yahoo_conll()