#! /usr/bin/env python # -*- coding: utf-8 -*- # Copyright 2023 Imperial College London (Pingchuan Ma) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import os import json import torch import argparse import numpy as np from espnet.asr.asr_utils import torch_load from espnet.asr.asr_utils import get_model_conf from espnet.asr.asr_utils import add_results_to_json from espnet.nets.batch_beam_search import BatchBeamSearch from espnet.nets.lm_interface import dynamic_import_lm from espnet.nets.scorers.length_bonus import LengthBonus from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E class AVSR(torch.nn.Module): def __init__(self, modality, model_path, model_conf, rnnlm=None, rnnlm_conf=None, penalty=0., ctc_weight=0.1, lm_weight=0., beam_size=40, device="cuda:0"): super(AVSR, self).__init__() self.device = device if modality == "audiovisual": from espnet.nets.pytorch_backend.e2e_asr_transformer_av import E2E else: from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E with open(model_conf, "rb") as f: confs = json.load(f) args = confs if isinstance(confs, dict) else confs[2] self.train_args = argparse.Namespace(**args) labels_type = getattr(self.train_args, "labels_type", "char") if labels_type == "char": self.token_list = self.train_args.char_list elif labels_type == "unigram5000": file_path = os.path.join(os.path.dirname(__file__), "tokens", "unigram5000_units.txt") self.token_list = [''] + [word.split()[0] for word in open(file_path).read().splitlines()] + [''] self.odim = len(self.token_list) self.model = E2E(self.odim, self.train_args) self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) self.model.to(device=self.device).eval() self.beam_search = get_beam_search_decoder(self.model, self.token_list, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size) self.beam_search.to(device=self.device).eval() def infer(self, data): with torch.no_grad(): if isinstance(data, tuple): enc_feats = self.model.encode(data[0].to(self.device), data[1].to(self.device)) else: enc_feats = self.model.encode(data.to(self.device)) nbest_hyps = self.beam_search(enc_feats) nbest_hyps = [h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), 1)]] transcription = add_results_to_json(nbest_hyps, self.token_list) transcription = transcription.replace("▁", " ").strip() return transcription.replace("", "") def get_beam_search_decoder(model, token_list, rnnlm=None, rnnlm_conf=None, penalty=0, ctc_weight=0.1, lm_weight=0., beam_size=40): sos = model.odim - 1 eos = model.odim - 1 scorers = model.scorers() if not rnnlm: lm = None else: lm_args = get_model_conf(rnnlm, rnnlm_conf) lm_model_module = getattr(lm_args, "model_module", "default") lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) lm = lm_class(len(token_list), lm_args) torch_load(rnnlm, lm) lm.eval() scorers["lm"] = lm scorers["length_bonus"] = LengthBonus(len(token_list)) weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, length_bonus=penalty, ) return BatchBeamSearch( beam_size=beam_size, vocab_size=len(token_list), weights=weights, scorers=scorers, sos=sos, eos=eos, token_list=token_list, pre_beam_score_key=None if ctc_weight == 1.0 else "decoder", )