Spaces:
Runtime error
Runtime error
File size: 3,376 Bytes
b6e5241 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
import sys
import json
import torch
import argparse
from tqdm import tqdm
from transformers import BartForConditionalGeneration, BartTokenizer
sys.path.insert(0, '..')
from utils.text_utils import detokenize_sent
from utils.spacy_tokenizer import spacy_tokenize_gec, spacy_tokenize_bea19
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path')
parser.add_argument('-i', '--input_path')
parser.add_argument('-o', '--output_path')
parser.add_argument('--bea19', action='store_true')
args = parser.parse_args()
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained(args.model_path, force_bos_token_to_be_generated=True)
model.eval()
model.cuda()
def run_model(sents):
num_ret_seqs = 10
inp_max_len = 66
batch = [tokenizer(s, return_tensors='pt', padding='max_length', max_length=inp_max_len) for s in sents]
oidx2bidx = {} #original index to final batch index
final_batch = []
for oidx, elm in enumerate(batch):
if elm['input_ids'].size(1) <= inp_max_len:
oidx2bidx[oidx] = len(final_batch)
final_batch.append(elm)
batch = {key: torch.cat([elm[key] for elm in final_batch], dim=0) for key in final_batch[0]}
with torch.no_grad():
generated_ids = model.generate(batch['input_ids'].cuda(),
attention_mask=batch['attention_mask'].cuda(),
num_beams=10, num_return_sequences=num_ret_seqs, max_length=65)
_out = tokenizer.batch_decode(generated_ids.detach().cpu(), skip_special_tokens=True)
outs = []
for i in range(0, len(_out), num_ret_seqs):
outs.append(_out[i:i+num_ret_seqs])
final_outs = [[sents[oidx]] if oidx not in oidx2bidx else outs[oidx2bidx[oidx]] for oidx in range(len(sents))]
return final_outs
def run_for_wiki_yahoo_conll():
sents = [detokenize_sent(l.strip()) for l in open(args.input_path)]
b_size = 40
outs = []
for j in tqdm(range(0, len(sents), b_size)):
sents_batch = sents[j:j+b_size]
outs_batch = run_model(sents_batch)
for sent, preds in zip(sents_batch, outs_batch):
preds_detoked = [detokenize_sent(pred) for pred in preds]
preds = [' '.join(spacy_tokenize_gec(pred)) for pred in preds_detoked]
outs.append({'src': sent, 'preds': preds})
os.system('mkdir -p {}'.format(os.path.dirname(args.output_path)))
with open(args.output_path, 'w') as outf:
for out in outs:
print (out['preds'][0], file=outf)
def run_for_bea19():
sents = [detokenize_sent(l.strip()) for l in open(args.input_path)]
b_size = 40
outs = []
for j in tqdm(range(0, len(sents), b_size)):
sents_batch = sents[j:j+b_size]
outs_batch = run_model(sents_batch)
for sent, preds in zip(sents_batch, outs_batch):
preds_detoked = [detokenize_sent(pred) for pred in preds]
preds = [' '.join(spacy_tokenize_bea19(pred)) for pred in preds_detoked]
outs.append({'src': sent, 'preds': preds})
os.system('mkdir -p {}'.format(os.path.dirname(args.output_path)))
with open(args.output_path, 'w') as outf:
for out in outs:
print (out['preds'][0], file=outf)
if args.bea19:
run_for_bea19()
else:
run_for_wiki_yahoo_conll()
|