Spaces:
Build error
Build error
import argparse | |
import torch | |
from pathlib import Path | |
import yaml | |
from .frontend import DefaultFrontend | |
from .utterance_mvn import UtteranceMVN | |
from .encoder.conformer_encoder import ConformerEncoder | |
_model = None # type: PPGModel | |
_device = None | |
class PPGModel(torch.nn.Module): | |
def __init__( | |
self, | |
frontend, | |
normalizer, | |
encoder, | |
): | |
super().__init__() | |
self.frontend = frontend | |
self.normalize = normalizer | |
self.encoder = encoder | |
def forward(self, speech, speech_lengths): | |
""" | |
Args: | |
speech (tensor): (B, L) | |
speech_lengths (tensor): (B, ) | |
Returns: | |
bottle_neck_feats (tensor): (B, L//hop_size, 144) | |
""" | |
feats, feats_lengths = self._extract_feats(speech, speech_lengths) | |
feats, feats_lengths = self.normalize(feats, feats_lengths) | |
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) | |
return encoder_out | |
def _extract_feats( | |
self, speech: torch.Tensor, speech_lengths: torch.Tensor | |
): | |
assert speech_lengths.dim() == 1, speech_lengths.shape | |
# for data-parallel | |
speech = speech[:, : speech_lengths.max()] | |
if self.frontend is not None: | |
# Frontend | |
# e.g. STFT and Feature extract | |
# data_loader may send time-domain signal in this case | |
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) | |
feats, feats_lengths = self.frontend(speech, speech_lengths) | |
else: | |
# No frontend and no feature extract | |
feats, feats_lengths = speech, speech_lengths | |
return feats, feats_lengths | |
def extract_from_wav(self, src_wav): | |
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device) | |
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device) | |
return self(src_wav_tensor, src_wav_lengths) | |
def build_model(args): | |
normalizer = UtteranceMVN(**args.normalize_conf) | |
frontend = DefaultFrontend(**args.frontend_conf) | |
encoder = ConformerEncoder(input_size=80, **args.encoder_conf) | |
model = PPGModel(frontend, normalizer, encoder) | |
return model | |
def load_model(model_file, device=None): | |
global _model, _device | |
if device is None: | |
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
_device = device | |
# search a config file | |
model_config_fpaths = list(model_file.parent.rglob("*.yaml")) | |
config_file = model_config_fpaths[0] | |
with config_file.open("r", encoding="utf-8") as f: | |
args = yaml.safe_load(f) | |
args = argparse.Namespace(**args) | |
model = build_model(args) | |
model_state_dict = model.state_dict() | |
ckpt_state_dict = torch.load(model_file, map_location=_device) | |
ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k} | |
model_state_dict.update(ckpt_state_dict) | |
model.load_state_dict(model_state_dict) | |
_model = model.eval().to(_device) | |
return _model | |