Spaces:
Build error
Build error
import numpy as np | |
import torch | |
import torchaudio | |
from coqpit import Coqpit | |
from torch import nn | |
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss | |
from TTS.utils.generic_utils import set_init_dict | |
from TTS.utils.io import load_fsspec | |
class PreEmphasis(nn.Module): | |
def __init__(self, coefficient=0.97): | |
super().__init__() | |
self.coefficient = coefficient | |
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) | |
def forward(self, x): | |
assert len(x.size()) == 2 | |
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") | |
return torch.nn.functional.conv1d(x, self.filter).squeeze(1) | |
class BaseEncoder(nn.Module): | |
"""Base `encoder` class. Every new `encoder` model must inherit this. | |
It defines common `encoder` specific functions. | |
""" | |
# pylint: disable=W0102 | |
def __init__(self): | |
super(BaseEncoder, self).__init__() | |
def get_torch_mel_spectrogram_class(self, audio_config): | |
return torch.nn.Sequential( | |
PreEmphasis(audio_config["preemphasis"]), | |
# TorchSTFT( | |
# n_fft=audio_config["fft_size"], | |
# hop_length=audio_config["hop_length"], | |
# win_length=audio_config["win_length"], | |
# sample_rate=audio_config["sample_rate"], | |
# window="hamming_window", | |
# mel_fmin=0.0, | |
# mel_fmax=None, | |
# use_htk=True, | |
# do_amp_to_db=False, | |
# n_mels=audio_config["num_mels"], | |
# power=2.0, | |
# use_mel=True, | |
# mel_norm=None, | |
# ) | |
torchaudio.transforms.MelSpectrogram( | |
sample_rate=audio_config["sample_rate"], | |
n_fft=audio_config["fft_size"], | |
win_length=audio_config["win_length"], | |
hop_length=audio_config["hop_length"], | |
window_fn=torch.hamming_window, | |
n_mels=audio_config["num_mels"], | |
), | |
) | |
def inference(self, x, l2_norm=True): | |
return self.forward(x, l2_norm) | |
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): | |
""" | |
Generate embeddings for a batch of utterances | |
x: 1xTxD | |
""" | |
# map to the waveform size | |
if self.use_torch_spec: | |
num_frames = num_frames * self.audio_config["hop_length"] | |
max_len = x.shape[1] | |
if max_len < num_frames: | |
num_frames = max_len | |
offsets = np.linspace(0, max_len - num_frames, num=num_eval) | |
frames_batch = [] | |
for offset in offsets: | |
offset = int(offset) | |
end_offset = int(offset + num_frames) | |
frames = x[:, offset:end_offset] | |
frames_batch.append(frames) | |
frames_batch = torch.cat(frames_batch, dim=0) | |
embeddings = self.inference(frames_batch, l2_norm=l2_norm) | |
if return_mean: | |
embeddings = torch.mean(embeddings, dim=0, keepdim=True) | |
return embeddings | |
def get_criterion(self, c: Coqpit, num_classes=None): | |
if c.loss == "ge2e": | |
criterion = GE2ELoss(loss_method="softmax") | |
elif c.loss == "angleproto": | |
criterion = AngleProtoLoss() | |
elif c.loss == "softmaxproto": | |
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) | |
else: | |
raise Exception("The %s not is a loss supported" % c.loss) | |
return criterion | |
def load_checkpoint( | |
self, | |
config: Coqpit, | |
checkpoint_path: str, | |
eval: bool = False, | |
use_cuda: bool = False, | |
criterion=None, | |
cache=False, | |
): | |
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) | |
try: | |
self.load_state_dict(state["model"]) | |
print(" > Model fully restored. ") | |
except (KeyError, RuntimeError) as error: | |
# If eval raise the error | |
if eval: | |
raise error | |
print(" > Partial model initialization.") | |
model_dict = self.state_dict() | |
model_dict = set_init_dict(model_dict, state["model"], c) | |
self.load_state_dict(model_dict) | |
del model_dict | |
# load the criterion for restore_path | |
if criterion is not None and "criterion" in state: | |
try: | |
criterion.load_state_dict(state["criterion"]) | |
except (KeyError, RuntimeError) as error: | |
print(" > Criterion load ignored because of:", error) | |
# instance and load the criterion for the encoder classifier in inference time | |
if ( | |
eval | |
and criterion is None | |
and "criterion" in state | |
and getattr(config, "map_classid_to_classname", None) is not None | |
): | |
criterion = self.get_criterion(config, len(config.map_classid_to_classname)) | |
criterion.load_state_dict(state["criterion"]) | |
if use_cuda: | |
self.cuda() | |
if criterion is not None: | |
criterion = criterion.cuda() | |
if eval: | |
self.eval() | |
assert not self.training | |
if not eval: | |
return criterion, state["step"] | |
return criterion | |