Spaces:
Runtime error
Runtime error
#! /usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Copyright 2023 Imperial College London (Pingchuan Ma) | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
import os | |
import json | |
import torch | |
import argparse | |
import numpy as np | |
from espnet.asr.asr_utils import torch_load | |
from espnet.asr.asr_utils import get_model_conf | |
from espnet.asr.asr_utils import add_results_to_json | |
from espnet.nets.batch_beam_search import BatchBeamSearch | |
from espnet.nets.lm_interface import dynamic_import_lm | |
from espnet.nets.scorers.length_bonus import LengthBonus | |
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E | |
class AVSR(torch.nn.Module): | |
def __init__(self, modality, model_path, model_conf, rnnlm=None, rnnlm_conf=None, | |
penalty=0., ctc_weight=0.1, lm_weight=0., beam_size=40, device="cuda:0"): | |
super(AVSR, self).__init__() | |
self.device = device | |
if modality == "audiovisual": | |
from espnet.nets.pytorch_backend.e2e_asr_transformer_av import E2E | |
else: | |
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E | |
with open(model_conf, "rb") as f: | |
confs = json.load(f) | |
args = confs if isinstance(confs, dict) else confs[2] | |
self.train_args = argparse.Namespace(**args) | |
labels_type = getattr(self.train_args, "labels_type", "char") | |
if labels_type == "char": | |
self.token_list = self.train_args.char_list | |
elif labels_type == "unigram5000": | |
file_path = os.path.join(os.path.dirname(__file__), "tokens", "unigram5000_units.txt") | |
self.token_list = ['<blank>'] + [word.split()[0] for word in open(file_path).read().splitlines()] + ['<eos>'] | |
self.odim = len(self.token_list) | |
self.model = E2E(self.odim, self.train_args) | |
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) | |
self.model.to(device=self.device).eval() | |
self.beam_search = get_beam_search_decoder(self.model, self.token_list, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size) | |
self.beam_search.to(device=self.device).eval() | |
def infer(self, data): | |
with torch.no_grad(): | |
if isinstance(data, tuple): | |
enc_feats = self.model.encode(data[0].to(self.device), data[1].to(self.device)) | |
else: | |
enc_feats = self.model.encode(data.to(self.device)) | |
nbest_hyps = self.beam_search(enc_feats) | |
nbest_hyps = [h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), 1)]] | |
transcription = add_results_to_json(nbest_hyps, self.token_list) | |
transcription = transcription.replace("▁", " ").strip() | |
return transcription.replace("<eos>", "") | |
def get_beam_search_decoder(model, token_list, rnnlm=None, rnnlm_conf=None, penalty=0, ctc_weight=0.1, lm_weight=0., beam_size=40): | |
sos = model.odim - 1 | |
eos = model.odim - 1 | |
scorers = model.scorers() | |
if not rnnlm: | |
lm = None | |
else: | |
lm_args = get_model_conf(rnnlm, rnnlm_conf) | |
lm_model_module = getattr(lm_args, "model_module", "default") | |
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) | |
lm = lm_class(len(token_list), lm_args) | |
torch_load(rnnlm, lm) | |
lm.eval() | |
scorers["lm"] = lm | |
scorers["length_bonus"] = LengthBonus(len(token_list)) | |
weights = dict( | |
decoder=1.0 - ctc_weight, | |
ctc=ctc_weight, | |
lm=lm_weight, | |
length_bonus=penalty, | |
) | |
return BatchBeamSearch( | |
beam_size=beam_size, | |
vocab_size=len(token_list), | |
weights=weights, | |
scorers=scorers, | |
sos=sos, | |
eos=eos, | |
token_list=token_list, | |
pre_beam_score_key=None if ctc_weight == 1.0 else "decoder", | |
) | |