unilight commited on
Commit
052c3ff
·
1 Parent(s): 781dc92
Files changed (3) hide show
  1. app.py +184 -0
  2. models/modules.py +59 -0
  3. models/sslmos.py +212 -0
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ from loguru import logger
7
+ import soundfile as sf
8
+ import librosa
9
+ import gradio as gr
10
+
11
+ from huggingface_hub import hf_hub_download
12
+ import time
13
+ import torch
14
+ import yaml
15
+
16
+ # from s3prl_vc.upstream.interface import get_upstream
17
+ # from s3prl.nn import Featurizer
18
+ # import s3prl_vc.models
19
+ # from s3prl_vc.utils import read_hdf5
20
+ # from s3prl_vc.vocoder import Vocoder
21
+
22
+
23
+ # ---------- Settings ----------
24
+ GPU_ID = '-1'
25
+ os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
26
+ DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu'
27
+
28
+ SERVER_PORT = 42208
29
+ SERVER_NAME = "0.0.0.0"
30
+ SSL_DIR = './keyble_ssl'
31
+
32
+ FS = 16000
33
+ resamplers = {}
34
+ MIN_REQUIRED_WAV_LENGTH = 1040
35
+
36
+ # EXAMPLE_DIR = './examples'
37
+ # en_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "en", '*.wav')))
38
+ # jp_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "jp", '*.wav')))
39
+ # zh_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "zh", '*.wav')))
40
+
41
+ # TRGSPKS = ["TEF1", "TEF2", "TEM1", "TEM2"]
42
+
43
+ # ref_samples = {
44
+ # trgspk: sorted(glob(os.path.join("./ref_samples", trgspk, '*.wav')))
45
+ # for trgspk in TRGSPKS
46
+ # }
47
+
48
+ # ---------- Logging ----------
49
+ logger.add('app.log', mode='a')
50
+ logger.info('============================= App restarted =============================')
51
+
52
+ # ---------- Download models ----------
53
+ logger.info('============================= Download models ===========================')
54
+
55
+ model_paths = {
56
+ "SSL-MOS, all training sets": {
57
+ "ckpt": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/checkpoint-86000steps.pkl"),
58
+ "config": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/config.yml"),
59
+ }
60
+ }
61
+
62
+ # ---------- Model ----------
63
+ models = {}
64
+ for name, path_dict in model_paths.items():
65
+ logger.info(f'============================= Setting up model for {name} =============')
66
+ checkpoint_path = path_dict["ckpt"]
67
+ config_path = path_dict["config"]
68
+ with open(config_path) as f:
69
+ config = yaml.load(f, Loader=yaml.Loader)
70
+
71
+ if config["model_type"] == "SSLMOS":
72
+ from models.sslmos import SSLMOS
73
+ model = SSLMOS(
74
+ config["model_input"],
75
+ num_listeners=config.get("num_listeners", None),
76
+ num_domains=config.get("num_domains", None),
77
+ **config["model_params"],
78
+ ).to(DEVICE)
79
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"])
80
+ model = model.eval().to(DEVICE)
81
+ logger.info(f"Loaded model parameters from {checkpoint_path}.")
82
+
83
+ models[name] = model
84
+
85
+ def read_wav(wav_path):
86
+ # read waveform
87
+ waveform, sample_rate = torchaudio.load(
88
+ wav_path, channels_first=False
89
+ ) # waveform: [T, 1]
90
+
91
+ # resample if needed
92
+ if sample_rate != FS:
93
+ resampler_key = f"{sample_rate}-{FS}"
94
+ if resampler_key not in resamplers:
95
+ resamplers[resampler_key] = torchaudio.transforms.Resample(
96
+ sample_rate, FS, dtype=waveform.dtype
97
+ )
98
+ waveform = resamplers[resampler_key](waveform)
99
+
100
+ waveform = waveform.squeeze(-1)
101
+
102
+ # always pad to a minumum length
103
+ if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH:
104
+ to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2
105
+ waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0)
106
+
107
+ return waveform, sample_rate
108
+
109
+ def predict(model_name, wav_file):
110
+ x, fs = read_wav(wav_file)
111
+ logger.info('wav file loaded')
112
+
113
+ # set up model input
114
+ model_input = x.unsqueeze(0).to(DEVICE)
115
+ model_lengths = model_input.new_tensor([model_input.size(1)]).long()
116
+ inputs = {
117
+ config["model_input"]: model_input,
118
+ config["model_input"] + "_lengths": model_lengths,
119
+ }
120
+
121
+ with torch.no_grad():
122
+ # model forward
123
+ if config["inference_mode"] == "mean_listener":
124
+ outputs = models[model_name].mean_listener_inference(inputs)
125
+ elif config["inference_mode"] == "mean_net":
126
+ outputs = models[model_name].mean_net_inference(inputs)
127
+
128
+ pred_mean_scores = outputs["scores"].cpu().detach().numpy()[0]
129
+
130
+ return pred_mean_scores
131
+
132
+ with gr.Blocks(title="S3PRL-VC: Any-to-one voice conversion demo on VCC2020") as demo:
133
+ gr.Markdown(
134
+ """
135
+ # S3PRL-VC: Any-to-one voice conversion demo on VCC2020
136
+ ### [[Paper (ICASSP2023)]](https://arxiv.org/abs/2110.06280) [[Paper(JSTSP)]](https://arxiv.org/abs/2207.04356) [[Code]](https://github.com/unilight/s3prl-vc)
137
+ **S3PRL-VC** is a voice conversion (VC) toolkit for benchmarking self-supervised speech representations (S3Rs). The term **any-to-one** means that the system can convert from any unseen speaker to a pre-defined speaker given in training.
138
+ In this demo, you can record your voice, and the model will convert your voice to one of the four pre-defined speakers. These four speakers come from the **voice conversion challenge (VCC) 2020**. You can listen to the samples to get a sense of what these speakers sound like.
139
+ The **RTF** of the system is around **1.5~2.5**, i.e. if you recorded a 5 second long audio, it will take 5 * (1.5~2.5) = 7.5~12.5 seconds to generate the output.
140
+ """
141
+ )
142
+
143
+ with gr.Row():
144
+ with gr.Column():
145
+ gr.Markdown("## Record your speech here!")
146
+ input_wav = gr.Audio(label="Input speech", source='microphone', type='filepath')
147
+
148
+ gr.Markdown("## Select a model!")
149
+ model_name = gr.Radio(label="Model", choices=list(model_paths.keys()))
150
+
151
+ evaluate_btn = gr.Button(value="Evaluate!")
152
+ # gr.Markdown("### You can use these examples if using a microphone is too troublesome!")
153
+ # gr.Markdown("I recorded the samples using my Macbook Pro, so there might be some noises.")
154
+ # gr.Examples(
155
+ # examples=en_examples,
156
+ # inputs=input_wav,
157
+ # label="English examples"
158
+ # )
159
+ # gr.Examples(
160
+ # examples=jp_examples,
161
+ # inputs=input_wav,
162
+ # label="Japanese examples"
163
+ # )
164
+ # gr.Examples(
165
+ # examples=zh_examples,
166
+ # inputs=input_wav,
167
+ # label="Mandarin examples"
168
+ # )
169
+
170
+ with gr.Column():
171
+ gr.Markdown("## The predicted scores is here:")
172
+ output_score = gr.Textbox(label="Prediction", interactive=False)
173
+ evaluate_btn.click(predict, [model_name, input_wav], output_score)
174
+
175
+ if __name__ == '__main__':
176
+ try:
177
+ demo.launch(debug=True,
178
+ enable_queue=True,
179
+ )
180
+ except KeyboardInterrupt as e:
181
+ print(e)
182
+
183
+ finally:
184
+ demo.close()
models/modules.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2024 Wen-Chin Huang
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ # LDNet modules
7
+ # taken from: https://github.com/unilight/LDNet/blob/main/models/modules.py (written by myself)
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ STRIDE = 3
13
+
14
+ class Projection(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_dim,
18
+ hidden_dim,
19
+ activation,
20
+ output_type,
21
+ _output_dim,
22
+ output_step=1.0,
23
+ range_clipping=False,
24
+ ):
25
+ super(Projection, self).__init__()
26
+ self.output_type = output_type
27
+ self.range_clipping = range_clipping
28
+ if output_type == "scalar":
29
+ output_dim = 1
30
+ if range_clipping:
31
+ self.proj = nn.Tanh()
32
+ elif output_type == "categorical":
33
+ output_dim = _output_dim
34
+ self.output_step = output_step
35
+ else:
36
+ raise NotImplementedError("wrong output_type: {}".format(output_type))
37
+
38
+ self.net = nn.Sequential(
39
+ nn.Linear(in_dim, hidden_dim),
40
+ activation(),
41
+ nn.Dropout(0.3),
42
+ nn.Linear(hidden_dim, output_dim),
43
+ )
44
+
45
+ def forward(self, x, inference=False):
46
+ output = self.net(x)
47
+
48
+ # scalar / categorical
49
+ if self.output_type == "scalar":
50
+ # range clipping
51
+ if self.range_clipping:
52
+ return self.proj(output) * 2.0 + 3
53
+ else:
54
+ return output
55
+ else:
56
+ if inference:
57
+ return torch.argmax(output, dim=-1) * self.output_step + 1
58
+ else:
59
+ return output
models/sslmos.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2024 Wen-Chin Huang
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ # SSLMOS model
7
+ # modified from: https://github.com/nii-yamagishilab/mos-finetune-ssl/blob/main/mos_fairseq.py (written by Erica Cooper)
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from .modules import Projection
12
+
13
+
14
+ class SSLMOS(torch.nn.Module):
15
+ def __init__(
16
+ self,
17
+ # model related
18
+ ssl_module: str,
19
+ s3prl_name: str,
20
+ ssl_model_output_dim: int,
21
+ ssl_model_layer_idx: int,
22
+ # mean net related
23
+ mean_net_dnn_dim: int = 64,
24
+ mean_net_output_type: str = "scalar",
25
+ mean_net_output_dim: int = 5,
26
+ mean_net_output_step: float = 0.25,
27
+ mean_net_range_clipping: bool = True,
28
+ # listener related
29
+ use_listener_modeling: bool = False,
30
+ num_listeners: int = None,
31
+ listener_emb_dim: int = None,
32
+ use_mean_listener: bool = True,
33
+ # decoder related
34
+ decoder_type: str = "ffn",
35
+ decoder_dnn_dim: int = 64,
36
+ output_type: str = "scalar",
37
+ range_clipping: bool = True,
38
+ ):
39
+ super().__init__() # this is needed! or else there will be an error.
40
+ self.use_mean_listener = use_mean_listener
41
+ self.output_type = output_type
42
+
43
+ # define listener embedding
44
+ self.use_listener_modeling = use_listener_modeling
45
+
46
+ # define ssl model
47
+ if ssl_module == "s3prl":
48
+ from s3prl.nn import S3PRLUpstream
49
+
50
+ if s3prl_name in S3PRLUpstream.available_names():
51
+ self.ssl_model = S3PRLUpstream(s3prl_name)
52
+ self.ssl_model_layer_idx = ssl_model_layer_idx
53
+ else:
54
+ raise NotImplementedError
55
+
56
+ # default uses ffn type mean net
57
+ self.mean_net_dnn = Projection(
58
+ ssl_model_output_dim,
59
+ mean_net_dnn_dim,
60
+ nn.ReLU,
61
+ mean_net_output_type,
62
+ mean_net_output_dim,
63
+ mean_net_output_step,
64
+ mean_net_range_clipping,
65
+ )
66
+
67
+ # listener modeling related
68
+ self.use_listener_modeling = use_listener_modeling
69
+ if use_listener_modeling:
70
+ self.num_listeners = num_listeners
71
+ self.listener_embeddings = nn.Embedding(
72
+ num_embeddings=num_listeners, embedding_dim=listener_emb_dim
73
+ )
74
+ # define decoder
75
+ self.decoder_type = decoder_type
76
+ if decoder_type == "ffn":
77
+ decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim
78
+ else:
79
+ raise NotImplementedError
80
+ # there is always dnn
81
+ self.decoder_dnn = Projection(
82
+ decoder_dnn_input_dim,
83
+ decoder_dnn_dim,
84
+ self.activation,
85
+ output_type,
86
+ range_clipping,
87
+ )
88
+
89
+ def get_num_params(self):
90
+ return sum(p.numel() for n, p in self.named_parameters())
91
+
92
+ def forward(self, inputs):
93
+ """Calculate forward propagation.
94
+ Args:
95
+ waveform has shape (batch, time)
96
+ waveform_lengths has shape (batch)
97
+ listener_ids has shape (batch)
98
+ """
99
+ waveform = inputs["waveform"]
100
+ waveform_lengths = inputs["waveform_lengths"]
101
+
102
+ batch, time = waveform.shape
103
+
104
+ # get listener embedding
105
+ if self.use_listener_modeling:
106
+ listener_ids = inputs["listener_idxs"]
107
+ # NOTE(unlight): not tested yet
108
+ listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim)
109
+ listener_embs = torch.stack(
110
+ [listener_embs for i in range(time)], dim=1
111
+ ) # (batch, time, feat_dim)
112
+
113
+ # ssl model forward
114
+ all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model(
115
+ waveform, waveform_lengths
116
+ )
117
+ encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
118
+ encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx]
119
+
120
+ # inject listener embedding
121
+ if self.use_listener_modeling:
122
+ # NOTE(unlight): not tested yet
123
+ encoder_outputs = encoder_outputs.view(
124
+ (batch, time, -1)
125
+ ) # (batch, time, feat_dim)
126
+ decoder_inputs = torch.cat(
127
+ [encoder_outputs, listener_embs], dim=-1
128
+ ) # concat along feature dimension
129
+ else:
130
+ decoder_inputs = encoder_outputs
131
+
132
+ # masked mean pooling
133
+ # masks = make_non_pad_mask(encoder_outputs_lens)
134
+ # masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1]
135
+ # decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1)
136
+
137
+ # mean net
138
+ mean_net_outputs = self.mean_net_dnn(
139
+ decoder_inputs
140
+ ) # [batch, time, 1 (scalar) / 5 (categorical)]
141
+
142
+ # decoder
143
+ if self.use_listener_modeling:
144
+ if self.decoder_type == "rnn":
145
+ decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs)
146
+ else:
147
+ decoder_outputs = decoder_inputs
148
+ decoder_outputs = self.decoder_dnn(
149
+ decoder_outputs
150
+ ) # [batch, time, 1 (scalar) / 5 (categorical)]
151
+
152
+ # set outputs
153
+ # return lengths for masked loss calculation
154
+ ret = {
155
+ "waveform_lengths": waveform_lengths,
156
+ "frame_lengths": encoder_outputs_lens,
157
+ }
158
+
159
+ # define scores
160
+ ret["mean_scores"] = mean_net_outputs
161
+ ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None
162
+
163
+ return ret
164
+
165
+ def mean_net_inference(self, inputs):
166
+ waveform = inputs["waveform"]
167
+ waveform_lengths = inputs["waveform_lengths"]
168
+
169
+ # ssl model forward
170
+ all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model(
171
+ waveform, waveform_lengths
172
+ )
173
+ encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
174
+
175
+ # mean net
176
+ decoder_inputs = encoder_outputs
177
+ mean_net_outputs = self.mean_net_dnn(
178
+ decoder_inputs, inference=True
179
+ ) # [batch, time, 1 (scalar) / 5 (categorical)]
180
+ mean_net_outputs = mean_net_outputs.squeeze(-1)
181
+ scores = torch.mean(mean_net_outputs, dim=1) # [batch]
182
+
183
+ return {
184
+ "ssl_embeddings": encoder_outputs,
185
+ "scores": scores
186
+ }
187
+
188
+ def mean_net_inference_p1(self, waveform, waveform_lengths):
189
+ # ssl model forward
190
+ all_encoder_outputs, _ = self.ssl_model(waveform, waveform_lengths)
191
+ encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
192
+ return encoder_outputs
193
+
194
+ def mean_net_inference_p2(self, encoder_outputs):
195
+ # mean net
196
+ mean_net_outputs = self.mean_net_dnn(
197
+ encoder_outputs
198
+ ) # [batch, time, 1 (scalar) / 5 (categorical)]
199
+ mean_net_outputs = mean_net_outputs.squeeze(-1)
200
+ scores = torch.mean(mean_net_outputs, dim=1)
201
+
202
+ return scores
203
+
204
+ def get_ssl_embeddings(self, inputs):
205
+ waveform = inputs["waveform"]
206
+ waveform_lengths = inputs["waveform_lengths"]
207
+
208
+ all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model(
209
+ waveform, waveform_lengths
210
+ )
211
+ encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
212
+ return encoder_outputs