d22cs051's picture
add app for speaker verifcation
fc8786d
raw
history blame
No virus
2.57 kB
import torch
import fairseq
from packaging import version
import torch.nn.functional as F
from fairseq import tasks
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from omegaconf import OmegaConf
from s3prl.upstream.interfaces import UpstreamBase
from torch.nn.utils.rnn import pad_sequence
def load_model(filepath):
state = torch.load(filepath, map_location=lambda storage, loc: storage)
# state = load_checkpoint_to_cpu(filepath)
state["cfg"] = OmegaConf.create(state["cfg"])
if "args" in state and state["args"] is not None:
cfg = convert_namespace_to_omegaconf(state["args"])
elif "cfg" in state and state["cfg"] is not None:
cfg = state["cfg"]
else:
raise RuntimeError(
f"Neither args nor cfg exist in state keys = {state.keys()}"
)
task = tasks.setup_task(cfg.task)
if "task_state" in state:
task.load_state_dict(state["task_state"])
model = task.build_model(cfg.model)
return model, cfg, task
###################
# UPSTREAM EXPERT #
###################
class UpstreamExpert(UpstreamBase):
def __init__(self, ckpt, **kwargs):
super().__init__(**kwargs)
assert version.parse(fairseq.__version__) > version.parse(
"0.10.2"
), "Please install the fairseq master branch."
model, cfg, task = load_model(ckpt)
self.model = model
self.task = task
if len(self.hooks) == 0:
module_name = "self.model.encoder.layers"
for module_id in range(len(eval(module_name))):
self.add_hook(
f"{module_name}[{module_id}]",
lambda input, output: input[0].transpose(0, 1),
)
self.add_hook("self.model.encoder", lambda input, output: output[0])
def forward(self, wavs):
if self.task.cfg.normalize:
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
device = wavs[0].device
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
wav_padding_mask = ~torch.lt(
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
wav_lengths.unsqueeze(1),
)
padded_wav = pad_sequence(wavs, batch_first=True)
features, feat_padding_mask = self.model.extract_features(
padded_wav,
padding_mask=wav_padding_mask,
mask=None,
)
return {
"default": features,
}