KevinGeng commited on
Commit
b33c328
1 Parent(s): 76f50e9

test python files

Browse files
Files changed (5) hide show
  1. app.py +77 -0
  2. app_record.py +65 -0
  3. app_record_streaming.py +63 -0
  4. lightning_module.py +41 -0
  5. model.py +191 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from random import sample
3
+ import gradio as gr
4
+ import torchaudio
5
+ import torch
6
+ import torch.nn as nn
7
+ import lightning_module
8
+ import pdb
9
+ import jiwer
10
+ # ASR part
11
+ from transformers import pipeline
12
+ p = pipeline("automatic-speech-recognition")
13
+
14
+ # WER part
15
+ transformation = jiwer.Compose([
16
+ jiwer.ToLowerCase(),
17
+ jiwer.RemoveWhiteSpace(replace_by_space=True),
18
+ jiwer.RemoveMultipleSpaces(),
19
+ jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
20
+ ])
21
+
22
+ class ChangeSampleRate(nn.Module):
23
+ def __init__(self, input_rate: int, output_rate: int):
24
+ super().__init__()
25
+ self.output_rate = output_rate
26
+ self.input_rate = input_rate
27
+
28
+ def forward(self, wav: torch.tensor) -> torch.tensor:
29
+ # Only accepts 1-channel waveform input
30
+ wav = wav.view(wav.size(0), -1)
31
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
32
+ indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
33
+ round_down = wav[:, indices.long()]
34
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
35
+ output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
36
+ return output
37
+
38
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
39
+ def calc_mos(audio_path, ref):
40
+ wav, sr = torchaudio.load(audio_path)
41
+ osr = 16_000
42
+ batch = wav.unsqueeze(0).repeat(10, 1, 1)
43
+ csr = ChangeSampleRate(sr, osr)
44
+ out_wavs = csr(wav)
45
+ # ASR
46
+ trans = p(audio_path)["text"]
47
+ # WER
48
+ wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
49
+
50
+ batch = {
51
+ 'wav': out_wavs,
52
+ 'domains': torch.tensor([0]),
53
+ 'judge_id': torch.tensor([288])
54
+ }
55
+ with torch.no_grad():
56
+ output = model(batch)
57
+
58
+ predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
59
+
60
+ return predic_mos, trans, wer
61
+
62
+ description ="""
63
+ MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
64
+ This demo only accepts .wav format. Best at 16 kHz sampling rate.
65
+
66
+ Paper is available [here](https://arxiv.org/abs/2204.02152)
67
+ """
68
+
69
+ iface = gr.Interface(
70
+ fn=calc_mos,
71
+ inputs=[gr.Audio(type='filepath'), gr.Textbox(placeholder="Insert referance here", label="Referance")],
72
+ outputs=[gr.Textbox("Predicted MOS"), gr.Textbox("Hypothesis"), gr.Textbox("WER")],
73
+ title="UTMOS Demo",
74
+ description=description,
75
+ allow_flagging="auto",
76
+ )
77
+ iface.launch()
app_record.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from random import sample
3
+ import gradio as gr
4
+ import torchaudio
5
+ import torch
6
+ import torch.nn as nn
7
+ import lightning_module
8
+ import pdb
9
+
10
+ # ASR part
11
+ from transformers import pipeline
12
+ p = pipeline("automatic-speech-recognition")
13
+
14
+ class ChangeSampleRate(nn.Module):
15
+ def __init__(self, input_rate: int, output_rate: int):
16
+ super().__init__()
17
+ self.output_rate = output_rate
18
+ self.input_rate = input_rate
19
+
20
+ def forward(self, wav: torch.tensor) -> torch.tensor:
21
+ # Only accepts 1-channel waveform input
22
+ wav = wav.view(wav.size(0), -1)
23
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
24
+ indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
25
+ round_down = wav[:, indices.long()]
26
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
27
+ output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
28
+ return output
29
+
30
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
31
+ def calc_mos(audio_path):
32
+ wav, sr = torchaudio.load(audio_path)
33
+ osr = 16_000
34
+ batch = wav.unsqueeze(0).repeat(10, 1, 1)
35
+ csr = ChangeSampleRate(sr, osr)
36
+ out_wavs = csr(wav)
37
+
38
+ transcription = p(audio_path)["text"]
39
+ batch = {
40
+ 'wav': out_wavs,
41
+ 'domains': torch.tensor([0]),
42
+ 'judge_id': torch.tensor([288])
43
+ }
44
+ with torch.no_grad():
45
+ output = model(batch)
46
+ return output.mean(dim=1).squeeze().detach().numpy()*2 + 3, transcription
47
+
48
+
49
+ description ="""
50
+ MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
51
+ This demo only accepts .wav format. Best at 16 kHz sampling rate.
52
+
53
+ Paper is available [here](https://arxiv.org/abs/2204.02152)
54
+ """
55
+
56
+ # inputs=gr.inputs.Audio(type='filepath'),
57
+ iface = gr.Interface(
58
+ fn=calc_mos,
59
+ inputs = gr.Audio(source="microphone", type="filepath"),
60
+ outputs=["text","textbox"],
61
+ title="UTMOS Demo",
62
+ description=description,
63
+ allow_flagging=True,
64
+
65
+ ).launch()
app_record_streaming.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from random import sample
3
+ import gradio as gr
4
+ import torchaudio
5
+ import torch
6
+ import torch.nn as nn
7
+ import lightning_module
8
+ import pdb
9
+ # ASR part
10
+ from transformers import pipeline
11
+ p = pipeline("automatic-speech-recognition")
12
+
13
+ class ChangeSampleRate(nn.Module):
14
+ def __init__(self, input_rate: int, output_rate: int):
15
+ super().__init__()
16
+ self.output_rate = output_rate
17
+ self.input_rate = input_rate
18
+
19
+ def forward(self, wav: torch.tensor) -> torch.tensor:
20
+ # Only accepts 1-channel waveform input
21
+ wav = wav.view(wav.size(0), -1)
22
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
23
+ indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
24
+ round_down = wav[:, indices.long()]
25
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
26
+ output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
27
+ return output
28
+
29
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
30
+ def calc_mos(audio_path):
31
+ wav, sr = torchaudio.load(audio_path)
32
+ osr = 16_000
33
+ batch = wav.unsqueeze(0).repeat(10, 1, 1)
34
+ csr = ChangeSampleRate(sr, osr)
35
+ out_wavs = csr(wav)
36
+ transcription = p(audio_path)["text"]
37
+ batch = {
38
+ 'wav': out_wavs,
39
+ 'domains': torch.tensor([0]),
40
+ 'judge_id': torch.tensor([288])
41
+ }
42
+ with torch.no_grad():
43
+ output = model(batch)
44
+ return output.mean(dim=1).squeeze().detach().numpy()*2 + 3, transcription
45
+
46
+
47
+ description ="""
48
+ MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
49
+ This demo only accepts .wav format. Best at 16 kHz sampling rate.
50
+
51
+ Paper is available [here](https://arxiv.org/abs/2204.02152)
52
+ """
53
+
54
+ # inputs=gr.inputs.Audio(type='filepath'),
55
+ iface = gr.Interface(
56
+ fn=calc_mos,
57
+ inputs = gr.Audio(source="microphone", type="filepath", streaming=True),
58
+ outputs=["text","textbox"],
59
+ title="UTMOS Demo",
60
+ description=description,
61
+ allow_flagging=False,
62
+
63
+ ).launch()
lightning_module.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ import numpy as np
6
+ import hydra
7
+ from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection
8
+
9
+
10
+ class BaselineLightningModule(pl.LightningModule):
11
+ def __init__(self, cfg):
12
+ super().__init__()
13
+ self.cfg = cfg
14
+ self.construct_model()
15
+ self.save_hyperparameters()
16
+
17
+ def construct_model(self):
18
+ self.feature_extractors = nn.ModuleList([
19
+ load_ssl_model(cp_path='wav2vec_small.pt'),
20
+ DomainEmbedding(3,128),
21
+ ])
22
+ output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors])
23
+ output_layers = [
24
+ LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim)
25
+ ]
26
+ output_dim = output_layers[-1].get_output_dim()
27
+ output_layers.append(
28
+ Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim)
29
+
30
+ )
31
+
32
+ self.output_layers = nn.ModuleList(output_layers)
33
+
34
+ def forward(self, inputs):
35
+ outputs = {}
36
+ for feature_extractor in self.feature_extractors:
37
+ outputs.update(feature_extractor(inputs))
38
+ x = outputs
39
+ for output_layer in self.output_layers:
40
+ x = output_layer(x,inputs)
41
+ return x
model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import fairseq
4
+ import os
5
+ import hydra
6
+
7
+ def load_ssl_model(cp_path):
8
+ ssl_model_type = cp_path.split("/")[-1]
9
+ wavlm = "WavLM" in ssl_model_type
10
+ if wavlm:
11
+ checkpoint = torch.load(cp_path)
12
+ cfg = WavLMConfig(checkpoint['cfg'])
13
+ ssl_model = WavLM(cfg)
14
+ ssl_model.load_state_dict(checkpoint['model'])
15
+ if 'Large' in ssl_model_type:
16
+ SSL_OUT_DIM = 1024
17
+ else:
18
+ SSL_OUT_DIM = 768
19
+ else:
20
+ if ssl_model_type == "wav2vec_small.pt":
21
+ SSL_OUT_DIM = 768
22
+ elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]:
23
+ SSL_OUT_DIM = 1024
24
+ else:
25
+ print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.")
26
+ exit()
27
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
28
+ [cp_path]
29
+ )
30
+ ssl_model = model[0]
31
+ ssl_model.remove_pretraining_modules()
32
+ return SSL_model(ssl_model, SSL_OUT_DIM, wavlm)
33
+
34
+ class SSL_model(nn.Module):
35
+ def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None:
36
+ super(SSL_model,self).__init__()
37
+ self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
38
+ self.WavLM = wavlm
39
+
40
+ def forward(self,batch):
41
+ wav = batch['wav']
42
+ wav = wav.squeeze(1) # [batches, audio_len]
43
+ if self.WavLM:
44
+ x = self.ssl_model.extract_features(wav)[0]
45
+ else:
46
+ res = self.ssl_model(wav, mask=False, features_only=True)
47
+ x = res["x"]
48
+ return {"ssl-feature":x}
49
+ def get_output_dim(self):
50
+ return self.ssl_out_dim
51
+
52
+
53
+ class PhonemeEncoder(nn.Module):
54
+ '''
55
+ PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer.
56
+ Args:
57
+ vocab_size: the size of the vocabulary
58
+ hidden_dim: the size of the hidden state of the LSTM
59
+ emb_dim: the size of the embedding layer
60
+ out_dim: the size of the output of the linear layer
61
+ n_lstm_layers: the number of LSTM layers
62
+ '''
63
+ def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None:
64
+ super().__init__()
65
+ self.with_reference = with_reference
66
+ self.embedding = nn.Embedding(vocab_size, emb_dim)
67
+ self.encoder = nn.LSTM(emb_dim, hidden_dim,
68
+ num_layers=n_lstm_layers, dropout=0.1, bidirectional=True)
69
+ self.linear = nn.Sequential(
70
+ nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim),
71
+ nn.ReLU()
72
+ )
73
+ self.out_dim = out_dim
74
+
75
+ def forward(self,batch):
76
+ seq = batch['phonemes']
77
+ lens = batch['phoneme_lens']
78
+ reference_seq = batch['reference']
79
+ reference_lens = batch['reference_lens']
80
+ emb = self.embedding(seq)
81
+ emb = torch.nn.utils.rnn.pack_padded_sequence(
82
+ emb, lens, batch_first=True, enforce_sorted=False)
83
+ _, (ht, _) = self.encoder(emb)
84
+ feature = ht[-1] + ht[0]
85
+ if self.with_reference:
86
+ if reference_seq==None or reference_lens ==None:
87
+ raise ValueError("reference_batch and reference_lens should not be None when with_reference is True")
88
+ reference_emb = self.embedding(reference_seq)
89
+ reference_emb = torch.nn.utils.rnn.pack_padded_sequence(
90
+ reference_emb, reference_lens, batch_first=True, enforce_sorted=False)
91
+ _, (ht_ref, _) = self.encoder(emb)
92
+ reference_feature = ht_ref[-1] + ht_ref[0]
93
+ feature = self.linear(torch.cat([feature,reference_feature],1))
94
+ else:
95
+ feature = self.linear(feature)
96
+ return {"phoneme-feature": feature}
97
+ def get_output_dim(self):
98
+ return self.out_dim
99
+
100
+ class DomainEmbedding(nn.Module):
101
+ def __init__(self,n_domains,domain_dim) -> None:
102
+ super().__init__()
103
+ self.embedding = nn.Embedding(n_domains,domain_dim)
104
+ self.output_dim = domain_dim
105
+ def forward(self, batch):
106
+ return {"domain-feature": self.embedding(batch['domains'])}
107
+ def get_output_dim(self):
108
+ return self.output_dim
109
+
110
+
111
+ class LDConditioner(nn.Module):
112
+ '''
113
+ Conditions ssl output by listener embedding
114
+ '''
115
+ def __init__(self,input_dim, judge_dim, num_judges=None):
116
+ super().__init__()
117
+ self.input_dim = input_dim
118
+ self.judge_dim = judge_dim
119
+ self.num_judges = num_judges
120
+ assert num_judges !=None
121
+ self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
122
+ # concat [self.output_layer, phoneme features]
123
+
124
+ self.decoder_rnn = nn.LSTM(
125
+ input_size = self.input_dim + self.judge_dim,
126
+ hidden_size = 512,
127
+ num_layers = 1,
128
+ batch_first = True,
129
+ bidirectional = True
130
+ ) # linear?
131
+ self.out_dim = self.decoder_rnn.hidden_size*2
132
+
133
+ def get_output_dim(self):
134
+ return self.out_dim
135
+
136
+
137
+ def forward(self, x, batch):
138
+ judge_ids = batch['judge_id']
139
+ if 'phoneme-feature' in x.keys():
140
+ concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2)
141
+ else:
142
+ concatenated_feature = x['ssl-feature']
143
+ if 'domain-feature' in x.keys():
144
+ concatenated_feature = torch.cat(
145
+ (
146
+ concatenated_feature,
147
+ x['domain-feature']
148
+ .unsqueeze(1)
149
+ .expand(-1, concatenated_feature.size(1), -1),
150
+ ),
151
+ dim=2,
152
+ )
153
+ if judge_ids != None:
154
+ concatenated_feature = torch.cat(
155
+ (
156
+ concatenated_feature,
157
+ self.judge_embedding(judge_ids)
158
+ .unsqueeze(1)
159
+ .expand(-1, concatenated_feature.size(1), -1),
160
+ ),
161
+ dim=2,
162
+ )
163
+ decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
164
+ return decoder_output
165
+
166
+ class Projection(nn.Module):
167
+ def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
168
+ super(Projection, self).__init__()
169
+ self.range_clipping = range_clipping
170
+ output_dim = 1
171
+ if range_clipping:
172
+ self.proj = nn.Tanh()
173
+
174
+ self.net = nn.Sequential(
175
+ nn.Linear(input_dim, hidden_dim),
176
+ activation,
177
+ nn.Dropout(0.3),
178
+ nn.Linear(hidden_dim, output_dim),
179
+ )
180
+ self.output_dim = output_dim
181
+
182
+ def forward(self, x, batch):
183
+ output = self.net(x)
184
+
185
+ # range clipping
186
+ if self.range_clipping:
187
+ return self.proj(output) * 2.0 + 3
188
+ else:
189
+ return output
190
+ def get_output_dim(self):
191
+ return self.output_dim