Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Run inference for pre-processed data with a trained model. | |
| """ | |
| import ast | |
| import logging | |
| import math | |
| import os | |
| import sys | |
| import editdistance | |
| import numpy as np | |
| import torch | |
| from fairseq import checkpoint_utils, options, progress_bar, tasks, utils | |
| from fairseq.data.data_utils import post_process | |
| from fairseq.logging.meters import StopwatchMeter, TimeMeter | |
| logging.basicConfig() | |
| logging.root.setLevel(logging.INFO) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def add_asr_eval_argument(parser): | |
| parser.add_argument("--kspmodel", default=None, help="sentence piece model") | |
| parser.add_argument( | |
| "--wfstlm", default=None, help="wfstlm on dictonary output units" | |
| ) | |
| parser.add_argument( | |
| "--rnnt_decoding_type", | |
| default="greedy", | |
| help="wfstlm on dictonary\ | |
| output units", | |
| ) | |
| try: | |
| parser.add_argument( | |
| "--lm-weight", | |
| "--lm_weight", | |
| type=float, | |
| default=0.2, | |
| help="weight for lm while interpolating with neural score", | |
| ) | |
| except: | |
| pass | |
| parser.add_argument( | |
| "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" | |
| ) | |
| parser.add_argument( | |
| "--w2l-decoder", | |
| choices=["viterbi", "kenlm", "fairseqlm"], | |
| help="use a w2l decoder", | |
| ) | |
| parser.add_argument("--lexicon", help="lexicon for w2l decoder") | |
| parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm") | |
| parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder") | |
| parser.add_argument("--beam-threshold", type=float, default=25.0) | |
| parser.add_argument("--beam-size-token", type=float, default=100) | |
| parser.add_argument("--word-score", type=float, default=1.0) | |
| parser.add_argument("--unk-weight", type=float, default=-math.inf) | |
| parser.add_argument("--sil-weight", type=float, default=0.0) | |
| parser.add_argument( | |
| "--dump-emissions", | |
| type=str, | |
| default=None, | |
| help="if present, dumps emissions into this file and exits", | |
| ) | |
| parser.add_argument( | |
| "--dump-features", | |
| type=str, | |
| default=None, | |
| help="if present, dumps features into this file and exits", | |
| ) | |
| parser.add_argument( | |
| "--load-emissions", | |
| type=str, | |
| default=None, | |
| help="if present, loads emissions from this file", | |
| ) | |
| return parser | |
| def check_args(args): | |
| # assert args.path is not None, "--path required for generation!" | |
| # assert args.results_path is not None, "--results_path required for generation!" | |
| assert ( | |
| not args.sampling or args.nbest == args.beam | |
| ), "--sampling requires --nbest to be equal to --beam" | |
| assert ( | |
| args.replace_unk is None or args.raw_text | |
| ), "--replace-unk requires a raw text dataset (--raw-text)" | |
| def get_dataset_itr(args, task, models): | |
| return task.get_batch_iterator( | |
| dataset=task.dataset(args.gen_subset), | |
| max_tokens=args.max_tokens, | |
| max_sentences=args.batch_size, | |
| max_positions=(sys.maxsize, sys.maxsize), | |
| ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, | |
| required_batch_size_multiple=args.required_batch_size_multiple, | |
| num_shards=args.num_shards, | |
| shard_id=args.shard_id, | |
| num_workers=args.num_workers, | |
| data_buffer_size=args.data_buffer_size, | |
| ).next_epoch_itr(shuffle=False) | |
| def process_predictions( | |
| args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id | |
| ): | |
| for hypo in hypos[: min(len(hypos), args.nbest)]: | |
| hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) | |
| if "words" in hypo: | |
| hyp_words = " ".join(hypo["words"]) | |
| else: | |
| hyp_words = post_process(hyp_pieces, args.post_process) | |
| if res_files is not None: | |
| print( | |
| "{} ({}-{})".format(hyp_pieces, speaker, id), | |
| file=res_files["hypo.units"], | |
| ) | |
| print( | |
| "{} ({}-{})".format(hyp_words, speaker, id), | |
| file=res_files["hypo.words"], | |
| ) | |
| tgt_pieces = tgt_dict.string(target_tokens) | |
| tgt_words = post_process(tgt_pieces, args.post_process) | |
| if res_files is not None: | |
| print( | |
| "{} ({}-{})".format(tgt_pieces, speaker, id), | |
| file=res_files["ref.units"], | |
| ) | |
| print( | |
| "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"] | |
| ) | |
| if not args.quiet: | |
| logger.info("HYPO:" + hyp_words) | |
| logger.info("TARGET:" + tgt_words) | |
| logger.info("___________________") | |
| hyp_words = hyp_words.split() | |
| tgt_words = tgt_words.split() | |
| return editdistance.eval(hyp_words, tgt_words), len(tgt_words) | |
| def prepare_result_files(args): | |
| def get_res_file(file_prefix): | |
| if args.num_shards > 1: | |
| file_prefix = f"{args.shard_id}_{file_prefix}" | |
| path = os.path.join( | |
| args.results_path, | |
| "{}-{}-{}.txt".format( | |
| file_prefix, os.path.basename(args.path), args.gen_subset | |
| ), | |
| ) | |
| return open(path, "w", buffering=1) | |
| if not args.results_path: | |
| return None | |
| return { | |
| "hypo.words": get_res_file("hypo.word"), | |
| "hypo.units": get_res_file("hypo.units"), | |
| "ref.words": get_res_file("ref.word"), | |
| "ref.units": get_res_file("ref.units"), | |
| } | |
| def optimize_models(args, use_cuda, models): | |
| """Optimize ensemble for generation""" | |
| for model in models: | |
| model.make_generation_fast_( | |
| beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, | |
| need_attn=args.print_alignment, | |
| ) | |
| if args.fp16: | |
| model.half() | |
| if use_cuda: | |
| model.cuda() | |
| def apply_half(t): | |
| if t.dtype is torch.float32: | |
| return t.to(dtype=torch.half) | |
| return t | |
| class ExistingEmissionsDecoder(object): | |
| def __init__(self, decoder, emissions): | |
| self.decoder = decoder | |
| self.emissions = emissions | |
| def generate(self, models, sample, **unused): | |
| ids = sample["id"].cpu().numpy() | |
| try: | |
| emissions = np.stack(self.emissions[ids]) | |
| except: | |
| print([x.shape for x in self.emissions[ids]]) | |
| raise Exception("invalid sizes") | |
| emissions = torch.from_numpy(emissions) | |
| return self.decoder.decode(emissions) | |
| def main(args, task=None, model_state=None): | |
| check_args(args) | |
| use_fp16 = args.fp16 | |
| if args.max_tokens is None and args.batch_size is None: | |
| args.max_tokens = 4000000 | |
| logger.info(args) | |
| use_cuda = torch.cuda.is_available() and not args.cpu | |
| logger.info("| decoding with criterion {}".format(args.criterion)) | |
| task = tasks.setup_task(args) | |
| # Load ensemble | |
| if args.load_emissions: | |
| models, criterions = [], [] | |
| task.load_dataset(args.gen_subset) | |
| else: | |
| logger.info("| loading model(s) from {}".format(args.path)) | |
| models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |
| utils.split_paths(args.path, separator="\\"), | |
| arg_overrides=ast.literal_eval(args.model_overrides), | |
| task=task, | |
| suffix=args.checkpoint_suffix, | |
| strict=(args.checkpoint_shard_count == 1), | |
| num_shards=args.checkpoint_shard_count, | |
| state=model_state, | |
| ) | |
| optimize_models(args, use_cuda, models) | |
| task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) | |
| # Set dictionary | |
| tgt_dict = task.target_dictionary | |
| logger.info( | |
| "| {} {} {} examples".format( | |
| args.data, args.gen_subset, len(task.dataset(args.gen_subset)) | |
| ) | |
| ) | |
| # hack to pass transitions to W2lDecoder | |
| if args.criterion == "asg_loss": | |
| raise NotImplementedError("asg_loss is currently not supported") | |
| # trans = criterions[0].asg.trans.data | |
| # args.asg_transitions = torch.flatten(trans).tolist() | |
| # Load dataset (possibly sharded) | |
| itr = get_dataset_itr(args, task, models) | |
| # Initialize generator | |
| gen_timer = StopwatchMeter() | |
| def build_generator(args): | |
| w2l_decoder = getattr(args, "w2l_decoder", None) | |
| if w2l_decoder == "viterbi": | |
| from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder | |
| return W2lViterbiDecoder(args, task.target_dictionary) | |
| elif w2l_decoder == "kenlm": | |
| from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder | |
| return W2lKenLMDecoder(args, task.target_dictionary) | |
| elif w2l_decoder == "fairseqlm": | |
| from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder | |
| return W2lFairseqLMDecoder(args, task.target_dictionary) | |
| else: | |
| print( | |
| "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment" | |
| ) | |
| # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task | |
| generator = build_generator(args) | |
| if args.load_emissions: | |
| generator = ExistingEmissionsDecoder( | |
| generator, np.load(args.load_emissions, allow_pickle=True) | |
| ) | |
| logger.info("loaded emissions from " + args.load_emissions) | |
| num_sentences = 0 | |
| if args.results_path is not None and not os.path.exists(args.results_path): | |
| os.makedirs(args.results_path) | |
| max_source_pos = ( | |
| utils.resolve_max_positions( | |
| task.max_positions(), *[model.max_positions() for model in models] | |
| ), | |
| ) | |
| if max_source_pos is not None: | |
| max_source_pos = max_source_pos[0] | |
| if max_source_pos is not None: | |
| max_source_pos = max_source_pos[0] - 1 | |
| if args.dump_emissions: | |
| emissions = {} | |
| if args.dump_features: | |
| features = {} | |
| models[0].bert.proj = None | |
| else: | |
| res_files = prepare_result_files(args) | |
| errs_t = 0 | |
| lengths_t = 0 | |
| with progress_bar.build_progress_bar(args, itr) as t: | |
| wps_meter = TimeMeter() | |
| for sample in t: | |
| sample = utils.move_to_cuda(sample) if use_cuda else sample | |
| if use_fp16: | |
| sample = utils.apply_to_sample(apply_half, sample) | |
| if "net_input" not in sample: | |
| continue | |
| prefix_tokens = None | |
| if args.prefix_size > 0: | |
| prefix_tokens = sample["target"][:, : args.prefix_size] | |
| gen_timer.start() | |
| if args.dump_emissions: | |
| with torch.no_grad(): | |
| encoder_out = models[0](**sample["net_input"]) | |
| emm = models[0].get_normalized_probs(encoder_out, log_probs=True) | |
| emm = emm.transpose(0, 1).cpu().numpy() | |
| for i, id in enumerate(sample["id"]): | |
| emissions[id.item()] = emm[i] | |
| continue | |
| elif args.dump_features: | |
| with torch.no_grad(): | |
| encoder_out = models[0](**sample["net_input"]) | |
| feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy() | |
| for i, id in enumerate(sample["id"]): | |
| padding = ( | |
| encoder_out["encoder_padding_mask"][i].cpu().numpy() | |
| if encoder_out["encoder_padding_mask"] is not None | |
| else None | |
| ) | |
| features[id.item()] = (feat[i], padding) | |
| continue | |
| hypos = task.inference_step(generator, models, sample, prefix_tokens) | |
| num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) | |
| gen_timer.stop(num_generated_tokens) | |
| for i, sample_id in enumerate(sample["id"].tolist()): | |
| speaker = None | |
| # id = task.dataset(args.gen_subset).ids[int(sample_id)] | |
| id = sample_id | |
| toks = ( | |
| sample["target"][i, :] | |
| if "target_label" not in sample | |
| else sample["target_label"][i, :] | |
| ) | |
| target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu() | |
| # Process top predictions | |
| errs, length = process_predictions( | |
| args, | |
| hypos[i], | |
| None, | |
| tgt_dict, | |
| target_tokens, | |
| res_files, | |
| speaker, | |
| id, | |
| ) | |
| errs_t += errs | |
| lengths_t += length | |
| wps_meter.update(num_generated_tokens) | |
| t.log({"wps": round(wps_meter.avg)}) | |
| num_sentences += ( | |
| sample["nsentences"] if "nsentences" in sample else sample["id"].numel() | |
| ) | |
| wer = None | |
| if args.dump_emissions: | |
| emm_arr = [] | |
| for i in range(len(emissions)): | |
| emm_arr.append(emissions[i]) | |
| np.save(args.dump_emissions, emm_arr) | |
| logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}") | |
| elif args.dump_features: | |
| feat_arr = [] | |
| for i in range(len(features)): | |
| feat_arr.append(features[i]) | |
| np.save(args.dump_features, feat_arr) | |
| logger.info(f"saved {len(features)} emissions to {args.dump_features}") | |
| else: | |
| if lengths_t > 0: | |
| wer = errs_t * 100.0 / lengths_t | |
| logger.info(f"WER: {wer}") | |
| logger.info( | |
| "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" | |
| "sentences/s, {:.2f} tokens/s)".format( | |
| num_sentences, | |
| gen_timer.n, | |
| gen_timer.sum, | |
| num_sentences / gen_timer.sum, | |
| 1.0 / gen_timer.avg, | |
| ) | |
| ) | |
| logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) | |
| return task, wer | |
| def make_parser(): | |
| parser = options.get_generation_parser() | |
| parser = add_asr_eval_argument(parser) | |
| return parser | |
| def cli_main(): | |
| parser = make_parser() | |
| args = options.parse_args_and_arch(parser) | |
| main(args) | |
| if __name__ == "__main__": | |
| cli_main() |