GreenRaptor commited on
Commit
52fe7b2
1 Parent(s): 468ac2e

Create infer.py

Browse files
Files changed (1) hide show
  1. infer.py +436 -0
infer.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Run inference for pre-processed data with a trained model.
9
+ """
10
+
11
+ import ast
12
+ import logging
13
+ import math
14
+ import os
15
+ import sys
16
+
17
+ import editdistance
18
+ import numpy as np
19
+ import torch
20
+ from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
21
+ from fairseq.data.data_utils import post_process
22
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
23
+
24
+
25
+ logging.basicConfig()
26
+ logging.root.setLevel(logging.INFO)
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def add_asr_eval_argument(parser):
32
+ parser.add_argument("--kspmodel", default=None, help="sentence piece model")
33
+ parser.add_argument(
34
+ "--wfstlm", default=None, help="wfstlm on dictonary output units"
35
+ )
36
+ parser.add_argument(
37
+ "--rnnt_decoding_type",
38
+ default="greedy",
39
+ help="wfstlm on dictonary\
40
+ output units",
41
+ )
42
+ try:
43
+ parser.add_argument(
44
+ "--lm-weight",
45
+ "--lm_weight",
46
+ type=float,
47
+ default=0.2,
48
+ help="weight for lm while interpolating with neural score",
49
+ )
50
+ except:
51
+ pass
52
+ parser.add_argument(
53
+ "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
54
+ )
55
+ parser.add_argument(
56
+ "--w2l-decoder",
57
+ choices=["viterbi", "kenlm", "fairseqlm"],
58
+ help="use a w2l decoder",
59
+ )
60
+ parser.add_argument("--lexicon", help="lexicon for w2l decoder")
61
+ parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
62
+ parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
63
+ parser.add_argument("--beam-threshold", type=float, default=25.0)
64
+ parser.add_argument("--beam-size-token", type=float, default=100)
65
+ parser.add_argument("--word-score", type=float, default=1.0)
66
+ parser.add_argument("--unk-weight", type=float, default=-math.inf)
67
+ parser.add_argument("--sil-weight", type=float, default=0.0)
68
+ parser.add_argument(
69
+ "--dump-emissions",
70
+ type=str,
71
+ default=None,
72
+ help="if present, dumps emissions into this file and exits",
73
+ )
74
+ parser.add_argument(
75
+ "--dump-features",
76
+ type=str,
77
+ default=None,
78
+ help="if present, dumps features into this file and exits",
79
+ )
80
+ parser.add_argument(
81
+ "--load-emissions",
82
+ type=str,
83
+ default=None,
84
+ help="if present, loads emissions from this file",
85
+ )
86
+ return parser
87
+
88
+
89
+ def check_args(args):
90
+ # assert args.path is not None, "--path required for generation!"
91
+ # assert args.results_path is not None, "--results_path required for generation!"
92
+ assert (
93
+ not args.sampling or args.nbest == args.beam
94
+ ), "--sampling requires --nbest to be equal to --beam"
95
+ assert (
96
+ args.replace_unk is None or args.raw_text
97
+ ), "--replace-unk requires a raw text dataset (--raw-text)"
98
+
99
+
100
+ def get_dataset_itr(args, task, models):
101
+ return task.get_batch_iterator(
102
+ dataset=task.dataset(args.gen_subset),
103
+ max_tokens=args.max_tokens,
104
+ max_sentences=args.batch_size,
105
+ max_positions=(sys.maxsize, sys.maxsize),
106
+ ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
107
+ required_batch_size_multiple=args.required_batch_size_multiple,
108
+ num_shards=args.num_shards,
109
+ shard_id=args.shard_id,
110
+ num_workers=args.num_workers,
111
+ data_buffer_size=args.data_buffer_size,
112
+ ).next_epoch_itr(shuffle=False)
113
+
114
+
115
+ def process_predictions(
116
+ args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
117
+ ):
118
+ for hypo in hypos[: min(len(hypos), args.nbest)]:
119
+ hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
120
+
121
+ if "words" in hypo:
122
+ hyp_words = " ".join(hypo["words"])
123
+ else:
124
+ hyp_words = post_process(hyp_pieces, args.post_process)
125
+
126
+ if res_files is not None:
127
+ print(
128
+ "{} ({}-{})".format(hyp_pieces, speaker, id),
129
+ file=res_files["hypo.units"],
130
+ )
131
+ print(
132
+ "{} ({}-{})".format(hyp_words, speaker, id),
133
+ file=res_files["hypo.words"],
134
+ )
135
+
136
+ tgt_pieces = tgt_dict.string(target_tokens)
137
+ tgt_words = post_process(tgt_pieces, args.post_process)
138
+
139
+ if res_files is not None:
140
+ print(
141
+ "{} ({}-{})".format(tgt_pieces, speaker, id),
142
+ file=res_files["ref.units"],
143
+ )
144
+ print(
145
+ "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
146
+ )
147
+
148
+ if not args.quiet:
149
+ logger.info("HYPO:" + hyp_words)
150
+ logger.info("TARGET:" + tgt_words)
151
+ logger.info("___________________")
152
+
153
+ hyp_words = hyp_words.split()
154
+ tgt_words = tgt_words.split()
155
+ return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
156
+
157
+
158
+ def prepare_result_files(args):
159
+ def get_res_file(file_prefix):
160
+ if args.num_shards > 1:
161
+ file_prefix = f"{args.shard_id}_{file_prefix}"
162
+ path = os.path.join(
163
+ args.results_path,
164
+ "{}-{}-{}.txt".format(
165
+ file_prefix, os.path.basename(args.path), args.gen_subset
166
+ ),
167
+ )
168
+ return open(path, "w", buffering=1)
169
+
170
+ if not args.results_path:
171
+ return None
172
+
173
+ return {
174
+ "hypo.words": get_res_file("hypo.word"),
175
+ "hypo.units": get_res_file("hypo.units"),
176
+ "ref.words": get_res_file("ref.word"),
177
+ "ref.units": get_res_file("ref.units"),
178
+ }
179
+
180
+
181
+ def optimize_models(args, use_cuda, models):
182
+ """Optimize ensemble for generation"""
183
+ for model in models:
184
+ model.make_generation_fast_(
185
+ beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
186
+ need_attn=args.print_alignment,
187
+ )
188
+ if args.fp16:
189
+ model.half()
190
+ if use_cuda:
191
+ model.cuda()
192
+
193
+
194
+ def apply_half(t):
195
+ if t.dtype is torch.float32:
196
+ return t.to(dtype=torch.half)
197
+ return t
198
+
199
+
200
+ class ExistingEmissionsDecoder(object):
201
+ def __init__(self, decoder, emissions):
202
+ self.decoder = decoder
203
+ self.emissions = emissions
204
+
205
+ def generate(self, models, sample, **unused):
206
+ ids = sample["id"].cpu().numpy()
207
+ try:
208
+ emissions = np.stack(self.emissions[ids])
209
+ except:
210
+ print([x.shape for x in self.emissions[ids]])
211
+ raise Exception("invalid sizes")
212
+ emissions = torch.from_numpy(emissions)
213
+ return self.decoder.decode(emissions)
214
+
215
+
216
+ def main(args, task=None, model_state=None):
217
+ check_args(args)
218
+
219
+ use_fp16 = args.fp16
220
+ if args.max_tokens is None and args.batch_size is None:
221
+ args.max_tokens = 4000000
222
+ logger.info(args)
223
+
224
+ use_cuda = torch.cuda.is_available() and not args.cpu
225
+
226
+ logger.info("| decoding with criterion {}".format(args.criterion))
227
+
228
+ task = tasks.setup_task(args)
229
+
230
+ # Load ensemble
231
+ if args.load_emissions:
232
+ models, criterions = [], []
233
+ task.load_dataset(args.gen_subset)
234
+ else:
235
+ logger.info("| loading model(s) from {}".format(args.path))
236
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
237
+ utils.split_paths(args.path, separator="\\"),
238
+ arg_overrides=ast.literal_eval(args.model_overrides),
239
+ task=task,
240
+ suffix=args.checkpoint_suffix,
241
+ strict=(args.checkpoint_shard_count == 1),
242
+ num_shards=args.checkpoint_shard_count,
243
+ state=model_state,
244
+ )
245
+ optimize_models(args, use_cuda, models)
246
+ task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
247
+
248
+
249
+ # Set dictionary
250
+ tgt_dict = task.target_dictionary
251
+
252
+ logger.info(
253
+ "| {} {} {} examples".format(
254
+ args.data, args.gen_subset, len(task.dataset(args.gen_subset))
255
+ )
256
+ )
257
+
258
+ # hack to pass transitions to W2lDecoder
259
+ if args.criterion == "asg_loss":
260
+ raise NotImplementedError("asg_loss is currently not supported")
261
+ # trans = criterions[0].asg.trans.data
262
+ # args.asg_transitions = torch.flatten(trans).tolist()
263
+
264
+ # Load dataset (possibly sharded)
265
+ itr = get_dataset_itr(args, task, models)
266
+
267
+ # Initialize generator
268
+ gen_timer = StopwatchMeter()
269
+
270
+ def build_generator(args):
271
+ w2l_decoder = getattr(args, "w2l_decoder", None)
272
+ if w2l_decoder == "viterbi":
273
+ from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
274
+
275
+ return W2lViterbiDecoder(args, task.target_dictionary)
276
+ elif w2l_decoder == "kenlm":
277
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
278
+
279
+ return W2lKenLMDecoder(args, task.target_dictionary)
280
+ elif w2l_decoder == "fairseqlm":
281
+ from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
282
+
283
+ return W2lFairseqLMDecoder(args, task.target_dictionary)
284
+ else:
285
+ print(
286
+ "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
287
+ )
288
+
289
+ # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
290
+ generator = build_generator(args)
291
+
292
+ if args.load_emissions:
293
+ generator = ExistingEmissionsDecoder(
294
+ generator, np.load(args.load_emissions, allow_pickle=True)
295
+ )
296
+ logger.info("loaded emissions from " + args.load_emissions)
297
+
298
+ num_sentences = 0
299
+
300
+ if args.results_path is not None and not os.path.exists(args.results_path):
301
+ os.makedirs(args.results_path)
302
+
303
+ max_source_pos = (
304
+ utils.resolve_max_positions(
305
+ task.max_positions(), *[model.max_positions() for model in models]
306
+ ),
307
+ )
308
+
309
+ if max_source_pos is not None:
310
+ max_source_pos = max_source_pos[0]
311
+ if max_source_pos is not None:
312
+ max_source_pos = max_source_pos[0] - 1
313
+
314
+ if args.dump_emissions:
315
+ emissions = {}
316
+ if args.dump_features:
317
+ features = {}
318
+ models[0].bert.proj = None
319
+ else:
320
+ res_files = prepare_result_files(args)
321
+ errs_t = 0
322
+ lengths_t = 0
323
+ with progress_bar.build_progress_bar(args, itr) as t:
324
+ wps_meter = TimeMeter()
325
+ for sample in t:
326
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
327
+ if use_fp16:
328
+ sample = utils.apply_to_sample(apply_half, sample)
329
+ if "net_input" not in sample:
330
+ continue
331
+
332
+ prefix_tokens = None
333
+ if args.prefix_size > 0:
334
+ prefix_tokens = sample["target"][:, : args.prefix_size]
335
+
336
+ gen_timer.start()
337
+ if args.dump_emissions:
338
+ with torch.no_grad():
339
+ encoder_out = models[0](**sample["net_input"])
340
+ emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
341
+ emm = emm.transpose(0, 1).cpu().numpy()
342
+ for i, id in enumerate(sample["id"]):
343
+ emissions[id.item()] = emm[i]
344
+ continue
345
+ elif args.dump_features:
346
+ with torch.no_grad():
347
+ encoder_out = models[0](**sample["net_input"])
348
+ feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
349
+ for i, id in enumerate(sample["id"]):
350
+ padding = (
351
+ encoder_out["encoder_padding_mask"][i].cpu().numpy()
352
+ if encoder_out["encoder_padding_mask"] is not None
353
+ else None
354
+ )
355
+ features[id.item()] = (feat[i], padding)
356
+ continue
357
+ hypos = task.inference_step(generator, models, sample, prefix_tokens)
358
+ num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
359
+ gen_timer.stop(num_generated_tokens)
360
+
361
+ for i, sample_id in enumerate(sample["id"].tolist()):
362
+ speaker = None
363
+ # id = task.dataset(args.gen_subset).ids[int(sample_id)]
364
+ id = sample_id
365
+ toks = (
366
+ sample["target"][i, :]
367
+ if "target_label" not in sample
368
+ else sample["target_label"][i, :]
369
+ )
370
+ target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
371
+ # Process top predictions
372
+ errs, length = process_predictions(
373
+ args,
374
+ hypos[i],
375
+ None,
376
+ tgt_dict,
377
+ target_tokens,
378
+ res_files,
379
+ speaker,
380
+ id,
381
+ )
382
+ errs_t += errs
383
+ lengths_t += length
384
+
385
+ wps_meter.update(num_generated_tokens)
386
+ t.log({"wps": round(wps_meter.avg)})
387
+ num_sentences += (
388
+ sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
389
+ )
390
+
391
+ wer = None
392
+ if args.dump_emissions:
393
+ emm_arr = []
394
+ for i in range(len(emissions)):
395
+ emm_arr.append(emissions[i])
396
+ np.save(args.dump_emissions, emm_arr)
397
+ logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
398
+ elif args.dump_features:
399
+ feat_arr = []
400
+ for i in range(len(features)):
401
+ feat_arr.append(features[i])
402
+ np.save(args.dump_features, feat_arr)
403
+ logger.info(f"saved {len(features)} emissions to {args.dump_features}")
404
+ else:
405
+ if lengths_t > 0:
406
+ wer = errs_t * 100.0 / lengths_t
407
+ logger.info(f"WER: {wer}")
408
+
409
+ logger.info(
410
+ "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
411
+ "sentences/s, {:.2f} tokens/s)".format(
412
+ num_sentences,
413
+ gen_timer.n,
414
+ gen_timer.sum,
415
+ num_sentences / gen_timer.sum,
416
+ 1.0 / gen_timer.avg,
417
+ )
418
+ )
419
+ logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
420
+ return task, wer
421
+
422
+
423
+ def make_parser():
424
+ parser = options.get_generation_parser()
425
+ parser = add_asr_eval_argument(parser)
426
+ return parser
427
+
428
+
429
+ def cli_main():
430
+ parser = make_parser()
431
+ args = options.parse_args_and_arch(parser)
432
+ main(args)
433
+
434
+
435
+ if __name__ == "__main__":
436
+ cli_main()