mrfakename commited on
Commit
5cf7b18
1 Parent(s): 6eb9ea3

[Experimental] Gruut support

Browse files
Files changed (4) hide show
  1. app.py +4 -3
  2. gruut_phonemize.py +10 -0
  3. requirements.txt +2 -1
  4. styletts2importable.py +72 -58
app.py CHANGED
@@ -16,13 +16,13 @@ voices = {}
16
  # else:
17
  for v in voicelist:
18
  voices[v] = styletts2importable.compute_style(f'voices/{v}.wav')
19
- def synthesize(text, voice):
20
  if text.strip() == "":
21
  raise gr.Error("You must enter some text")
22
  if len(text) > 300:
23
  raise gr.Error("Text must be under 300 characters")
24
  v = voice.lower()
25
- return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1))
26
  def clsynthesize(text, voice):
27
  if text.strip() == "":
28
  raise gr.Error("You must enter some text")
@@ -43,10 +43,11 @@ with gr.Blocks() as vctk:
43
  with gr.Column(scale=1):
44
  inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
45
  voice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value='m-us-1', interactive=True)
 
46
  with gr.Column(scale=1):
47
  btn = gr.Button("Synthesize", variant="primary")
48
  audio = gr.Audio(interactive=False, label="Synthesized Audio")
49
- btn.click(synthesize, inputs=[inp, voice], outputs=[audio], concurrency_limit=4)
50
  with gr.Blocks() as clone:
51
  with gr.Row():
52
  with gr.Column(scale=1):
 
16
  # else:
17
  for v in voicelist:
18
  voices[v] = styletts2importable.compute_style(f'voices/{v}.wav')
19
+ def synthesize(text, voice, use_gruut):
20
  if text.strip() == "":
21
  raise gr.Error("You must enter some text")
22
  if len(text) > 300:
23
  raise gr.Error("Text must be under 300 characters")
24
  v = voice.lower()
25
+ return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1, use_gruut=use_gruut))
26
  def clsynthesize(text, voice):
27
  if text.strip() == "":
28
  raise gr.Error("You must enter some text")
 
43
  with gr.Column(scale=1):
44
  inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
45
  voice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value='m-us-1', interactive=True)
46
+ use_gruut = gr.Checkbox(label="Use alternate phonemizer (Gruut) - Experimental")
47
  with gr.Column(scale=1):
48
  btn = gr.Button("Synthesize", variant="primary")
49
  audio = gr.Audio(interactive=False, label="Synthesized Audio")
50
+ btn.click(synthesize, inputs=[inp, voice, use_gruut], outputs=[audio], concurrency_limit=4)
51
  with gr.Blocks() as clone:
52
  with gr.Row():
53
  with gr.Column(scale=1):
gruut_phonemize.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from gruut import sentences
2
+
3
+
4
+ def gphonemize(text):
5
+ phonemes = ''
6
+ for sent in sentences(text, lang="en-us"):
7
+ for word in sent:
8
+ if word.phonemes:
9
+ phonemes += ''.join(word.phonemes)
10
+ return phonemes
requirements.txt CHANGED
@@ -18,4 +18,5 @@ git+https://github.com/resemble-ai/monotonic_align.git
18
  scipy
19
  phonemizer
20
  cached-path
21
- gradio
 
 
18
  scipy
19
  phonemizer
20
  cached-path
21
+ gradio
22
+ gruut
styletts2importable.py CHANGED
@@ -1,4 +1,6 @@
1
  from cached_path import cached_path
 
 
2
 
3
  # from dp.phonemizer import Phonemizer
4
  print("NLTK")
@@ -131,9 +133,12 @@ sampler = DiffusionSampler(
131
  clamp=False
132
  )
133
 
134
- def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
135
  text = text.strip()
136
- ps = global_phonemizer.phonemize([text])
 
 
 
137
  ps = word_tokenize(ps[0])
138
  ps = ' '.join(ps)
139
  tokens = textclenaer(ps)
@@ -200,86 +205,92 @@ def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding
200
 
201
  return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
202
 
203
- def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):
204
- text = text.strip()
205
- ps = global_phonemizer.phonemize([text])
206
- ps = word_tokenize(ps[0])
207
- ps = ' '.join(ps)
208
- ps = ps.replace('``', '"')
209
- ps = ps.replace("''", '"')
 
 
 
210
 
211
- tokens = textclenaer(ps)
212
- tokens.insert(0, 0)
213
- tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
214
 
215
- with torch.no_grad():
216
- input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
217
- text_mask = length_to_mask(input_lengths).to(device)
218
 
219
- t_en = model.text_encoder(tokens, input_lengths, text_mask)
220
- bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
221
- d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
222
 
223
- s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
224
  embedding=bert_dur,
225
  embedding_scale=embedding_scale,
226
- features=ref_s, # reference from the same speaker as the embedding
227
  num_steps=diffusion_steps).squeeze(1)
228
 
229
- if s_prev is not None:
230
- # convex combination of previous and current style
231
- s_pred = t * s_prev + (1 - t) * s_pred
232
 
233
- s = s_pred[:, 128:]
234
- ref = s_pred[:, :128]
235
 
236
- ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
237
- s = beta * s + (1 - beta) * ref_s[:, 128:]
238
 
239
- s_pred = torch.cat([ref, s], dim=-1)
240
 
241
- d = model.predictor.text_encoder(d_en,
242
  s, input_lengths, text_mask)
243
 
244
- x, _ = model.predictor.lstm(d)
245
- duration = model.predictor.duration_proj(x)
246
 
247
- duration = torch.sigmoid(duration).sum(axis=-1)
248
- pred_dur = torch.round(duration.squeeze()).clamp(min=1)
249
 
250
 
251
- pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
252
- c_frame = 0
253
- for i in range(pred_aln_trg.size(0)):
254
- pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
255
- c_frame += int(pred_dur[i].data)
256
 
257
- # encode prosody
258
- en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
259
- if model_params.decoder.type == "hifigan":
260
- asr_new = torch.zeros_like(en)
261
- asr_new[:, :, 0] = en[:, :, 0]
262
- asr_new[:, :, 1:] = en[:, :, 0:-1]
263
- en = asr_new
264
 
265
- F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
266
 
267
- asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
268
- if model_params.decoder.type == "hifigan":
269
- asr_new = torch.zeros_like(asr)
270
- asr_new[:, :, 0] = asr[:, :, 0]
271
- asr_new[:, :, 1:] = asr[:, :, 0:-1]
272
- asr = asr_new
273
 
274
- out = model.decoder(asr,
275
- F0_pred, N_pred, ref.squeeze().unsqueeze(0))
276
 
277
 
278
- return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
279
 
280
- def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
281
  text = text.strip()
282
- ps = global_phonemizer.phonemize([text])
 
 
 
283
  ps = word_tokenize(ps[0])
284
  ps = ' '.join(ps)
285
 
@@ -288,7 +299,10 @@ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=
288
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
289
 
290
  ref_text = ref_text.strip()
291
- ps = global_phonemizer.phonemize([ref_text])
 
 
 
292
  ps = word_tokenize(ps[0])
293
  ps = ' '.join(ps)
294
 
 
1
  from cached_path import cached_path
2
+ print("GRUUT")
3
+ from gruut_phonemize import gphonemize
4
 
5
  # from dp.phonemizer import Phonemizer
6
  print("NLTK")
 
133
  clamp=False
134
  )
135
 
136
+ def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
137
  text = text.strip()
138
+ if use_gruut:
139
+ ps = gphonemize(text)
140
+ else:
141
+ ps = global_phonemizer.phonemize([text])
142
  ps = word_tokenize(ps[0])
143
  ps = ' '.join(ps)
144
  tokens = textclenaer(ps)
 
205
 
206
  return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
207
 
208
+ def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
209
+ text = text.strip()
210
+ if use_gruut:
211
+ ps = gphonemize(text)
212
+ else:
213
+ ps = global_phonemizer.phonemize([text])
214
+ ps = word_tokenize(ps[0])
215
+ ps = ' '.join(ps)
216
+ ps = ps.replace('``', '"')
217
+ ps = ps.replace("''", '"')
218
 
219
+ tokens = textclenaer(ps)
220
+ tokens.insert(0, 0)
221
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
222
 
223
+ with torch.no_grad():
224
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
225
+ text_mask = length_to_mask(input_lengths).to(device)
226
 
227
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
228
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
229
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
230
 
231
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
232
  embedding=bert_dur,
233
  embedding_scale=embedding_scale,
234
+ features=ref_s, # reference from the same speaker as the embedding
235
  num_steps=diffusion_steps).squeeze(1)
236
 
237
+ if s_prev is not None:
238
+ # convex combination of previous and current style
239
+ s_pred = t * s_prev + (1 - t) * s_pred
240
 
241
+ s = s_pred[:, 128:]
242
+ ref = s_pred[:, :128]
243
 
244
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
245
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
246
 
247
+ s_pred = torch.cat([ref, s], dim=-1)
248
 
249
+ d = model.predictor.text_encoder(d_en,
250
  s, input_lengths, text_mask)
251
 
252
+ x, _ = model.predictor.lstm(d)
253
+ duration = model.predictor.duration_proj(x)
254
 
255
+ duration = torch.sigmoid(duration).sum(axis=-1)
256
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
257
 
258
 
259
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
260
+ c_frame = 0
261
+ for i in range(pred_aln_trg.size(0)):
262
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
263
+ c_frame += int(pred_dur[i].data)
264
 
265
+ # encode prosody
266
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
267
+ if model_params.decoder.type == "hifigan":
268
+ asr_new = torch.zeros_like(en)
269
+ asr_new[:, :, 0] = en[:, :, 0]
270
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
271
+ en = asr_new
272
 
273
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
274
 
275
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
276
+ if model_params.decoder.type == "hifigan":
277
+ asr_new = torch.zeros_like(asr)
278
+ asr_new[:, :, 0] = asr[:, :, 0]
279
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
280
+ asr = asr_new
281
 
282
+ out = model.decoder(asr,
283
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
284
 
285
 
286
+ return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
287
 
288
+ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
289
  text = text.strip()
290
+ if use_gruut:
291
+ ps = gphonemize(text)
292
+ else:
293
+ ps = global_phonemizer.phonemize([text])
294
  ps = word_tokenize(ps[0])
295
  ps = ' '.join(ps)
296
 
 
299
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
300
 
301
  ref_text = ref_text.strip()
302
+ if use_gruut:
303
+ ps = gphonemize(text)
304
+ else:
305
+ ps = global_phonemizer.phonemize([ref_text])
306
  ps = word_tokenize(ps[0])
307
  ps = ' '.join(ps)
308