File size: 3,376 Bytes
b6e5241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import sys
import json
import torch
import argparse
from tqdm import tqdm
from transformers import BartForConditionalGeneration, BartTokenizer

sys.path.insert(0, '..')
from utils.text_utils import detokenize_sent
from utils.spacy_tokenizer import spacy_tokenize_gec, spacy_tokenize_bea19

parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path')
parser.add_argument('-i', '--input_path')
parser.add_argument('-o', '--output_path')
parser.add_argument('--bea19', action='store_true')
args = parser.parse_args()


tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained(args.model_path, force_bos_token_to_be_generated=True)
model.eval()
model.cuda()


def run_model(sents):
    num_ret_seqs = 10
    inp_max_len = 66
    batch = [tokenizer(s, return_tensors='pt', padding='max_length', max_length=inp_max_len) for s in sents]
    oidx2bidx = {} #original index to final batch index
    final_batch = []
    for oidx, elm in enumerate(batch):
        if elm['input_ids'].size(1) <= inp_max_len:
            oidx2bidx[oidx] = len(final_batch)
            final_batch.append(elm)
    batch = {key: torch.cat([elm[key] for elm in final_batch], dim=0) for key in final_batch[0]}
    with torch.no_grad():
        generated_ids = model.generate(batch['input_ids'].cuda(),
                                attention_mask=batch['attention_mask'].cuda(),
                                num_beams=10, num_return_sequences=num_ret_seqs, max_length=65)
    _out = tokenizer.batch_decode(generated_ids.detach().cpu(), skip_special_tokens=True)
    outs = []
    for i in range(0, len(_out), num_ret_seqs):
        outs.append(_out[i:i+num_ret_seqs])
    final_outs = [[sents[oidx]] if oidx not in oidx2bidx else outs[oidx2bidx[oidx]] for oidx in range(len(sents))]
    return final_outs


def run_for_wiki_yahoo_conll():
    sents = [detokenize_sent(l.strip()) for l in open(args.input_path)]
    b_size = 40
    outs = []
    for j in tqdm(range(0, len(sents), b_size)):
        sents_batch = sents[j:j+b_size]
        outs_batch = run_model(sents_batch)
        for sent, preds in zip(sents_batch, outs_batch):
            preds_detoked = [detokenize_sent(pred) for pred in preds]
            preds = [' '.join(spacy_tokenize_gec(pred)) for pred in preds_detoked]
            outs.append({'src': sent, 'preds': preds})
    os.system('mkdir -p {}'.format(os.path.dirname(args.output_path)))
    with open(args.output_path, 'w') as outf:
        for out in outs:
            print (out['preds'][0], file=outf)


def run_for_bea19():
    sents = [detokenize_sent(l.strip()) for l in open(args.input_path)]
    b_size = 40
    outs = []
    for j in tqdm(range(0, len(sents), b_size)):
        sents_batch = sents[j:j+b_size]
        outs_batch = run_model(sents_batch)
        for sent, preds in zip(sents_batch, outs_batch):
            preds_detoked = [detokenize_sent(pred) for pred in preds]
            preds = [' '.join(spacy_tokenize_bea19(pred)) for pred in preds_detoked]
            outs.append({'src': sent, 'preds': preds})
    os.system('mkdir -p {}'.format(os.path.dirname(args.output_path)))
    with open(args.output_path, 'w') as outf:
        for out in outs:
            print (out['preds'][0], file=outf)


if args.bea19:
    run_for_bea19()
else:
    run_for_wiki_yahoo_conll()