Spaces:
Running
Running
unilight
commited on
Commit
·
052c3ff
1
Parent(s):
781dc92
init
Browse files- app.py +184 -0
- models/modules.py +59 -0
- models/sslmos.py +212 -0
app.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchaudio
|
6 |
+
from loguru import logger
|
7 |
+
import soundfile as sf
|
8 |
+
import librosa
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
import time
|
13 |
+
import torch
|
14 |
+
import yaml
|
15 |
+
|
16 |
+
# from s3prl_vc.upstream.interface import get_upstream
|
17 |
+
# from s3prl.nn import Featurizer
|
18 |
+
# import s3prl_vc.models
|
19 |
+
# from s3prl_vc.utils import read_hdf5
|
20 |
+
# from s3prl_vc.vocoder import Vocoder
|
21 |
+
|
22 |
+
|
23 |
+
# ---------- Settings ----------
|
24 |
+
GPU_ID = '-1'
|
25 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
|
26 |
+
DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu'
|
27 |
+
|
28 |
+
SERVER_PORT = 42208
|
29 |
+
SERVER_NAME = "0.0.0.0"
|
30 |
+
SSL_DIR = './keyble_ssl'
|
31 |
+
|
32 |
+
FS = 16000
|
33 |
+
resamplers = {}
|
34 |
+
MIN_REQUIRED_WAV_LENGTH = 1040
|
35 |
+
|
36 |
+
# EXAMPLE_DIR = './examples'
|
37 |
+
# en_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "en", '*.wav')))
|
38 |
+
# jp_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "jp", '*.wav')))
|
39 |
+
# zh_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "zh", '*.wav')))
|
40 |
+
|
41 |
+
# TRGSPKS = ["TEF1", "TEF2", "TEM1", "TEM2"]
|
42 |
+
|
43 |
+
# ref_samples = {
|
44 |
+
# trgspk: sorted(glob(os.path.join("./ref_samples", trgspk, '*.wav')))
|
45 |
+
# for trgspk in TRGSPKS
|
46 |
+
# }
|
47 |
+
|
48 |
+
# ---------- Logging ----------
|
49 |
+
logger.add('app.log', mode='a')
|
50 |
+
logger.info('============================= App restarted =============================')
|
51 |
+
|
52 |
+
# ---------- Download models ----------
|
53 |
+
logger.info('============================= Download models ===========================')
|
54 |
+
|
55 |
+
model_paths = {
|
56 |
+
"SSL-MOS, all training sets": {
|
57 |
+
"ckpt": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/checkpoint-86000steps.pkl"),
|
58 |
+
"config": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/config.yml"),
|
59 |
+
}
|
60 |
+
}
|
61 |
+
|
62 |
+
# ---------- Model ----------
|
63 |
+
models = {}
|
64 |
+
for name, path_dict in model_paths.items():
|
65 |
+
logger.info(f'============================= Setting up model for {name} =============')
|
66 |
+
checkpoint_path = path_dict["ckpt"]
|
67 |
+
config_path = path_dict["config"]
|
68 |
+
with open(config_path) as f:
|
69 |
+
config = yaml.load(f, Loader=yaml.Loader)
|
70 |
+
|
71 |
+
if config["model_type"] == "SSLMOS":
|
72 |
+
from models.sslmos import SSLMOS
|
73 |
+
model = SSLMOS(
|
74 |
+
config["model_input"],
|
75 |
+
num_listeners=config.get("num_listeners", None),
|
76 |
+
num_domains=config.get("num_domains", None),
|
77 |
+
**config["model_params"],
|
78 |
+
).to(DEVICE)
|
79 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"])
|
80 |
+
model = model.eval().to(DEVICE)
|
81 |
+
logger.info(f"Loaded model parameters from {checkpoint_path}.")
|
82 |
+
|
83 |
+
models[name] = model
|
84 |
+
|
85 |
+
def read_wav(wav_path):
|
86 |
+
# read waveform
|
87 |
+
waveform, sample_rate = torchaudio.load(
|
88 |
+
wav_path, channels_first=False
|
89 |
+
) # waveform: [T, 1]
|
90 |
+
|
91 |
+
# resample if needed
|
92 |
+
if sample_rate != FS:
|
93 |
+
resampler_key = f"{sample_rate}-{FS}"
|
94 |
+
if resampler_key not in resamplers:
|
95 |
+
resamplers[resampler_key] = torchaudio.transforms.Resample(
|
96 |
+
sample_rate, FS, dtype=waveform.dtype
|
97 |
+
)
|
98 |
+
waveform = resamplers[resampler_key](waveform)
|
99 |
+
|
100 |
+
waveform = waveform.squeeze(-1)
|
101 |
+
|
102 |
+
# always pad to a minumum length
|
103 |
+
if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH:
|
104 |
+
to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2
|
105 |
+
waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0)
|
106 |
+
|
107 |
+
return waveform, sample_rate
|
108 |
+
|
109 |
+
def predict(model_name, wav_file):
|
110 |
+
x, fs = read_wav(wav_file)
|
111 |
+
logger.info('wav file loaded')
|
112 |
+
|
113 |
+
# set up model input
|
114 |
+
model_input = x.unsqueeze(0).to(DEVICE)
|
115 |
+
model_lengths = model_input.new_tensor([model_input.size(1)]).long()
|
116 |
+
inputs = {
|
117 |
+
config["model_input"]: model_input,
|
118 |
+
config["model_input"] + "_lengths": model_lengths,
|
119 |
+
}
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
# model forward
|
123 |
+
if config["inference_mode"] == "mean_listener":
|
124 |
+
outputs = models[model_name].mean_listener_inference(inputs)
|
125 |
+
elif config["inference_mode"] == "mean_net":
|
126 |
+
outputs = models[model_name].mean_net_inference(inputs)
|
127 |
+
|
128 |
+
pred_mean_scores = outputs["scores"].cpu().detach().numpy()[0]
|
129 |
+
|
130 |
+
return pred_mean_scores
|
131 |
+
|
132 |
+
with gr.Blocks(title="S3PRL-VC: Any-to-one voice conversion demo on VCC2020") as demo:
|
133 |
+
gr.Markdown(
|
134 |
+
"""
|
135 |
+
# S3PRL-VC: Any-to-one voice conversion demo on VCC2020
|
136 |
+
### [[Paper (ICASSP2023)]](https://arxiv.org/abs/2110.06280) [[Paper(JSTSP)]](https://arxiv.org/abs/2207.04356) [[Code]](https://github.com/unilight/s3prl-vc)
|
137 |
+
**S3PRL-VC** is a voice conversion (VC) toolkit for benchmarking self-supervised speech representations (S3Rs). The term **any-to-one** means that the system can convert from any unseen speaker to a pre-defined speaker given in training.
|
138 |
+
In this demo, you can record your voice, and the model will convert your voice to one of the four pre-defined speakers. These four speakers come from the **voice conversion challenge (VCC) 2020**. You can listen to the samples to get a sense of what these speakers sound like.
|
139 |
+
The **RTF** of the system is around **1.5~2.5**, i.e. if you recorded a 5 second long audio, it will take 5 * (1.5~2.5) = 7.5~12.5 seconds to generate the output.
|
140 |
+
"""
|
141 |
+
)
|
142 |
+
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Column():
|
145 |
+
gr.Markdown("## Record your speech here!")
|
146 |
+
input_wav = gr.Audio(label="Input speech", source='microphone', type='filepath')
|
147 |
+
|
148 |
+
gr.Markdown("## Select a model!")
|
149 |
+
model_name = gr.Radio(label="Model", choices=list(model_paths.keys()))
|
150 |
+
|
151 |
+
evaluate_btn = gr.Button(value="Evaluate!")
|
152 |
+
# gr.Markdown("### You can use these examples if using a microphone is too troublesome!")
|
153 |
+
# gr.Markdown("I recorded the samples using my Macbook Pro, so there might be some noises.")
|
154 |
+
# gr.Examples(
|
155 |
+
# examples=en_examples,
|
156 |
+
# inputs=input_wav,
|
157 |
+
# label="English examples"
|
158 |
+
# )
|
159 |
+
# gr.Examples(
|
160 |
+
# examples=jp_examples,
|
161 |
+
# inputs=input_wav,
|
162 |
+
# label="Japanese examples"
|
163 |
+
# )
|
164 |
+
# gr.Examples(
|
165 |
+
# examples=zh_examples,
|
166 |
+
# inputs=input_wav,
|
167 |
+
# label="Mandarin examples"
|
168 |
+
# )
|
169 |
+
|
170 |
+
with gr.Column():
|
171 |
+
gr.Markdown("## The predicted scores is here:")
|
172 |
+
output_score = gr.Textbox(label="Prediction", interactive=False)
|
173 |
+
evaluate_btn.click(predict, [model_name, input_wav], output_score)
|
174 |
+
|
175 |
+
if __name__ == '__main__':
|
176 |
+
try:
|
177 |
+
demo.launch(debug=True,
|
178 |
+
enable_queue=True,
|
179 |
+
)
|
180 |
+
except KeyboardInterrupt as e:
|
181 |
+
print(e)
|
182 |
+
|
183 |
+
finally:
|
184 |
+
demo.close()
|
models/modules.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2024 Wen-Chin Huang
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
# LDNet modules
|
7 |
+
# taken from: https://github.com/unilight/LDNet/blob/main/models/modules.py (written by myself)
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
STRIDE = 3
|
13 |
+
|
14 |
+
class Projection(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_dim,
|
18 |
+
hidden_dim,
|
19 |
+
activation,
|
20 |
+
output_type,
|
21 |
+
_output_dim,
|
22 |
+
output_step=1.0,
|
23 |
+
range_clipping=False,
|
24 |
+
):
|
25 |
+
super(Projection, self).__init__()
|
26 |
+
self.output_type = output_type
|
27 |
+
self.range_clipping = range_clipping
|
28 |
+
if output_type == "scalar":
|
29 |
+
output_dim = 1
|
30 |
+
if range_clipping:
|
31 |
+
self.proj = nn.Tanh()
|
32 |
+
elif output_type == "categorical":
|
33 |
+
output_dim = _output_dim
|
34 |
+
self.output_step = output_step
|
35 |
+
else:
|
36 |
+
raise NotImplementedError("wrong output_type: {}".format(output_type))
|
37 |
+
|
38 |
+
self.net = nn.Sequential(
|
39 |
+
nn.Linear(in_dim, hidden_dim),
|
40 |
+
activation(),
|
41 |
+
nn.Dropout(0.3),
|
42 |
+
nn.Linear(hidden_dim, output_dim),
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x, inference=False):
|
46 |
+
output = self.net(x)
|
47 |
+
|
48 |
+
# scalar / categorical
|
49 |
+
if self.output_type == "scalar":
|
50 |
+
# range clipping
|
51 |
+
if self.range_clipping:
|
52 |
+
return self.proj(output) * 2.0 + 3
|
53 |
+
else:
|
54 |
+
return output
|
55 |
+
else:
|
56 |
+
if inference:
|
57 |
+
return torch.argmax(output, dim=-1) * self.output_step + 1
|
58 |
+
else:
|
59 |
+
return output
|
models/sslmos.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2024 Wen-Chin Huang
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
# SSLMOS model
|
7 |
+
# modified from: https://github.com/nii-yamagishilab/mos-finetune-ssl/blob/main/mos_fairseq.py (written by Erica Cooper)
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from .modules import Projection
|
12 |
+
|
13 |
+
|
14 |
+
class SSLMOS(torch.nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
# model related
|
18 |
+
ssl_module: str,
|
19 |
+
s3prl_name: str,
|
20 |
+
ssl_model_output_dim: int,
|
21 |
+
ssl_model_layer_idx: int,
|
22 |
+
# mean net related
|
23 |
+
mean_net_dnn_dim: int = 64,
|
24 |
+
mean_net_output_type: str = "scalar",
|
25 |
+
mean_net_output_dim: int = 5,
|
26 |
+
mean_net_output_step: float = 0.25,
|
27 |
+
mean_net_range_clipping: bool = True,
|
28 |
+
# listener related
|
29 |
+
use_listener_modeling: bool = False,
|
30 |
+
num_listeners: int = None,
|
31 |
+
listener_emb_dim: int = None,
|
32 |
+
use_mean_listener: bool = True,
|
33 |
+
# decoder related
|
34 |
+
decoder_type: str = "ffn",
|
35 |
+
decoder_dnn_dim: int = 64,
|
36 |
+
output_type: str = "scalar",
|
37 |
+
range_clipping: bool = True,
|
38 |
+
):
|
39 |
+
super().__init__() # this is needed! or else there will be an error.
|
40 |
+
self.use_mean_listener = use_mean_listener
|
41 |
+
self.output_type = output_type
|
42 |
+
|
43 |
+
# define listener embedding
|
44 |
+
self.use_listener_modeling = use_listener_modeling
|
45 |
+
|
46 |
+
# define ssl model
|
47 |
+
if ssl_module == "s3prl":
|
48 |
+
from s3prl.nn import S3PRLUpstream
|
49 |
+
|
50 |
+
if s3prl_name in S3PRLUpstream.available_names():
|
51 |
+
self.ssl_model = S3PRLUpstream(s3prl_name)
|
52 |
+
self.ssl_model_layer_idx = ssl_model_layer_idx
|
53 |
+
else:
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
# default uses ffn type mean net
|
57 |
+
self.mean_net_dnn = Projection(
|
58 |
+
ssl_model_output_dim,
|
59 |
+
mean_net_dnn_dim,
|
60 |
+
nn.ReLU,
|
61 |
+
mean_net_output_type,
|
62 |
+
mean_net_output_dim,
|
63 |
+
mean_net_output_step,
|
64 |
+
mean_net_range_clipping,
|
65 |
+
)
|
66 |
+
|
67 |
+
# listener modeling related
|
68 |
+
self.use_listener_modeling = use_listener_modeling
|
69 |
+
if use_listener_modeling:
|
70 |
+
self.num_listeners = num_listeners
|
71 |
+
self.listener_embeddings = nn.Embedding(
|
72 |
+
num_embeddings=num_listeners, embedding_dim=listener_emb_dim
|
73 |
+
)
|
74 |
+
# define decoder
|
75 |
+
self.decoder_type = decoder_type
|
76 |
+
if decoder_type == "ffn":
|
77 |
+
decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim
|
78 |
+
else:
|
79 |
+
raise NotImplementedError
|
80 |
+
# there is always dnn
|
81 |
+
self.decoder_dnn = Projection(
|
82 |
+
decoder_dnn_input_dim,
|
83 |
+
decoder_dnn_dim,
|
84 |
+
self.activation,
|
85 |
+
output_type,
|
86 |
+
range_clipping,
|
87 |
+
)
|
88 |
+
|
89 |
+
def get_num_params(self):
|
90 |
+
return sum(p.numel() for n, p in self.named_parameters())
|
91 |
+
|
92 |
+
def forward(self, inputs):
|
93 |
+
"""Calculate forward propagation.
|
94 |
+
Args:
|
95 |
+
waveform has shape (batch, time)
|
96 |
+
waveform_lengths has shape (batch)
|
97 |
+
listener_ids has shape (batch)
|
98 |
+
"""
|
99 |
+
waveform = inputs["waveform"]
|
100 |
+
waveform_lengths = inputs["waveform_lengths"]
|
101 |
+
|
102 |
+
batch, time = waveform.shape
|
103 |
+
|
104 |
+
# get listener embedding
|
105 |
+
if self.use_listener_modeling:
|
106 |
+
listener_ids = inputs["listener_idxs"]
|
107 |
+
# NOTE(unlight): not tested yet
|
108 |
+
listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim)
|
109 |
+
listener_embs = torch.stack(
|
110 |
+
[listener_embs for i in range(time)], dim=1
|
111 |
+
) # (batch, time, feat_dim)
|
112 |
+
|
113 |
+
# ssl model forward
|
114 |
+
all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model(
|
115 |
+
waveform, waveform_lengths
|
116 |
+
)
|
117 |
+
encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
|
118 |
+
encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx]
|
119 |
+
|
120 |
+
# inject listener embedding
|
121 |
+
if self.use_listener_modeling:
|
122 |
+
# NOTE(unlight): not tested yet
|
123 |
+
encoder_outputs = encoder_outputs.view(
|
124 |
+
(batch, time, -1)
|
125 |
+
) # (batch, time, feat_dim)
|
126 |
+
decoder_inputs = torch.cat(
|
127 |
+
[encoder_outputs, listener_embs], dim=-1
|
128 |
+
) # concat along feature dimension
|
129 |
+
else:
|
130 |
+
decoder_inputs = encoder_outputs
|
131 |
+
|
132 |
+
# masked mean pooling
|
133 |
+
# masks = make_non_pad_mask(encoder_outputs_lens)
|
134 |
+
# masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1]
|
135 |
+
# decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1)
|
136 |
+
|
137 |
+
# mean net
|
138 |
+
mean_net_outputs = self.mean_net_dnn(
|
139 |
+
decoder_inputs
|
140 |
+
) # [batch, time, 1 (scalar) / 5 (categorical)]
|
141 |
+
|
142 |
+
# decoder
|
143 |
+
if self.use_listener_modeling:
|
144 |
+
if self.decoder_type == "rnn":
|
145 |
+
decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs)
|
146 |
+
else:
|
147 |
+
decoder_outputs = decoder_inputs
|
148 |
+
decoder_outputs = self.decoder_dnn(
|
149 |
+
decoder_outputs
|
150 |
+
) # [batch, time, 1 (scalar) / 5 (categorical)]
|
151 |
+
|
152 |
+
# set outputs
|
153 |
+
# return lengths for masked loss calculation
|
154 |
+
ret = {
|
155 |
+
"waveform_lengths": waveform_lengths,
|
156 |
+
"frame_lengths": encoder_outputs_lens,
|
157 |
+
}
|
158 |
+
|
159 |
+
# define scores
|
160 |
+
ret["mean_scores"] = mean_net_outputs
|
161 |
+
ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None
|
162 |
+
|
163 |
+
return ret
|
164 |
+
|
165 |
+
def mean_net_inference(self, inputs):
|
166 |
+
waveform = inputs["waveform"]
|
167 |
+
waveform_lengths = inputs["waveform_lengths"]
|
168 |
+
|
169 |
+
# ssl model forward
|
170 |
+
all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model(
|
171 |
+
waveform, waveform_lengths
|
172 |
+
)
|
173 |
+
encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
|
174 |
+
|
175 |
+
# mean net
|
176 |
+
decoder_inputs = encoder_outputs
|
177 |
+
mean_net_outputs = self.mean_net_dnn(
|
178 |
+
decoder_inputs, inference=True
|
179 |
+
) # [batch, time, 1 (scalar) / 5 (categorical)]
|
180 |
+
mean_net_outputs = mean_net_outputs.squeeze(-1)
|
181 |
+
scores = torch.mean(mean_net_outputs, dim=1) # [batch]
|
182 |
+
|
183 |
+
return {
|
184 |
+
"ssl_embeddings": encoder_outputs,
|
185 |
+
"scores": scores
|
186 |
+
}
|
187 |
+
|
188 |
+
def mean_net_inference_p1(self, waveform, waveform_lengths):
|
189 |
+
# ssl model forward
|
190 |
+
all_encoder_outputs, _ = self.ssl_model(waveform, waveform_lengths)
|
191 |
+
encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
|
192 |
+
return encoder_outputs
|
193 |
+
|
194 |
+
def mean_net_inference_p2(self, encoder_outputs):
|
195 |
+
# mean net
|
196 |
+
mean_net_outputs = self.mean_net_dnn(
|
197 |
+
encoder_outputs
|
198 |
+
) # [batch, time, 1 (scalar) / 5 (categorical)]
|
199 |
+
mean_net_outputs = mean_net_outputs.squeeze(-1)
|
200 |
+
scores = torch.mean(mean_net_outputs, dim=1)
|
201 |
+
|
202 |
+
return scores
|
203 |
+
|
204 |
+
def get_ssl_embeddings(self, inputs):
|
205 |
+
waveform = inputs["waveform"]
|
206 |
+
waveform_lengths = inputs["waveform_lengths"]
|
207 |
+
|
208 |
+
all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model(
|
209 |
+
waveform, waveform_lengths
|
210 |
+
)
|
211 |
+
encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
|
212 |
+
return encoder_outputs
|