jbetker commited on
Commit
979ff6e
1 Parent(s): 60d363f

implement clip-guided generation (and never use it...)

Browse files
Files changed (5) hide show
  1. api.py +30 -6
  2. eval_multiple.py +1 -1
  3. models/autoregressive.py +17 -6
  4. read.py +4 -3
  5. sweep.py +8 -9
api.py CHANGED
@@ -76,7 +76,30 @@ def load_conditioning(clip, cond_length=132300):
76
  return mel_clip.unsqueeze(0).cuda()
77
 
78
 
79
- def fix_autoregressive_output(codes, stop_token):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  """
81
  This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
82
  trained on and what the autoregressive code generator creates (which has no padding or end).
@@ -89,7 +112,8 @@ def fix_autoregressive_output(codes, stop_token):
89
  # Strip off the autoregressive stop token and add padding.
90
  stop_token_indices = (codes == stop_token).nonzero()
91
  if len(stop_token_indices) == 0:
92
- print("No stop tokens found, enjoy that output of yours!")
 
93
  return codes
94
  else:
95
  codes[stop_token_indices] = 83
@@ -136,14 +160,14 @@ class TextToSpeech:
136
  heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
137
  train_solo_embeddings=False,
138
  average_conditioning_embeddings=True).cpu().eval()
139
- self.autoregressive.load_state_dict(torch.load('.models/autoregressive_diverse.pth'))
140
 
141
  self.autoregressive_for_latents = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
142
  model_dim=1024,
143
  heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
144
  train_solo_embeddings=False,
145
  average_conditioning_embeddings=True).cpu().eval()
146
- self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_diverse.pth'))
147
 
148
  self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
149
  text_seq_len=350, text_heads=8,
@@ -154,7 +178,7 @@ class TextToSpeech:
154
  self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
155
  in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
156
  layer_drop=0, unconditioned_percentage=0).cpu().eval()
157
- self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
158
 
159
  self.vocoder = UnivNetGenerator().cpu()
160
  self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
@@ -170,7 +194,7 @@ class TextToSpeech:
170
  presets = {
171
  'intelligible': {'temperature': .5, 'length_penalty': 2.0, 'repetition_penalty': 2.0, 'top_p': .5, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': .7, 'diffusion_temperature': .7},
172
  'mid': {'temperature': .7, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .7, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 1.5, 'diffusion_temperature': .8},
173
- 'realistic': {'temperature': .9, 'length_penalty': 1.0, 'repetition_penalty': 1.3, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1},
174
  }
175
  kwargs.update(presets[preset])
176
  return self.tts(text, voice_samples, **kwargs)
 
76
  return mel_clip.unsqueeze(0).cuda()
77
 
78
 
79
+ def clip_guided_generation(autoregressive_model, clip_model, conditioning_input, text_input, num_batches, stop_mel_token,
80
+ tokens_per_clip_inference=10, clip_results_to_reduce_to=8, **generation_kwargs):
81
+ """
82
+ Uses a CLVP model trained to associate full text with **partial** audio clips to pick the best generation candidates
83
+ every few iterations. The top results are then propagated forward through the generation process. Rinse and repeat.
84
+ This is a hybrid between beam search and sampling.
85
+ """
86
+ token_goal = tokens_per_clip_inference
87
+ finished = False
88
+ while not finished and token_goal < autoregressive_model.max_mel_tokens:
89
+ samples = []
90
+ for b in tqdm(range(num_batches)):
91
+ codes = autoregressive_model.inference_speech(conditioning_input, text_input, **generation_kwargs)
92
+ samples.append(codes)
93
+ for batch in samples:
94
+ for i in range(batch.shape[0]):
95
+ batch[i] = fix_autoregressive_output(batch[i], stop_mel_token, complain=False)
96
+ clip_results.append(clip_model(text_input.repeat(batch.shape[0], 1), batch, return_loss=False))
97
+ clip_results = torch.cat(clip_results, dim=0)
98
+ samples = torch.cat(samples, dim=0)
99
+ best_results = samples[torch.topk(clip_results, k=clip_results_to_reduce_to).indices]
100
+
101
+
102
+ def fix_autoregressive_output(codes, stop_token, complain=True):
103
  """
104
  This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
105
  trained on and what the autoregressive code generator creates (which has no padding or end).
 
112
  # Strip off the autoregressive stop token and add padding.
113
  stop_token_indices = (codes == stop_token).nonzero()
114
  if len(stop_token_indices) == 0:
115
+ if complain:
116
+ print("No stop tokens found, enjoy that output of yours!")
117
  return codes
118
  else:
119
  codes[stop_token_indices] = 83
 
160
  heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
161
  train_solo_embeddings=False,
162
  average_conditioning_embeddings=True).cpu().eval()
163
+ self.autoregressive.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth'))
164
 
165
  self.autoregressive_for_latents = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
166
  model_dim=1024,
167
  heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
168
  train_solo_embeddings=False,
169
  average_conditioning_embeddings=True).cpu().eval()
170
+ self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth'))
171
 
172
  self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
173
  text_seq_len=350, text_heads=8,
 
178
  self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
179
  in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
180
  layer_drop=0, unconditioned_percentage=0).cpu().eval()
181
+ self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder_audiobooks.pth'))
182
 
183
  self.vocoder = UnivNetGenerator().cpu()
184
  self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
 
194
  presets = {
195
  'intelligible': {'temperature': .5, 'length_penalty': 2.0, 'repetition_penalty': 2.0, 'top_p': .5, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': .7, 'diffusion_temperature': .7},
196
  'mid': {'temperature': .7, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .7, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 1.5, 'diffusion_temperature': .8},
197
+ 'realistic': {'temperature': 1.0, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1},
198
  }
199
  kwargs.update(presets[preset])
200
  return self.tts(text, voice_samples, **kwargs)
eval_multiple.py CHANGED
@@ -8,7 +8,7 @@ from utils.audio import load_audio
8
  if __name__ == '__main__':
9
  fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
10
  stop_after = 128
11
- outpath_base = 'D:\\tmp\\tortoise-tts-eval\\diverse'
12
  outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
13
 
14
  os.makedirs(outpath_real, exist_ok=True)
 
8
  if __name__ == '__main__':
9
  fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
10
  stop_after = 128
11
+ outpath_base = 'D:\\tmp\\tortoise-tts-eval\\audiobooks'
12
  outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
13
 
14
  os.makedirs(outpath_real, exist_ok=True)
models/autoregressive.py CHANGED
@@ -511,7 +511,8 @@ class UnifiedVoice(nn.Module):
511
  loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
512
  return loss_mel.mean()
513
 
514
- def inference_speech(self, speech_conditioning_input, text_inputs, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
 
515
  seq_length = self.max_mel_tokens + self.max_text_tokens + 2
516
  if not hasattr(self, 'inference_model'):
517
  # TODO: Decouple gpt_config from this inference model.
@@ -541,13 +542,23 @@ class UnifiedVoice(nn.Module):
541
  emb = torch.cat([conds, text_emb], dim=1)
542
  self.inference_model.store_mel_emb(emb)
543
 
544
- fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
545
- fake_inputs[:,-1] = self.start_mel_token
 
 
 
 
 
 
 
 
 
546
 
547
  logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
548
- gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
549
- max_length=fake_inputs.shape[-1] + self.max_mel_tokens - 1, logits_processor=logits_processor, **hf_generate_kwargs)
550
- return gen[:, fake_inputs.shape[1]:]
 
551
 
552
 
553
  if __name__ == '__main__':
 
511
  loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
512
  return loss_mel.mean()
513
 
514
+ def inference_speech(self, speech_conditioning_input, text_inputs, input_tokens=None, num_return_sequences=1,
515
+ max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
516
  seq_length = self.max_mel_tokens + self.max_text_tokens + 2
517
  if not hasattr(self, 'inference_model'):
518
  # TODO: Decouple gpt_config from this inference model.
 
542
  emb = torch.cat([conds, text_emb], dim=1)
543
  self.inference_model.store_mel_emb(emb)
544
 
545
+ fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
546
+ device=text_inputs.device)
547
+ fake_inputs[:, -1] = self.start_mel_token
548
+ trunc_index = fake_inputs.shape[1]
549
+ if input_tokens is None:
550
+ inputs = fake_inputs
551
+ else:
552
+ assert num_return_sequences % input_tokens.shape[0] == 0, "The number of return sequences must be divisible by the number of input sequences"
553
+ fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
554
+ input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
555
+ inputs = torch.cat([fake_inputs, input_tokens], dim=1)
556
 
557
  logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
558
+ max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
559
+ gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
560
+ max_length=max_length, logits_processor=logits_processor, **hf_generate_kwargs)
561
+ return gen[:, trunc_index:]
562
 
563
 
564
  if __name__ == '__main__':
read.py CHANGED
@@ -32,15 +32,16 @@ if __name__ == '__main__':
32
  preselected_cond_voices = {
33
  'emma_stone': ['voices/emma_stone/1.wav','voices/emma_stone/2.wav','voices/emma_stone/3.wav'],
34
  'tom_hanks': ['voices/tom_hanks/1.wav','voices/tom_hanks/2.wav','voices/tom_hanks/3.wav'],
 
35
  }
36
 
37
  parser = argparse.ArgumentParser()
38
  parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
39
- parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='emma_stone')
40
- parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
41
  parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
42
  parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
43
- parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='intelligible')
44
  args = parser.parse_args()
45
  os.makedirs(args.output_path, exist_ok=True)
46
 
 
32
  preselected_cond_voices = {
33
  'emma_stone': ['voices/emma_stone/1.wav','voices/emma_stone/2.wav','voices/emma_stone/3.wav'],
34
  'tom_hanks': ['voices/tom_hanks/1.wav','voices/tom_hanks/2.wav','voices/tom_hanks/3.wav'],
35
+ 'patrick_stewart': ['voices/patrick_stewart/1.wav','voices/patrick_stewart/2.wav','voices/patrick_stewart/3.wav','voices/patrick_stewart/4.wav'],
36
  }
37
 
38
  parser = argparse.ArgumentParser()
39
  parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
40
+ parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='patrick_stewart')
41
+ parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=128)
42
  parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
43
  parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
44
+ parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='realistic')
45
  args = parser.parse_args()
46
  os.makedirs(args.output_path, exist_ok=True)
47
 
sweep.py CHANGED
@@ -25,16 +25,15 @@ def permutations(args):
25
 
26
  if __name__ == '__main__':
27
  fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
28
- stop_after = 128
29
- outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep'
30
  outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
31
 
32
  arg_ranges = {
33
- 'top_p': [.5, 1],
34
- 'temperature': [.5, 1],
35
- 'diffusion_temperature': [.6, 1],
36
- 'cond_free_k': [0, 1, 4],
37
- 'repetition_penalty': [1.0, 2.0]
38
  }
39
  cfgs = permutations(arg_ranges)
40
  shuffle(cfgs)
@@ -56,8 +55,8 @@ if __name__ == '__main__':
56
  path = os.path.join(os.path.dirname(fname), line[1])
57
  cond_audio = load_audio(path, 22050)
58
  torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
59
- sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256,
60
- k=1, diffusion_iterations=70, length_penalty=1.0, **cfg)
61
  down = torchaudio.functional.resample(sample, 24000, 22050)
62
  fout_path = os.path.join(outpath, os.path.basename(line[1]))
63
  torchaudio.save(fout_path, down.squeeze(0), 22050)
 
25
 
26
  if __name__ == '__main__':
27
  fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
28
+ stop_after = 512
29
+ outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep-2'
30
  outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
31
 
32
  arg_ranges = {
33
+ 'top_p': [.8,1],
34
+ 'temperature': [.8,.9,1],
35
+ 'diffusion_temperature': [.8,1],
36
+ 'cond_free_k': [1,2,5,10],
 
37
  }
38
  cfgs = permutations(arg_ranges)
39
  shuffle(cfgs)
 
55
  path = os.path.join(os.path.dirname(fname), line[1])
56
  cond_audio = load_audio(path, 22050)
57
  torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
58
+ sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=32, repetition_penalty=2.0,
59
+ k=1, diffusion_iterations=32, length_penalty=1.0, **cfg)
60
  down = torchaudio.functional.resample(sample, 24000, 22050)
61
  fout_path = os.path.join(outpath, os.path.basename(line[1]))
62
  torchaudio.save(fout_path, down.squeeze(0), 22050)