|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import io |
|
import sys |
|
import argparse |
|
import torch |
|
|
|
from src.utils import AttrDict |
|
from src.utils import bool_flag, initialize_exp |
|
from src.data.dictionary import Dictionary |
|
from src.model.transformer import TransformerModel |
|
|
|
|
|
def get_parser(): |
|
""" |
|
Generate a parameters parser. |
|
""" |
|
|
|
parser = argparse.ArgumentParser(description="Translate sentences") |
|
|
|
|
|
parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path") |
|
parser.add_argument("--exp_name", type=str, default="", help="Experiment name") |
|
parser.add_argument("--exp_id", type=str, default="", help="Experiment ID") |
|
parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch") |
|
|
|
|
|
parser.add_argument("--model_path", type=str, default="", help="Model path") |
|
parser.add_argument("--output_path", type=str, default="", help="Output path") |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--src_lang", type=str, default="", help="Source language") |
|
parser.add_argument("--tgt_lang", type=str, default="", help="Target language") |
|
|
|
return parser |
|
|
|
|
|
def main(params): |
|
params.device = torch.device('cuda') |
|
params.eval_only = True |
|
params.log_file_prefix = False |
|
|
|
|
|
logger = initialize_exp(params) |
|
|
|
|
|
parser = get_parser() |
|
params = parser.parse_args() |
|
reloaded = torch.load(params.model_path) |
|
model_params = AttrDict(reloaded['params']) |
|
logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys())) |
|
|
|
|
|
for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']: |
|
setattr(params, name, getattr(model_params, name)) |
|
|
|
|
|
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) |
|
encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval() |
|
decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval() |
|
encoder.load_state_dict(reloaded['encoder']) |
|
decoder.load_state_dict(reloaded['decoder']) |
|
params.src_id = model_params.lang2id[params.src_lang] |
|
params.tgt_id = model_params.lang2id[params.tgt_lang] |
|
|
|
|
|
src_sent = [] |
|
for line in sys.stdin.readlines(): |
|
assert len(line.strip().split()) > 0 |
|
src_sent.append(line) |
|
logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent)) |
|
|
|
f = io.open(params.output_path, 'w', encoding='utf-8') |
|
|
|
for i in range(0, len(src_sent), params.batch_size): |
|
|
|
|
|
word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()]) |
|
for s in src_sent[i:i + params.batch_size]] |
|
lengths = torch.LongTensor([len(s) + 2 for s in word_ids]) |
|
batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index) |
|
batch[0] = params.eos_index |
|
for j, s in enumerate(word_ids): |
|
if lengths[j] > 2: |
|
batch[1:lengths[j] - 1, j].copy_(s) |
|
batch[lengths[j] - 1, j] = params.eos_index |
|
langs = batch.clone().fill_(params.src_id) |
|
|
|
|
|
encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False) |
|
encoded = encoded.transpose(0, 1) |
|
decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10)) |
|
|
|
|
|
for j in range(decoded.size(1)): |
|
|
|
|
|
sent = decoded[:, j] |
|
delimiters = (sent == params.eos_index).nonzero().view(-1) |
|
assert len(delimiters) >= 1 and delimiters[0].item() == 0 |
|
sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]] |
|
|
|
|
|
source = src_sent[i + j].strip() |
|
target = " ".join([dico[sent[k].item()] for k in range(len(sent))]) |
|
sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target)) |
|
f.write(target + "\n") |
|
|
|
f.close() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
parser = get_parser() |
|
params = parser.parse_args() |
|
|
|
|
|
assert os.path.isfile(params.model_path) |
|
assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang |
|
assert params.output_path and not os.path.isfile(params.output_path) |
|
|
|
|
|
with torch.no_grad(): |
|
main(params) |
|
|