Plachta commited on
Commit
452888f
1 Parent(s): e2f2ce0

Update ONNXVITS_infer.py

Browse files
Files changed (1) hide show
  1. ONNXVITS_infer.py +81 -130
ONNXVITS_infer.py CHANGED
@@ -13,17 +13,18 @@ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
13
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
  from commons import init_weights, get_padding
15
 
 
16
  class TextEncoder(nn.Module):
17
  def __init__(self,
18
- n_vocab,
19
- out_channels,
20
- hidden_channels,
21
- filter_channels,
22
- n_heads,
23
- n_layers,
24
- kernel_size,
25
- p_dropout,
26
- emotion_embedding):
27
  super().__init__()
28
  self.n_vocab = n_vocab
29
  self.out_channels = out_channels
@@ -34,12 +35,12 @@ class TextEncoder(nn.Module):
34
  self.kernel_size = kernel_size
35
  self.p_dropout = p_dropout
36
  self.emotion_embedding = emotion_embedding
37
-
38
- if self.n_vocab!=0:
39
  self.emb = nn.Embedding(n_vocab, hidden_channels)
40
  if emotion_embedding:
41
  self.emo_proj = nn.Linear(1024, hidden_channels)
42
- nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
43
 
44
  self.encoder = attentions.Encoder(
45
  hidden_channels,
@@ -48,15 +49,15 @@ class TextEncoder(nn.Module):
48
  n_layers,
49
  kernel_size,
50
  p_dropout)
51
- self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
52
 
53
  def forward(self, x, x_lengths, emotion_embedding=None):
54
- if self.n_vocab!=0:
55
- x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
56
  if emotion_embedding is not None:
57
  print("emotion added")
58
  x = x + self.emo_proj(emotion_embedding.unsqueeze(1))
59
- x = torch.transpose(x, 1, -1) # [b, h, t]
60
  x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
61
 
62
  x = self.encoder(x * x_mask, x_mask)
@@ -65,15 +66,16 @@ class TextEncoder(nn.Module):
65
  m, logs = torch.split(stats, self.out_channels, dim=1)
66
  return x, m, logs, x_mask
67
 
 
68
  class PosteriorEncoder(nn.Module):
69
  def __init__(self,
70
- in_channels,
71
- out_channels,
72
- hidden_channels,
73
- kernel_size,
74
- dilation_rate,
75
- n_layers,
76
- gin_channels=0):
77
  super().__init__()
78
  self.in_channels = in_channels
79
  self.out_channels = out_channels
@@ -96,35 +98,36 @@ class PosteriorEncoder(nn.Module):
96
  z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
97
  return z, m, logs, x_mask
98
 
 
99
  class SynthesizerTrn(models.SynthesizerTrn):
100
  """
101
  Synthesizer for Training
102
  """
103
 
104
- def __init__(self,
105
- n_vocab,
106
- spec_channels,
107
- segment_size,
108
- inter_channels,
109
- hidden_channels,
110
- filter_channels,
111
- n_heads,
112
- n_layers,
113
- kernel_size,
114
- p_dropout,
115
- resblock,
116
- resblock_kernel_sizes,
117
- resblock_dilation_sizes,
118
- upsample_rates,
119
- upsample_initial_channel,
120
- upsample_kernel_sizes,
121
- n_speakers=0,
122
- gin_channels=0,
123
- use_sdp=True,
124
- emotion_embedding=False,
125
- **kwargs):
126
-
127
- super().__init__(
128
  n_vocab,
129
  spec_channels,
130
  segment_size,
@@ -135,11 +138,11 @@ class SynthesizerTrn(models.SynthesizerTrn):
135
  n_layers,
136
  kernel_size,
137
  p_dropout,
138
- resblock,
139
- resblock_kernel_sizes,
140
- resblock_dilation_sizes,
141
- upsample_rates,
142
- upsample_initial_channel,
143
  upsample_kernel_sizes,
144
  n_speakers=n_speakers,
145
  gin_channels=gin_channels,
@@ -147,27 +150,28 @@ class SynthesizerTrn(models.SynthesizerTrn):
147
  **kwargs
148
  )
149
  self.enc_p = TextEncoder(n_vocab,
150
- inter_channels,
151
- hidden_channels,
152
- filter_channels,
153
- n_heads,
154
- n_layers,
155
- kernel_size,
156
- p_dropout,
157
- emotion_embedding)
158
  self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
159
 
160
- def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None):
 
161
  from ONNXVITS_utils import runonnx
162
-
163
- x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emotion_embedding)
164
 
165
  if self.n_speakers > 0:
166
- g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
167
  else:
168
  g = None
169
 
170
- #logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
171
  logw = runonnx("ONNX_net/dp.onnx", x=x.numpy(), x_mask=x_mask.numpy(), g=g.numpy())
172
  logw = torch.from_numpy(logw[0])
173
 
@@ -178,26 +182,27 @@ class SynthesizerTrn(models.SynthesizerTrn):
178
  attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
179
  attn = commons.generate_path(w_ceil, attn_mask)
180
 
181
- m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
182
- logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
 
183
 
184
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
185
-
186
- #z = self.flow(z_p, y_mask, g=g, reverse=True)
187
  z = runonnx("ONNX_net/flow.onnx", z_p=z_p.numpy(), y_mask=y_mask.numpy(), g=g.numpy())
188
  z = torch.from_numpy(z[0])
189
 
190
- #o = self.dec((z * y_mask)[:,:,:max_len], g=g)
191
- o = runonnx("ONNX_net/dec.onnx", z_in=(z * y_mask)[:,:,:max_len].numpy(), g=g.numpy())
192
  o = torch.from_numpy(o[0])
193
 
194
  return o, attn, y_mask, (z, z_p, m_p, logs_p)
195
 
196
  def predict_duration(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None,
197
- emotion_embedding=None):
198
  from ONNXVITS_utils import runonnx
199
 
200
- #x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
201
  x, m_p, logs_p, x_mask = runonnx("ONNX_net/enc_p.onnx", x=x.numpy(), x_lengths=x_lengths.numpy())
202
  x = torch.from_numpy(x)
203
  m_p = torch.from_numpy(m_p)
@@ -205,68 +210,14 @@ class SynthesizerTrn(models.SynthesizerTrn):
205
  x_mask = torch.from_numpy(x_mask)
206
 
207
  if self.n_speakers > 0:
208
- g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
209
  else:
210
  g = None
211
 
212
- #logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
213
  logw = runonnx("ONNX_net/dp.onnx", x=x.numpy(), x_mask=x_mask.numpy(), g=g.numpy())
214
  logw = torch.from_numpy(logw[0])
215
 
216
  w = torch.exp(logw) * x_mask * length_scale
217
  w_ceil = torch.ceil(w)
218
- return list(w_ceil.squeeze())
219
-
220
- def infer_with_duration(self, x, x_lengths, w_ceil, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None,
221
- emotion_embedding=None):
222
- from ONNXVITS_utils import runonnx
223
-
224
- #x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
225
- x, m_p, logs_p, x_mask = runonnx("ONNX_net/enc_p.onnx", x=x.numpy(), x_lengths=x_lengths.numpy())
226
- x = torch.from_numpy(x)
227
- m_p = torch.from_numpy(m_p)
228
- logs_p = torch.from_numpy(logs_p)
229
- x_mask = torch.from_numpy(x_mask)
230
-
231
- if self.n_speakers > 0:
232
- g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
233
- else:
234
- g = None
235
- assert len(w_ceil) == x.shape[2]
236
- w_ceil = torch.FloatTensor(w_ceil).reshape(1, 1, -1)
237
- y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
238
- y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
239
- attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
240
- attn = commons.generate_path(w_ceil, attn_mask)
241
-
242
- m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
243
- logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
244
-
245
- z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
246
-
247
- #z = self.flow(z_p, y_mask, g=g, reverse=True)
248
- z = runonnx("ONNX_net/flow.onnx", z_p=z_p.numpy(), y_mask=y_mask.numpy(), g=g.numpy())
249
- z = torch.from_numpy(z[0])
250
-
251
- #o = self.dec((z * y_mask)[:,:,:max_len], g=g)
252
- o = runonnx("ONNX_net/dec.onnx", z_in=(z * y_mask)[:,:,:max_len].numpy(), g=g.numpy())
253
- o = torch.from_numpy(o[0])
254
-
255
- return o, attn, y_mask, (z, z_p, m_p, logs_p)
256
-
257
- def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
258
- from ONNXVITS_utils import runonnx
259
- assert self.n_speakers > 0, "n_speakers have to be larger than 0."
260
- g_src = self.emb_g(sid_src).unsqueeze(-1)
261
- g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
262
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
263
- # z_p = self.flow(z, y_mask, g=g_src)
264
- z_p = runonnx("ONNX_net/flow.onnx", z_p=z.numpy(), y_mask=y_mask.numpy(), g=g_src.numpy())
265
- z_p = torch.from_numpy(z_p[0])
266
- # z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
267
- z_hat = runonnx("ONNX_net/flow.onnx", z_p=z_p.numpy(), y_mask=y_mask.numpy(), g=g_tgt.numpy())
268
- z_hat = torch.from_numpy(z_hat[0])
269
- # o_hat = self.dec(z_hat * y_mask, g=g_tgt)
270
- o_hat = runonnx("ONNX_net/dec.onnx", z_in=(z_hat * y_mask).numpy(), g=g_tgt.numpy())
271
- o_hat = torch.from_numpy(o_hat[0])
272
- return o_hat, y_mask, (z, z_p, z_hat)
 
13
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
  from commons import init_weights, get_padding
15
 
16
+
17
  class TextEncoder(nn.Module):
18
  def __init__(self,
19
+ n_vocab,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ emotion_embedding):
28
  super().__init__()
29
  self.n_vocab = n_vocab
30
  self.out_channels = out_channels
 
35
  self.kernel_size = kernel_size
36
  self.p_dropout = p_dropout
37
  self.emotion_embedding = emotion_embedding
38
+
39
+ if self.n_vocab != 0:
40
  self.emb = nn.Embedding(n_vocab, hidden_channels)
41
  if emotion_embedding:
42
  self.emo_proj = nn.Linear(1024, hidden_channels)
43
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
44
 
45
  self.encoder = attentions.Encoder(
46
  hidden_channels,
 
49
  n_layers,
50
  kernel_size,
51
  p_dropout)
52
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
53
 
54
  def forward(self, x, x_lengths, emotion_embedding=None):
55
+ if self.n_vocab != 0:
56
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
57
  if emotion_embedding is not None:
58
  print("emotion added")
59
  x = x + self.emo_proj(emotion_embedding.unsqueeze(1))
60
+ x = torch.transpose(x, 1, -1) # [b, h, t]
61
  x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
62
 
63
  x = self.encoder(x * x_mask, x_mask)
 
66
  m, logs = torch.split(stats, self.out_channels, dim=1)
67
  return x, m, logs, x_mask
68
 
69
+
70
  class PosteriorEncoder(nn.Module):
71
  def __init__(self,
72
+ in_channels,
73
+ out_channels,
74
+ hidden_channels,
75
+ kernel_size,
76
+ dilation_rate,
77
+ n_layers,
78
+ gin_channels=0):
79
  super().__init__()
80
  self.in_channels = in_channels
81
  self.out_channels = out_channels
 
98
  z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
99
  return z, m, logs, x_mask
100
 
101
+
102
  class SynthesizerTrn(models.SynthesizerTrn):
103
  """
104
  Synthesizer for Training
105
  """
106
 
107
+ def __init__(self,
108
+ n_vocab,
109
+ spec_channels,
110
+ segment_size,
111
+ inter_channels,
112
+ hidden_channels,
113
+ filter_channels,
114
+ n_heads,
115
+ n_layers,
116
+ kernel_size,
117
+ p_dropout,
118
+ resblock,
119
+ resblock_kernel_sizes,
120
+ resblock_dilation_sizes,
121
+ upsample_rates,
122
+ upsample_initial_channel,
123
+ upsample_kernel_sizes,
124
+ n_speakers=0,
125
+ gin_channels=0,
126
+ use_sdp=True,
127
+ emotion_embedding=False,
128
+ **kwargs):
129
+
130
+ super().__init__(
131
  n_vocab,
132
  spec_channels,
133
  segment_size,
 
138
  n_layers,
139
  kernel_size,
140
  p_dropout,
141
+ resblock,
142
+ resblock_kernel_sizes,
143
+ resblock_dilation_sizes,
144
+ upsample_rates,
145
+ upsample_initial_channel,
146
  upsample_kernel_sizes,
147
  n_speakers=n_speakers,
148
  gin_channels=gin_channels,
 
150
  **kwargs
151
  )
152
  self.enc_p = TextEncoder(n_vocab,
153
+ inter_channels,
154
+ hidden_channels,
155
+ filter_channels,
156
+ n_heads,
157
+ n_layers,
158
+ kernel_size,
159
+ p_dropout,
160
+ emotion_embedding)
161
  self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
162
 
163
+ def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None,
164
+ emotion_embedding=None):
165
  from ONNXVITS_utils import runonnx
166
+ with torch.no_grad():
167
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emotion_embedding)
168
 
169
  if self.n_speakers > 0:
170
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
171
  else:
172
  g = None
173
 
174
+ # logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
175
  logw = runonnx("ONNX_net/dp.onnx", x=x.numpy(), x_mask=x_mask.numpy(), g=g.numpy())
176
  logw = torch.from_numpy(logw[0])
177
 
 
182
  attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
183
  attn = commons.generate_path(w_ceil, attn_mask)
184
 
185
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
186
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1,
187
+ 2) # [b, t', t], [b, t, d] -> [b, d, t']
188
 
189
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
190
+
191
+ # z = self.flow(z_p, y_mask, g=g, reverse=True)
192
  z = runonnx("ONNX_net/flow.onnx", z_p=z_p.numpy(), y_mask=y_mask.numpy(), g=g.numpy())
193
  z = torch.from_numpy(z[0])
194
 
195
+ # o = self.dec((z * y_mask)[:,:,:max_len], g=g)
196
+ o = runonnx("ONNX_net/dec.onnx", z_in=(z * y_mask)[:, :, :max_len].numpy(), g=g.numpy())
197
  o = torch.from_numpy(o[0])
198
 
199
  return o, attn, y_mask, (z, z_p, m_p, logs_p)
200
 
201
  def predict_duration(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None,
202
+ emotion_embedding=None):
203
  from ONNXVITS_utils import runonnx
204
 
205
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
206
  x, m_p, logs_p, x_mask = runonnx("ONNX_net/enc_p.onnx", x=x.numpy(), x_lengths=x_lengths.numpy())
207
  x = torch.from_numpy(x)
208
  m_p = torch.from_numpy(m_p)
 
210
  x_mask = torch.from_numpy(x_mask)
211
 
212
  if self.n_speakers > 0:
213
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
214
  else:
215
  g = None
216
 
217
+ # logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
218
  logw = runonnx("ONNX_net/dp.onnx", x=x.numpy(), x_mask=x_mask.numpy(), g=g.numpy())
219
  logw = torch.from_numpy(logw[0])
220
 
221
  w = torch.exp(logw) * x_mask * length_scale
222
  w_ceil = torch.ceil(w)
223
+ return list(w_ceil.squeeze())