jbetker commited on
Commit
c66954b
1 Parent(s): 9ad0f0e

Add in ASR filtration

Browse files
do_tts.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import torch.nn.functional as F
8
  import torchaudio
9
  import progressbar
 
10
 
11
  from models.diffusion_decoder import DiffusionTts
12
  from models.autoregressive import UnifiedVoice
@@ -17,7 +18,7 @@ from models.text_voice_clip import VoiceCLIP
17
  from models.vocoder import UnivNetGenerator
18
  from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
19
  from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
20
- from utils.tokenizer import VoiceBpeTokenizer
21
 
22
  pbar = None
23
  def download_models():
@@ -47,13 +48,13 @@ def download_models():
47
  print('Done.')
48
 
49
 
50
- def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
51
  """
52
  Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
53
  """
54
  return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
55
  model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
56
- conditioning_free=True, conditioning_free_k=1)
57
 
58
 
59
  def load_conditioning(path, sample_rate=22050, cond_length=132300):
@@ -109,11 +110,12 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
109
  mel = torch.nn.functional.pad(mel_codes, (0, gap))
110
 
111
  output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
 
112
  if mean:
113
  mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
114
- model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
115
  else:
116
- mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
117
  return denormalize_tacotron_mel(mel)[:,:,:msl*4]
118
 
119
 
@@ -136,9 +138,9 @@ if __name__ == '__main__':
136
  parser = argparse.ArgumentParser()
137
  parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
138
  parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
139
- parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
140
- parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=16)
141
- parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2)
142
  parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
143
  args = parser.parse_args()
144
 
@@ -192,7 +194,7 @@ if __name__ == '__main__':
192
  return_loss=False))
193
  clip_results = torch.cat(clip_results, dim=0)
194
  samples = torch.cat(samples, dim=0)
195
- best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices]
196
 
197
  # Delete the autoregressive and clip models to free up GPU memory
198
  del samples, clip
@@ -210,12 +212,32 @@ if __name__ == '__main__':
210
  vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
211
  vocoder = vocoder.cuda()
212
  vocoder.eval(inference=True)
213
- diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
 
214
 
215
  print("Performing vocoding..")
216
- # Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
217
  for b in range(best_results.shape[0]):
218
  code = best_results[b].unsqueeze(0)
219
- mel = do_spectrogram_diffusion(diffusion, diffuser, code, cond_diffusion, mean=False)
220
  wav = vocoder.inference(mel)
221
- torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 24000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch.nn.functional as F
8
  import torchaudio
9
  import progressbar
10
+ import ocotillo
11
 
12
  from models.diffusion_decoder import DiffusionTts
13
  from models.autoregressive import UnifiedVoice
 
18
  from models.vocoder import UnivNetGenerator
19
  from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
20
  from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
21
+ from utils.tokenizer import VoiceBpeTokenizer, lev_distance
22
 
23
  pbar = None
24
  def download_models():
 
48
  print('Done.')
49
 
50
 
51
+ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True):
52
  """
53
  Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
54
  """
55
  return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
56
  model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
57
+ conditioning_free=cond_free, conditioning_free_k=1)
58
 
59
 
60
  def load_conditioning(path, sample_rate=22050, cond_length=132300):
 
110
  mel = torch.nn.functional.pad(mel_codes, (0, gap))
111
 
112
  output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
113
+ precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
114
  if mean:
115
  mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
116
+ model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
117
  else:
118
+ mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
119
  return denormalize_tacotron_mel(mel)[:,:,:msl*4]
120
 
121
 
 
138
  parser = argparse.ArgumentParser()
139
  parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
140
  parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
141
+ parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=1024)
142
+ parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=32)
143
+ parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
144
  parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
145
  args = parser.parse_args()
146
 
 
194
  return_loss=False))
195
  clip_results = torch.cat(clip_results, dim=0)
196
  samples = torch.cat(samples, dim=0)
197
+ best_results = samples[torch.topk(clip_results, k=args.num_diffusion_samples).indices]
198
 
199
  # Delete the autoregressive and clip models to free up GPU memory
200
  del samples, clip
 
212
  vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
213
  vocoder = vocoder.cuda()
214
  vocoder.eval(inference=True)
215
+ initial_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=40, cond_free=False)
216
+ final_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=500)
217
 
218
  print("Performing vocoding..")
219
+ wav_candidates = []
220
  for b in range(best_results.shape[0]):
221
  code = best_results[b].unsqueeze(0)
222
+ mel = do_spectrogram_diffusion(diffusion, initial_diffuser, code, cond_diffusion, mean=False)
223
  wav = vocoder.inference(mel)
224
+ wav_candidates.append(wav.cpu())
225
+
226
+ # Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
227
+ transcriber = ocotillo.Transcriber(on_cuda=True)
228
+ transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
229
+ best = 99999999
230
+ for i, transcription in enumerate(transcriptions):
231
+ dist = lev_distance(transcription, args.text.lower())
232
+ if dist < best:
233
+ best = dist
234
+ best_codes = best_results[i].unsqueeze(0)
235
+ best_wav = wav_candidates[i]
236
+ del transcriber
237
+ torchaudio.save(os.path.join(args.output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
238
+
239
+ # Perform diffusion again with the high-quality diffuser.
240
+ mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
241
+ wav = vocoder.inference(mel)
242
+ torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)
243
+
models/diffusion_decoder.py CHANGED
@@ -486,66 +486,40 @@ class DiffusionTts(nn.Module):
486
  aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
487
  return x, aligned_conditioning
488
 
489
- def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
490
- """
491
- Apply the model to an input batch.
492
-
493
- :param x: an [N x C x ...] Tensor of inputs.
494
- :param timesteps: a 1-D batch of timesteps.
495
- :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
496
- :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
497
- :param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
498
- :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
499
- :return: an [N x C x ...] Tensor of outputs.
500
- """
501
- assert conditioning_input is not None
502
- if self.super_sampling_enabled:
503
- assert lr_input is not None
504
- if self.training and self.super_sampling_max_noising_factor > 0:
505
- noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
506
- lr_input = torch.randn_like(lr_input) * noising_factor + lr_input
507
- lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
508
- x = torch.cat([x, lr_input], dim=1)
509
-
510
  # Shuffle aligned_latent to BxCxS format
511
  if is_latent(aligned_conditioning):
512
  aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
513
 
514
- # Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
515
- orig_x_shape = x.shape[-1]
516
- x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
 
 
 
 
 
 
 
 
517
 
518
- with autocast(x.device.type, enabled=self.enable_fp16):
519
- hs = []
520
- time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
521
 
522
- # Note: this block does not need to repeated on inference, since it is not timestep-dependent.
523
  if conditioning_free:
524
  code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
525
  else:
526
- cond_emb = self.contextual_embedder(conditioning_input)
527
- if len(cond_emb.shape) == 3: # Just take the first element.
528
- cond_emb = cond_emb[:, :, 0]
529
- if is_latent(aligned_conditioning):
530
- code_emb = self.latent_converter(aligned_conditioning)
531
- else:
532
- code_emb = self.code_converter(aligned_conditioning)
533
- cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
534
- code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
535
- # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
536
- if self.training and self.unconditioned_percentage > 0:
537
- unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
538
- device=code_emb.device) < self.unconditioned_percentage
539
- code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
540
- code_emb)
541
-
542
- # Everything after this comment is timestep dependent.
543
  code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
544
  code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
545
 
546
  first = True
547
  time_emb = time_emb.float()
548
  h = x
 
549
  for k, module in enumerate(self.input_blocks):
550
  if isinstance(module, nn.Conv1d):
551
  h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
@@ -565,14 +539,7 @@ class DiffusionTts(nn.Module):
565
  h = h.float()
566
  out = self.out(h)
567
 
568
- # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
569
- extraneous_addition = 0
570
- params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
571
- for p in params:
572
- extraneous_addition = extraneous_addition + p.mean()
573
- out = out + extraneous_addition * 0
574
-
575
- return out[:, :, :orig_x_shape]
576
 
577
 
578
  if __name__ == '__main__':
 
486
  aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
487
  return x, aligned_conditioning
488
 
489
+ def timestep_independent(self, aligned_conditioning, conditioning_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  # Shuffle aligned_latent to BxCxS format
491
  if is_latent(aligned_conditioning):
492
  aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
493
 
494
+ with autocast(aligned_conditioning.device.type, enabled=self.enable_fp16):
495
+ cond_emb = self.contextual_embedder(conditioning_input)
496
+ if len(cond_emb.shape) == 3: # Just take the first element.
497
+ cond_emb = cond_emb[:, :, 0]
498
+ if is_latent(aligned_conditioning):
499
+ code_emb = self.latent_converter(aligned_conditioning)
500
+ else:
501
+ code_emb = self.code_converter(aligned_conditioning)
502
+ cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
503
+ code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
504
+ return code_emb
505
 
506
+ def forward(self, x, timesteps, precomputed_aligned_embeddings, conditioning_free=False):
507
+ assert x.shape[-1] % self.alignment_size == 0
 
508
 
509
+ with autocast(x.device.type, enabled=self.enable_fp16):
510
  if conditioning_free:
511
  code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
512
  else:
513
+ code_emb = precomputed_aligned_embeddings
514
+
515
+ time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
517
  code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
518
 
519
  first = True
520
  time_emb = time_emb.float()
521
  h = x
522
+ hs = []
523
  for k, module in enumerate(self.input_blocks):
524
  if isinstance(module, nn.Conv1d):
525
  h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
 
539
  h = h.float()
540
  out = self.out(h)
541
 
542
+ return out
 
 
 
 
 
 
 
543
 
544
 
545
  if __name__ == '__main__':
requirements.txt CHANGED
@@ -7,4 +7,5 @@ inflect
7
  progressbar
8
  einops
9
  unidecode
10
- x-transformers
 
 
7
  progressbar
8
  einops
9
  unidecode
10
+ x-transformers
11
+ ocotillo
utils/__init__.py ADDED
File without changes
utils/tokenizer.py CHANGED
@@ -148,6 +148,20 @@ def english_cleaners(text):
148
  text = text.replace('"', '')
149
  return text
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  class VoiceBpeTokenizer:
153
  def __init__(self, vocab_file='data/tokenizer.json'):
 
148
  text = text.replace('"', '')
149
  return text
150
 
151
+ def lev_distance(s1, s2):
152
+ if len(s1) > len(s2):
153
+ s1, s2 = s2, s1
154
+
155
+ distances = range(len(s1) + 1)
156
+ for i2, c2 in enumerate(s2):
157
+ distances_ = [i2 + 1]
158
+ for i1, c1 in enumerate(s1):
159
+ if c1 == c2:
160
+ distances_.append(distances[i1])
161
+ else:
162
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
163
+ distances = distances_
164
+ return distances[-1]
165
 
166
  class VoiceBpeTokenizer:
167
  def __init__(self, vocab_file='data/tokenizer.json'):