Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Run inference for pre-processed data with a trained model. | |
""" | |
import ast | |
from collections import namedtuple | |
from dataclasses import dataclass, field | |
from enum import Enum, auto | |
import hydra | |
from hydra.core.config_store import ConfigStore | |
import logging | |
import math | |
import os | |
from omegaconf import OmegaConf | |
from typing import Optional | |
import sys | |
import editdistance | |
import torch | |
from hydra.core.hydra_config import HydraConfig | |
from fairseq import checkpoint_utils, progress_bar, tasks, utils | |
from fairseq.data.data_utils import post_process | |
from fairseq.dataclass.configs import FairseqDataclass, FairseqConfig | |
from fairseq.logging.meters import StopwatchMeter | |
from omegaconf import open_dict | |
from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoderConfig | |
logging.root.setLevel(logging.INFO) | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class DecoderType(Enum): | |
VITERBI = auto() | |
KENLM = auto() | |
FAIRSEQ = auto() | |
KALDI = auto() | |
class UnsupGenerateConfig(FairseqDataclass): | |
fairseq: FairseqConfig = FairseqConfig() | |
lm_weight: float = field( | |
default=2.0, | |
metadata={"help": "language model weight"}, | |
) | |
w2l_decoder: DecoderType = field( | |
default=DecoderType.VITERBI, | |
metadata={"help": "type of decoder to use"}, | |
) | |
kaldi_decoder_config: Optional[KaldiDecoderConfig] = None | |
lexicon: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "path to lexicon. This is also used to 'phonemize' for unsupvised param tuning" | |
}, | |
) | |
lm_model: Optional[str] = field( | |
default=None, | |
metadata={"help": "path to language model (kenlm or fairseq)"}, | |
) | |
unit_lm: bool = field( | |
default=False, | |
metadata={"help": "whether to use unit lm"}, | |
) | |
beam_threshold: float = field( | |
default=50.0, | |
metadata={"help": "beam score threshold"}, | |
) | |
beam_size_token: float = field( | |
default=100.0, | |
metadata={"help": "max tokens per beam"}, | |
) | |
beam: int = field( | |
default=5, | |
metadata={"help": "decoder beam size"}, | |
) | |
nbest: int = field( | |
default=1, | |
metadata={"help": "number of results to return"}, | |
) | |
word_score: float = field( | |
default=1.0, | |
metadata={"help": "word score to add at end of word"}, | |
) | |
unk_weight: float = field( | |
default=-math.inf, | |
metadata={"help": "unknown token weight"}, | |
) | |
sil_weight: float = field( | |
default=0.0, | |
metadata={"help": "silence token weight"}, | |
) | |
targets: Optional[str] = field( | |
default=None, | |
metadata={"help": "extension of ground truth labels to compute UER"}, | |
) | |
results_path: Optional[str] = field( | |
default=None, | |
metadata={"help": "where to store results"}, | |
) | |
post_process: Optional[str] = field( | |
default=None, | |
metadata={"help": "how to post process results"}, | |
) | |
vocab_usage_power: float = field( | |
default=2, | |
metadata={"help": "for unsupervised param tuning"}, | |
) | |
viterbi_transcript: Optional[str] = field( | |
default=None, | |
metadata={"help": "for unsupervised param tuning"}, | |
) | |
min_lm_ppl: float = field( | |
default=0, | |
metadata={"help": "for unsupervised param tuning"}, | |
) | |
min_vt_uer: float = field( | |
default=0, | |
metadata={"help": "for unsupervised param tuning"}, | |
) | |
blank_weight: float = field( | |
default=0, | |
metadata={"help": "value to add or set for blank emission"}, | |
) | |
blank_mode: str = field( | |
default="set", | |
metadata={ | |
"help": "can be add or set, how to modify blank emission with blank weight" | |
}, | |
) | |
sil_is_blank: bool = field( | |
default=False, | |
metadata={"help": "if true, <SIL> token is same as blank token"}, | |
) | |
unsupervised_tuning: bool = field( | |
default=False, | |
metadata={ | |
"help": "if true, returns a score based on unsupervised param selection metric instead of UER" | |
}, | |
) | |
is_ax: bool = field( | |
default=False, | |
metadata={ | |
"help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" | |
}, | |
) | |
def get_dataset_itr(cfg, task): | |
return task.get_batch_iterator( | |
dataset=task.dataset(cfg.fairseq.dataset.gen_subset), | |
max_tokens=cfg.fairseq.dataset.max_tokens, | |
max_sentences=cfg.fairseq.dataset.batch_size, | |
max_positions=(sys.maxsize, sys.maxsize), | |
ignore_invalid_inputs=cfg.fairseq.dataset.skip_invalid_size_inputs_valid_test, | |
required_batch_size_multiple=cfg.fairseq.dataset.required_batch_size_multiple, | |
num_shards=cfg.fairseq.dataset.num_shards, | |
shard_id=cfg.fairseq.dataset.shard_id, | |
num_workers=cfg.fairseq.dataset.num_workers, | |
data_buffer_size=cfg.fairseq.dataset.data_buffer_size, | |
).next_epoch_itr(shuffle=False) | |
def process_predictions( | |
cfg: UnsupGenerateConfig, | |
hypos, | |
tgt_dict, | |
target_tokens, | |
res_files, | |
): | |
retval = [] | |
word_preds = [] | |
transcriptions = [] | |
dec_scores = [] | |
for i, hypo in enumerate(hypos[: min(len(hypos), cfg.nbest)]): | |
if torch.is_tensor(hypo["tokens"]): | |
tokens = hypo["tokens"].int().cpu() | |
tokens = tokens[tokens >= tgt_dict.nspecial] | |
hyp_pieces = tgt_dict.string(tokens) | |
else: | |
hyp_pieces = " ".join(hypo["tokens"]) | |
if "words" in hypo and len(hypo["words"]) > 0: | |
hyp_words = " ".join(hypo["words"]) | |
else: | |
hyp_words = post_process(hyp_pieces, cfg.post_process) | |
to_write = {} | |
if res_files is not None: | |
to_write[res_files["hypo.units"]] = hyp_pieces | |
to_write[res_files["hypo.words"]] = hyp_words | |
tgt_words = "" | |
if target_tokens is not None: | |
if isinstance(target_tokens, str): | |
tgt_pieces = tgt_words = target_tokens | |
else: | |
tgt_pieces = tgt_dict.string(target_tokens) | |
tgt_words = post_process(tgt_pieces, cfg.post_process) | |
if res_files is not None: | |
to_write[res_files["ref.units"]] = tgt_pieces | |
to_write[res_files["ref.words"]] = tgt_words | |
if not cfg.fairseq.common_eval.quiet: | |
logger.info(f"HYPO {i}:" + hyp_words) | |
if tgt_words: | |
logger.info("TARGET:" + tgt_words) | |
if "am_score" in hypo and "lm_score" in hypo: | |
logger.info( | |
f"DECODER AM SCORE: {hypo['am_score']}, DECODER LM SCORE: {hypo['lm_score']}, DECODER SCORE: {hypo['score']}" | |
) | |
elif "score" in hypo: | |
logger.info(f"DECODER SCORE: {hypo['score']}") | |
logger.info("___________________") | |
hyp_words_arr = hyp_words.split() | |
tgt_words_arr = tgt_words.split() | |
retval.append( | |
( | |
editdistance.eval(hyp_words_arr, tgt_words_arr), | |
len(hyp_words_arr), | |
len(tgt_words_arr), | |
hyp_pieces, | |
hyp_words, | |
) | |
) | |
word_preds.append(hyp_words_arr) | |
transcriptions.append(to_write) | |
dec_scores.append(-hypo.get("score", 0)) # negate cuz kaldi returns NLL | |
if len(retval) > 1: | |
best = None | |
for r, t in zip(retval, transcriptions): | |
if best is None or r[0] < best[0][0]: | |
best = r, t | |
for dest, tran in best[1].items(): | |
print(tran, file=dest) | |
dest.flush() | |
return best[0] | |
assert len(transcriptions) == 1 | |
for dest, tran in transcriptions[0].items(): | |
print(tran, file=dest) | |
return retval[0] | |
def prepare_result_files(cfg: UnsupGenerateConfig): | |
def get_res_file(file_prefix): | |
if cfg.fairseq.dataset.num_shards > 1: | |
file_prefix = f"{cfg.fairseq.dataset.shard_id}_{file_prefix}" | |
path = os.path.join( | |
cfg.results_path, | |
"{}{}.txt".format( | |
cfg.fairseq.dataset.gen_subset, | |
file_prefix, | |
), | |
) | |
return open(path, "w", buffering=1) | |
if not cfg.results_path: | |
return None | |
return { | |
"hypo.words": get_res_file(""), | |
"hypo.units": get_res_file("_units"), | |
"ref.words": get_res_file("_ref"), | |
"ref.units": get_res_file("_ref_units"), | |
"hypo.nbest.words": get_res_file("_nbest_words"), | |
} | |
def optimize_models(cfg: UnsupGenerateConfig, use_cuda, models): | |
"""Optimize ensemble for generation""" | |
for model in models: | |
model.eval() | |
if cfg.fairseq.common.fp16: | |
model.half() | |
if use_cuda: | |
model.cuda() | |
GenResult = namedtuple( | |
"GenResult", | |
[ | |
"count", | |
"errs_t", | |
"gen_timer", | |
"lengths_hyp_unit_t", | |
"lengths_hyp_t", | |
"lengths_t", | |
"lm_score_t", | |
"num_feats", | |
"num_sentences", | |
"num_symbols", | |
"vt_err_t", | |
"vt_length_t", | |
], | |
) | |
def generate(cfg: UnsupGenerateConfig, models, saved_cfg, use_cuda): | |
task = tasks.setup_task(cfg.fairseq.task) | |
saved_cfg.task.labels = cfg.fairseq.task.labels | |
task.load_dataset(cfg.fairseq.dataset.gen_subset, task_cfg=saved_cfg.task) | |
# Set dictionary | |
tgt_dict = task.target_dictionary | |
logger.info( | |
"| {} {} {} examples".format( | |
cfg.fairseq.task.data, | |
cfg.fairseq.dataset.gen_subset, | |
len(task.dataset(cfg.fairseq.dataset.gen_subset)), | |
) | |
) | |
# Load dataset (possibly sharded) | |
itr = get_dataset_itr(cfg, task) | |
# Initialize generator | |
gen_timer = StopwatchMeter() | |
def build_generator(cfg: UnsupGenerateConfig): | |
w2l_decoder = cfg.w2l_decoder | |
if w2l_decoder == DecoderType.VITERBI: | |
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder | |
return W2lViterbiDecoder(cfg, task.target_dictionary) | |
elif w2l_decoder == DecoderType.KENLM: | |
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder | |
return W2lKenLMDecoder(cfg, task.target_dictionary) | |
elif w2l_decoder == DecoderType.FAIRSEQ: | |
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder | |
return W2lFairseqLMDecoder(cfg, task.target_dictionary) | |
elif w2l_decoder == DecoderType.KALDI: | |
from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoder | |
assert cfg.kaldi_decoder_config is not None | |
return KaldiDecoder( | |
cfg.kaldi_decoder_config, | |
cfg.beam, | |
) | |
else: | |
raise NotImplementedError( | |
"only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment but found " | |
+ str(w2l_decoder) | |
) | |
generator = build_generator(cfg) | |
kenlm = None | |
fairseq_lm = None | |
if cfg.lm_model is not None: | |
import kenlm | |
kenlm = kenlm.Model(cfg.lm_model) | |
num_sentences = 0 | |
if cfg.results_path is not None and not os.path.exists(cfg.results_path): | |
os.makedirs(cfg.results_path) | |
res_files = prepare_result_files(cfg) | |
errs_t = 0 | |
lengths_hyp_t = 0 | |
lengths_hyp_unit_t = 0 | |
lengths_t = 0 | |
count = 0 | |
num_feats = 0 | |
all_hyp_pieces = [] | |
all_hyp_words = [] | |
num_symbols = ( | |
len([s for s in tgt_dict.symbols if not s.startswith("madeup")]) | |
- tgt_dict.nspecial | |
) | |
targets = None | |
if cfg.targets is not None: | |
tgt_path = os.path.join( | |
cfg.fairseq.task.data, cfg.fairseq.dataset.gen_subset + "." + cfg.targets | |
) | |
if os.path.exists(tgt_path): | |
with open(tgt_path, "r") as f: | |
targets = f.read().splitlines() | |
viterbi_transcript = None | |
if cfg.viterbi_transcript is not None and len(cfg.viterbi_transcript) > 0: | |
logger.info(f"loading viterbi transcript from {cfg.viterbi_transcript}") | |
with open(cfg.viterbi_transcript, "r") as vf: | |
viterbi_transcript = vf.readlines() | |
viterbi_transcript = [v.rstrip().split() for v in viterbi_transcript] | |
gen_timer.start() | |
start = 0 | |
end = len(itr) | |
hypo_futures = None | |
if cfg.w2l_decoder == DecoderType.KALDI: | |
logger.info("Extracting features") | |
hypo_futures = [] | |
samples = [] | |
with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t: | |
for i, sample in enumerate(t): | |
if "net_input" not in sample or i < start or i >= end: | |
continue | |
if "padding_mask" not in sample["net_input"]: | |
sample["net_input"]["padding_mask"] = None | |
hypos, num_feats = gen_hypos( | |
generator, models, num_feats, sample, task, use_cuda | |
) | |
hypo_futures.append(hypos) | |
samples.append(sample) | |
itr = list(zip(hypo_futures, samples)) | |
start = 0 | |
end = len(itr) | |
logger.info("Finished extracting features") | |
with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t: | |
for i, sample in enumerate(t): | |
if i < start or i >= end: | |
continue | |
if hypo_futures is not None: | |
hypos, sample = sample | |
hypos = [h.result() for h in hypos] | |
else: | |
if "net_input" not in sample: | |
continue | |
hypos, num_feats = gen_hypos( | |
generator, models, num_feats, sample, task, use_cuda | |
) | |
for i, sample_id in enumerate(sample["id"].tolist()): | |
if targets is not None: | |
target_tokens = targets[sample_id] | |
elif "target" in sample or "target_label" in sample: | |
toks = ( | |
sample["target"][i, :] | |
if "target_label" not in sample | |
else sample["target_label"][i, :] | |
) | |
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu() | |
else: | |
target_tokens = None | |
# Process top predictions | |
( | |
errs, | |
length_hyp, | |
length, | |
hyp_pieces, | |
hyp_words, | |
) = process_predictions( | |
cfg, | |
hypos[i], | |
tgt_dict, | |
target_tokens, | |
res_files, | |
) | |
errs_t += errs | |
lengths_hyp_t += length_hyp | |
lengths_hyp_unit_t += ( | |
len(hyp_pieces) if len(hyp_pieces) > 0 else len(hyp_words) | |
) | |
lengths_t += length | |
count += 1 | |
all_hyp_pieces.append(hyp_pieces) | |
all_hyp_words.append(hyp_words) | |
num_sentences += ( | |
sample["nsentences"] if "nsentences" in sample else sample["id"].numel() | |
) | |
lm_score_sum = 0 | |
if kenlm is not None: | |
if cfg.unit_lm: | |
lm_score_sum = sum(kenlm.score(w) for w in all_hyp_pieces) | |
else: | |
lm_score_sum = sum(kenlm.score(w) for w in all_hyp_words) | |
elif fairseq_lm is not None: | |
lm_score_sum = sum(fairseq_lm.score([h.split() for h in all_hyp_words])[0]) | |
vt_err_t = 0 | |
vt_length_t = 0 | |
if viterbi_transcript is not None: | |
unit_hyps = [] | |
if cfg.targets is not None and cfg.lexicon is not None: | |
lex = {} | |
with open(cfg.lexicon, "r") as lf: | |
for line in lf: | |
items = line.rstrip().split() | |
lex[items[0]] = items[1:] | |
for h in all_hyp_pieces: | |
hyp_ws = [] | |
for w in h.split(): | |
assert w in lex, w | |
hyp_ws.extend(lex[w]) | |
unit_hyps.append(hyp_ws) | |
else: | |
unit_hyps.extend([h.split() for h in all_hyp_words]) | |
vt_err_t = sum( | |
editdistance.eval(vt, h) for vt, h in zip(viterbi_transcript, unit_hyps) | |
) | |
vt_length_t = sum(len(h) for h in viterbi_transcript) | |
if res_files is not None: | |
for r in res_files.values(): | |
r.close() | |
gen_timer.stop(lengths_hyp_t) | |
return GenResult( | |
count, | |
errs_t, | |
gen_timer, | |
lengths_hyp_unit_t, | |
lengths_hyp_t, | |
lengths_t, | |
lm_score_sum, | |
num_feats, | |
num_sentences, | |
num_symbols, | |
vt_err_t, | |
vt_length_t, | |
) | |
def gen_hypos(generator, models, num_feats, sample, task, use_cuda): | |
sample = utils.move_to_cuda(sample) if use_cuda else sample | |
if "features" in sample["net_input"]: | |
sample["net_input"]["dense_x_only"] = True | |
num_feats += ( | |
sample["net_input"]["features"].shape[0] | |
* sample["net_input"]["features"].shape[1] | |
) | |
hypos = task.inference_step(generator, models, sample, None) | |
return hypos, num_feats | |
def main(cfg: UnsupGenerateConfig, model=None): | |
if ( | |
cfg.fairseq.dataset.max_tokens is None | |
and cfg.fairseq.dataset.batch_size is None | |
): | |
cfg.fairseq.dataset.max_tokens = 1024000 | |
use_cuda = torch.cuda.is_available() and not cfg.fairseq.common.cpu | |
task = tasks.setup_task(cfg.fairseq.task) | |
overrides = ast.literal_eval(cfg.fairseq.common_eval.model_overrides) | |
if cfg.fairseq.task._name == "unpaired_audio_text": | |
overrides["model"] = { | |
"blank_weight": cfg.blank_weight, | |
"blank_mode": cfg.blank_mode, | |
"blank_is_sil": cfg.sil_is_blank, | |
"no_softmax": True, | |
"segmentation": { | |
"type": "NONE", | |
}, | |
} | |
else: | |
overrides["model"] = { | |
"blank_weight": cfg.blank_weight, | |
"blank_mode": cfg.blank_mode, | |
} | |
if model is None: | |
# Load ensemble | |
logger.info("| loading model(s) from {}".format(cfg.fairseq.common_eval.path)) | |
models, saved_cfg = checkpoint_utils.load_model_ensemble( | |
cfg.fairseq.common_eval.path.split("\\"), | |
arg_overrides=overrides, | |
task=task, | |
suffix=cfg.fairseq.checkpoint.checkpoint_suffix, | |
strict=(cfg.fairseq.checkpoint.checkpoint_shard_count == 1), | |
num_shards=cfg.fairseq.checkpoint.checkpoint_shard_count, | |
) | |
optimize_models(cfg, use_cuda, models) | |
else: | |
models = [model] | |
saved_cfg = cfg.fairseq | |
with open_dict(saved_cfg.task): | |
saved_cfg.task.shuffle = False | |
saved_cfg.task.sort_by_length = False | |
gen_result = generate(cfg, models, saved_cfg, use_cuda) | |
wer = None | |
if gen_result.lengths_t > 0: | |
wer = gen_result.errs_t * 100.0 / gen_result.lengths_t | |
logger.info(f"WER: {wer}") | |
lm_ppl = float("inf") | |
if gen_result.lm_score_t != 0 and gen_result.lengths_hyp_t > 0: | |
hyp_len = gen_result.lengths_hyp_t | |
lm_ppl = math.pow( | |
10, -gen_result.lm_score_t / (hyp_len + gen_result.num_sentences) | |
) | |
logger.info(f"LM PPL: {lm_ppl}") | |
logger.info( | |
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" | |
" sentences/s, {:.2f} tokens/s)".format( | |
gen_result.num_sentences, | |
gen_result.gen_timer.n, | |
gen_result.gen_timer.sum, | |
gen_result.num_sentences / gen_result.gen_timer.sum, | |
1.0 / gen_result.gen_timer.avg, | |
) | |
) | |
vt_diff = None | |
if gen_result.vt_length_t > 0: | |
vt_diff = gen_result.vt_err_t / gen_result.vt_length_t | |
vt_diff = max(cfg.min_vt_uer, vt_diff) | |
lm_ppl = max(cfg.min_lm_ppl, lm_ppl) | |
if not cfg.unsupervised_tuning == 0: | |
weighted_score = wer | |
else: | |
weighted_score = math.log(lm_ppl) * (vt_diff or 1.0) | |
res = ( | |
f"| Generate {cfg.fairseq.dataset.gen_subset} with beam={cfg.beam}, " | |
f"lm_weight={cfg.kaldi_decoder_config.acoustic_scale if cfg.kaldi_decoder_config else cfg.lm_weight}, " | |
f"word_score={cfg.word_score}, sil_weight={cfg.sil_weight}, blank_weight={cfg.blank_weight}, " | |
f"WER: {wer}, LM_PPL: {lm_ppl}, num feats: {gen_result.num_feats}, " | |
f"length: {gen_result.lengths_hyp_t}, UER to viterbi: {(vt_diff or 0) * 100}, score: {weighted_score}" | |
) | |
logger.info(res) | |
# print(res) | |
return task, weighted_score | |
def hydra_main(cfg): | |
with open_dict(cfg): | |
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) | |
cfg.job_logging_cfg = OmegaConf.to_container( | |
HydraConfig.get().job_logging, resolve=True | |
) | |
cfg = OmegaConf.create( | |
OmegaConf.to_container(cfg, resolve=False, enum_to_str=False) | |
) | |
OmegaConf.set_struct(cfg, True) | |
logger.info(cfg) | |
utils.import_user_module(cfg.fairseq.common) | |
_, score = main(cfg) | |
if cfg.is_ax: | |
return score, None | |
return score | |
def cli_main(): | |
try: | |
from hydra._internal.utils import get_args | |
cfg_name = get_args().config_name or "config" | |
except: | |
logger.warning("Failed to get config name from hydra args") | |
cfg_name = "config" | |
cs = ConfigStore.instance() | |
cs.store(name=cfg_name, node=UnsupGenerateConfig) | |
hydra_main() | |
if __name__ == "__main__": | |
cli_main() | |