sadanyh commited on
Commit
47662c2
โ€ข
1 Parent(s): 9118d6a

create translate.py

Browse files
Files changed (1) hide show
  1. translate.py +144 -0
translate.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Translate sentences from the input stream.
8
+ # The model will be faster is sentences are sorted by length.
9
+ # Input sentences must have the same tokenization and BPE codes than the ones used in the model.
10
+ #
11
+ # Usage:
12
+ # cat source_sentences.bpe | \
13
+ # python translate.py --exp_name translate \
14
+ # --src_lang en --tgt_lang fr \
15
+ # --model_path trained_model.pth --output_path output
16
+ #
17
+
18
+ import os
19
+ import io
20
+ import sys
21
+ import argparse
22
+ import torch
23
+
24
+ from src.utils import AttrDict
25
+ from src.utils import bool_flag, initialize_exp
26
+ from src.data.dictionary import Dictionary
27
+ from src.model.transformer import TransformerModel
28
+
29
+
30
+ def get_parser():
31
+ """
32
+ Generate a parameters parser.
33
+ """
34
+ # parse parameters
35
+ parser = argparse.ArgumentParser(description="Translate sentences")
36
+
37
+ # main parameters
38
+ parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path")
39
+ parser.add_argument("--exp_name", type=str, default="", help="Experiment name")
40
+ parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
41
+ parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch")
42
+
43
+ # model / output paths
44
+ parser.add_argument("--model_path", type=str, default="", help="Model path")
45
+ parser.add_argument("--output_path", type=str, default="", help="Output path")
46
+
47
+ # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
48
+ # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
49
+
50
+ # source language / target language
51
+ parser.add_argument("--src_lang", type=str, default="", help="Source language")
52
+ parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
53
+
54
+ return parser
55
+
56
+
57
+ def main(params):
58
+ params.device = torch.device('cuda')
59
+ params.eval_only = True
60
+ params.log_file_prefix = False
61
+
62
+ # initialize the experiment
63
+ logger = initialize_exp(params)
64
+
65
+ # generate parser / parse parameters
66
+ parser = get_parser()
67
+ params = parser.parse_args()
68
+ reloaded = torch.load(params.model_path)
69
+ model_params = AttrDict(reloaded['params'])
70
+ logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))
71
+
72
+ # update dictionary parameters
73
+ for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
74
+ setattr(params, name, getattr(model_params, name))
75
+
76
+ # build dictionary / build encoder / build decoder / reload weights
77
+ dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
78
+ encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
79
+ decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
80
+ encoder.load_state_dict(reloaded['encoder'])
81
+ decoder.load_state_dict(reloaded['decoder'])
82
+ params.src_id = model_params.lang2id[params.src_lang]
83
+ params.tgt_id = model_params.lang2id[params.tgt_lang]
84
+
85
+ # read sentences from stdin
86
+ src_sent = []
87
+ for line in sys.stdin.readlines():
88
+ assert len(line.strip().split()) > 0
89
+ src_sent.append(line)
90
+ logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent))
91
+
92
+ f = io.open(params.output_path, 'w', encoding='utf-8')
93
+
94
+ for i in range(0, len(src_sent), params.batch_size):
95
+
96
+ # prepare batch
97
+ word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()])
98
+ for s in src_sent[i:i + params.batch_size]]
99
+ lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
100
+ batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index)
101
+ batch[0] = params.eos_index
102
+ for j, s in enumerate(word_ids):
103
+ if lengths[j] > 2: # if sentence not empty
104
+ batch[1:lengths[j] - 1, j].copy_(s)
105
+ batch[lengths[j] - 1, j] = params.eos_index
106
+ langs = batch.clone().fill_(params.src_id)
107
+
108
+ # encode source batch and translate it
109
+ encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
110
+ encoded = encoded.transpose(0, 1)
111
+ decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
112
+
113
+ # convert sentences to words
114
+ for j in range(decoded.size(1)):
115
+
116
+ # remove delimiters
117
+ sent = decoded[:, j]
118
+ delimiters = (sent == params.eos_index).nonzero().view(-1)
119
+ assert len(delimiters) >= 1 and delimiters[0].item() == 0
120
+ sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]
121
+
122
+ # output translation
123
+ source = src_sent[i + j].strip()
124
+ target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
125
+ sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
126
+ f.write(target + "\n")
127
+
128
+ f.close()
129
+
130
+
131
+ if __name__ == '__main__':
132
+
133
+ # generate parser / parse parameters
134
+ parser = get_parser()
135
+ params = parser.parse_args()
136
+
137
+ # check parameters
138
+ assert os.path.isfile(params.model_path)
139
+ assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang
140
+ assert params.output_path and not os.path.isfile(params.output_path)
141
+
142
+ # translate
143
+ with torch.no_grad():
144
+ main(params)