sadanyh commited on
Commit
67eacb2
โ€ข
1 Parent(s): 21d3796

Delete translate_mine.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. translate_mine.py +0 -144
translate_mine.py DELETED
@@ -1,144 +0,0 @@
1
- # Copyright (c) 2019-present, Facebook, Inc.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- #
7
- # Translate sentences from the input stream.
8
- # The model will be faster is sentences are sorted by length.
9
- # Input sentences must have the same tokenization and BPE codes than the ones used in the model.
10
- #
11
- # Usage:
12
- # cat source_sentences.bpe | \
13
- # python translate.py --exp_name translate \
14
- # --src_lang en --tgt_lang fr \
15
- # --model_path trained_model.pth --output_path output
16
- #
17
-
18
- import os
19
- import io
20
- import sys
21
- import argparse
22
- import torch
23
-
24
- from src.utils import AttrDict
25
- from src.utils import bool_flag, initialize_exp
26
- from src.data.dictionary import Dictionary
27
- from src.model.transformer import TransformerModel
28
-
29
-
30
- def get_parser():
31
- """
32
- Generate a parameters parser.
33
- """
34
- # parse parameters
35
- parser = argparse.ArgumentParser(description="Translate sentences")
36
-
37
- # main parameters
38
- parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path")
39
- parser.add_argument("--exp_name", type=str, default="", help="Experiment name")
40
- parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
41
- parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch")
42
-
43
- # model / output paths
44
- parser.add_argument("--model_path", type=str, default="", help="Model path")
45
- parser.add_argument("--output_path", type=str, default="", help="Output path")
46
-
47
- # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
48
- # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
49
-
50
- # source language / target language
51
- parser.add_argument("--src_lang", type=str, default="", help="Source language")
52
- parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
53
-
54
- return parser
55
-
56
-
57
- def main(params):
58
- params.device = torch.device('cuda')
59
- params.eval_only = True
60
- params.log_file_prefix = False
61
-
62
- # initialize the experiment
63
- logger = initialize_exp(params)
64
-
65
- # generate parser / parse parameters
66
- parser = get_parser()
67
- params = parser.parse_args()
68
- reloaded = torch.load(params.model_path)
69
- model_params = AttrDict(reloaded['params'])
70
- logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))
71
-
72
- # update dictionary parameters
73
- for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
74
- setattr(params, name, getattr(model_params, name))
75
-
76
- # build dictionary / build encoder / build decoder / reload weights
77
- dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
78
- encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
79
- decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
80
- encoder.load_state_dict(reloaded['encoder'])
81
- decoder.load_state_dict(reloaded['decoder'])
82
- params.src_id = model_params.lang2id[params.src_lang]
83
- params.tgt_id = model_params.lang2id[params.tgt_lang]
84
-
85
- # read sentences from stdin
86
- src_sent = []
87
- for line in sys.stdin.readlines():
88
- assert len(line.strip().split()) > 0
89
- src_sent.append(line)
90
- logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent))
91
-
92
- f = io.open(params.output_path, 'w', encoding='utf-8')
93
-
94
- for i in range(0, len(src_sent), params.batch_size):
95
-
96
- # prepare batch
97
- word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()])
98
- for s in src_sent[i:i + params.batch_size]]
99
- lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
100
- batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index)
101
- batch[0] = params.eos_index
102
- for j, s in enumerate(word_ids):
103
- if lengths[j] > 2: # if sentence not empty
104
- batch[1:lengths[j] - 1, j].copy_(s)
105
- batch[lengths[j] - 1, j] = params.eos_index
106
- langs = batch.clone().fill_(params.src_id)
107
-
108
- # encode source batch and translate it
109
- encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
110
- encoded = encoded.transpose(0, 1)
111
- decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
112
-
113
- # convert sentences to words
114
- for j in range(decoded.size(1)):
115
-
116
- # remove delimiters
117
- sent = decoded[:, j]
118
- delimiters = (sent == params.eos_index).nonzero().view(-1)
119
- assert len(delimiters) >= 1 and delimiters[0].item() == 0
120
- sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]
121
-
122
- # output translation
123
- source = src_sent[i + j].strip()
124
- target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
125
- sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
126
- f.write(target + "\n")
127
-
128
- f.close()
129
-
130
-
131
- if __name__ == '__main__':
132
-
133
- # generate parser / parse parameters
134
- parser = get_parser()
135
- params = parser.parse_args()
136
-
137
- # check parameters
138
- assert os.path.isfile(params.model_path)
139
- assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang
140
- assert params.output_path and not os.path.isfile(params.output_path)
141
-
142
- # translate
143
- with torch.no_grad():
144
- main(params)