#!/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Run inference for pre-processed data with a trained model. """ import ast import logging import math import os import sys import editdistance import numpy as np import torch from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.data.data_utils import post_process from fairseq.logging.meters import StopwatchMeter, TimeMeter logging.basicConfig() logging.root.setLevel(logging.INFO) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def add_asr_eval_argument(parser): parser.add_argument("--kspmodel", default=None, help="sentence piece model") parser.add_argument( "--wfstlm", default=None, help="wfstlm on dictonary output units" ) parser.add_argument( "--rnnt_decoding_type", default="greedy", help="wfstlm on dictonary\ output units", ) try: parser.add_argument( "--lm-weight", "--lm_weight", type=float, default=0.2, help="weight for lm while interpolating with neural score", ) except: pass parser.add_argument( "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" ) parser.add_argument( "--w2l-decoder", choices=["viterbi", "kenlm", "fairseqlm"], help="use a w2l decoder", ) parser.add_argument("--lexicon", help="lexicon for w2l decoder") parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm") parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder") parser.add_argument("--beam-threshold", type=float, default=25.0) parser.add_argument("--beam-size-token", type=float, default=100) parser.add_argument("--word-score", type=float, default=1.0) parser.add_argument("--unk-weight", type=float, default=-math.inf) parser.add_argument("--sil-weight", type=float, default=0.0) parser.add_argument( "--dump-emissions", type=str, default=None, help="if present, dumps emissions into this file and exits", ) parser.add_argument( "--dump-features", type=str, default=None, help="if present, dumps features into this file and exits", ) parser.add_argument( "--load-emissions", type=str, default=None, help="if present, loads emissions from this file", ) return parser def check_args(args): # assert args.path is not None, "--path required for generation!" # assert args.results_path is not None, "--results_path required for generation!" assert ( not args.sampling or args.nbest == args.beam ), "--sampling requires --nbest to be equal to --beam" assert ( args.replace_unk is None or args.raw_text ), "--replace-unk requires a raw text dataset (--raw-text)" def get_dataset_itr(args, task, models): return task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.batch_size, max_positions=(sys.maxsize, sys.maxsize), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) def process_predictions( args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id ): for hypo in hypos[: min(len(hypos), args.nbest)]: hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) if "words" in hypo: hyp_words = " ".join(hypo["words"]) else: hyp_words = post_process(hyp_pieces, args.post_process) if res_files is not None: print( "{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"], ) print( "{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"], ) tgt_pieces = tgt_dict.string(target_tokens) tgt_words = post_process(tgt_pieces, args.post_process) if res_files is not None: print( "{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"], ) print( "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"] ) if not args.quiet: logger.info("HYPO:" + hyp_words) logger.info("TARGET:" + tgt_words) logger.info("___________________") hyp_words = hyp_words.split() tgt_words = tgt_words.split() return editdistance.eval(hyp_words, tgt_words), len(tgt_words) def prepare_result_files(args): def get_res_file(file_prefix): if args.num_shards > 1: file_prefix = f"{args.shard_id}_{file_prefix}" path = os.path.join( args.results_path, "{}-{}-{}.txt".format( file_prefix, os.path.basename(args.path), args.gen_subset ), ) return open(path, "w", buffering=1) if not args.results_path: return None return { "hypo.words": get_res_file("hypo.word"), "hypo.units": get_res_file("hypo.units"), "ref.words": get_res_file("ref.word"), "ref.units": get_res_file("ref.units"), } def optimize_models(args, use_cuda, models): """Optimize ensemble for generation""" for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() def apply_half(t): if t.dtype is torch.float32: return t.to(dtype=torch.half) return t class ExistingEmissionsDecoder(object): def __init__(self, decoder, emissions): self.decoder = decoder self.emissions = emissions def generate(self, models, sample, **unused): ids = sample["id"].cpu().numpy() try: emissions = np.stack(self.emissions[ids]) except: print([x.shape for x in self.emissions[ids]]) raise Exception("invalid sizes") emissions = torch.from_numpy(emissions) return self.decoder.decode(emissions) def main(args, task=None, model_state=None): check_args(args) use_fp16 = args.fp16 if args.max_tokens is None and args.batch_size is None: args.max_tokens = 4000000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu logger.info("| decoding with criterion {}".format(args.criterion)) task = tasks.setup_task(args) # Load ensemble if args.load_emissions: models, criterions = [], [] task.load_dataset(args.gen_subset) else: logger.info("| loading model(s) from {}".format(args.path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( utils.split_paths(args.path, separator="\\"), arg_overrides=ast.literal_eval(args.model_overrides), task=task, suffix=args.checkpoint_suffix, strict=(args.checkpoint_shard_count == 1), num_shards=args.checkpoint_shard_count, state=model_state, ) optimize_models(args, use_cuda, models) task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) # Set dictionary tgt_dict = task.target_dictionary logger.info( "| {} {} {} examples".format( args.data, args.gen_subset, len(task.dataset(args.gen_subset)) ) ) # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": raise NotImplementedError("asg_loss is currently not supported") # trans = criterions[0].asg.trans.data # args.asg_transitions = torch.flatten(trans).tolist() # Load dataset (possibly sharded) itr = get_dataset_itr(args, task, models) # Initialize generator gen_timer = StopwatchMeter() def build_generator(args): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(args, task.target_dictionary) elif w2l_decoder == "kenlm": from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(args, task.target_dictionary) elif w2l_decoder == "fairseqlm": from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder return W2lFairseqLMDecoder(args, task.target_dictionary) else: print( "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment" ) # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task generator = build_generator(args) if args.load_emissions: generator = ExistingEmissionsDecoder( generator, np.load(args.load_emissions, allow_pickle=True) ) logger.info("loaded emissions from " + args.load_emissions) num_sentences = 0 if args.results_path is not None and not os.path.exists(args.results_path): os.makedirs(args.results_path) max_source_pos = ( utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ) if max_source_pos is not None: max_source_pos = max_source_pos[0] if max_source_pos is not None: max_source_pos = max_source_pos[0] - 1 if args.dump_emissions: emissions = {} if args.dump_features: features = {} models[0].bert.proj = None else: res_files = prepare_result_files(args) errs_t = 0 lengths_t = 0 with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if use_fp16: sample = utils.apply_to_sample(apply_half, sample) if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample["target"][:, : args.prefix_size] gen_timer.start() if args.dump_emissions: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) emm = models[0].get_normalized_probs(encoder_out, log_probs=True) emm = emm.transpose(0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): emissions[id.item()] = emm[i] continue elif args.dump_features: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): padding = ( encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None ) features[id.item()] = (feat[i], padding) continue hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): speaker = None # id = task.dataset(args.gen_subset).ids[int(sample_id)] id = sample_id toks = ( sample["target"][i, :] if "target_label" not in sample else sample["target_label"][i, :] ) target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu() # Process top predictions errs, length = process_predictions( args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id, ) errs_t += errs lengths_t += length wps_meter.update(num_generated_tokens) t.log({"wps": round(wps_meter.avg)}) num_sentences += ( sample["nsentences"] if "nsentences" in sample else sample["id"].numel() ) wer = None if args.dump_emissions: emm_arr = [] for i in range(len(emissions)): emm_arr.append(emissions[i]) np.save(args.dump_emissions, emm_arr) logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}") elif args.dump_features: feat_arr = [] for i in range(len(features)): feat_arr.append(features[i]) np.save(args.dump_features, feat_arr) logger.info(f"saved {len(features)} emissions to {args.dump_features}") else: if lengths_t > 0: wer = errs_t * 100.0 / lengths_t logger.info(f"WER: {wer}") logger.info( "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" "sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, ) ) logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) return task, wer def make_parser(): parser = options.get_generation_parser() parser = add_asr_eval_argument(parser) return parser def cli_main(): parser = make_parser() args = options.parse_args_and_arch(parser) main(args) if __name__ == "__main__": cli_main()