jbetker commited on
Commit
f625a9e
1 Parent(s): b78ae92

Update API to have more expressive interface for controlling various generation knobs

Browse files

- Also adds typical decoder support; unfortunately this does not work well with the current model.

api.py CHANGED
@@ -49,13 +49,13 @@ def download_models():
49
  print('Done.')
50
 
51
 
52
- def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True):
53
  """
54
  Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
55
  """
56
  return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
57
  model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
58
- conditioning_free=cond_free, conditioning_free_k=1)
59
 
60
 
61
  def load_conditioning(clip, cond_length=132300):
@@ -96,7 +96,7 @@ def fix_autoregressive_output(codes, stop_token):
96
  return codes
97
 
98
 
99
- def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False):
100
  """
101
  Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
102
  """
@@ -111,11 +111,10 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
111
 
112
  output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
113
  precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
114
- if mean:
115
- mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
116
- model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
117
- else:
118
- mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
119
  return denormalize_tacotron_mel(mel)[:,:,:msl*4]
120
 
121
 
@@ -150,7 +149,12 @@ class TextToSpeech:
150
  self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
151
  self.vocoder.eval(inference=True)
152
 
153
- def tts(self, text, voice_samples, num_autoregressive_samples=512, k=1, diffusion_iterations=100, cond_free=True):
 
 
 
 
 
154
  text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
155
  text = F.pad(text, (0, 1)) # This may not be necessary.
156
 
@@ -167,7 +171,7 @@ class TextToSpeech:
167
  else:
168
  cond_diffusion = cond_diffusion[:, :88200]
169
 
170
- diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free)
171
 
172
  with torch.no_grad():
173
  samples = []
@@ -175,11 +179,16 @@ class TextToSpeech:
175
  stop_mel_token = self.autoregressive.stop_mel_token
176
  self.autoregressive = self.autoregressive.cuda()
177
  for b in tqdm(range(num_batches)):
178
- codes = self.autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True,
179
- top_k=50, top_p=.95,
180
- temperature=.9,
181
- num_return_sequences=self.autoregressive_batch_size,
182
- length_penalty=1)
 
 
 
 
 
183
  padding_needed = 250 - codes.shape[1]
184
  codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
185
  samples.append(codes)
@@ -203,7 +212,7 @@ class TextToSpeech:
203
  self.vocoder = self.vocoder.cuda()
204
  for b in range(best_results.shape[0]):
205
  code = best_results[b].unsqueeze(0)
206
- mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, mean=False)
207
  wav = self.vocoder.inference(mel)
208
  wav_candidates.append(wav.cpu())
209
  self.diffusion = self.diffusion.cpu()
 
49
  print('Done.')
50
 
51
 
52
+ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
53
  """
54
  Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
55
  """
56
  return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
57
  model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
58
+ conditioning_free=cond_free, conditioning_free_k=cond_free_k)
59
 
60
 
61
  def load_conditioning(clip, cond_length=132300):
 
96
  return codes
97
 
98
 
99
+ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, temperature=1):
100
  """
101
  Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
102
  """
 
111
 
112
  output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
113
  precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
114
+
115
+ noise = torch.randn(output_shape, device=mel_codes.device) * temperature
116
+ mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
117
+ model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
 
118
  return denormalize_tacotron_mel(mel)[:,:,:msl*4]
119
 
120
 
 
149
  self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
150
  self.vocoder.eval(inference=True)
151
 
152
+ def tts(self, text, voice_samples, k=1,
153
+ # autoregressive generation parameters follow
154
+ num_autoregressive_samples=512, temperature=.9, length_penalty=1, repetition_penalty=1.0, top_k=50, top_p=.95,
155
+ typical_sampling=False, typical_mass=.9,
156
+ # diffusion generation parameters follow
157
+ diffusion_iterations=100, cond_free=True, cond_free_k=1, diffusion_temperature=1,):
158
  text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
159
  text = F.pad(text, (0, 1)) # This may not be necessary.
160
 
 
171
  else:
172
  cond_diffusion = cond_diffusion[:, :88200]
173
 
174
+ diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
175
 
176
  with torch.no_grad():
177
  samples = []
 
179
  stop_mel_token = self.autoregressive.stop_mel_token
180
  self.autoregressive = self.autoregressive.cuda()
181
  for b in tqdm(range(num_batches)):
182
+ codes = self.autoregressive.inference_speech(conds, text,
183
+ do_sample=True,
184
+ top_k=top_k,
185
+ top_p=top_p,
186
+ temperature=temperature,
187
+ num_return_sequences=self.autoregressive_batch_size,
188
+ length_penalty=length_penalty,
189
+ repetition_penalty=repetition_penalty,
190
+ typical_sampling=typical_sampling,
191
+ typical_mass=typical_mass)
192
  padding_needed = 250 - codes.shape[1]
193
  codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
194
  samples.append(codes)
 
212
  self.vocoder = self.vocoder.cuda()
213
  for b in range(best_results.shape[0]):
214
  code = best_results[b].unsqueeze(0)
215
+ mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, temperature=diffusion_temperature)
216
  wav = self.vocoder.inference(mel)
217
  wav_candidates.append(wav.cpu())
218
  self.diffusion = self.diffusion.cpu()
eval_multiple.py CHANGED
@@ -7,7 +7,7 @@ from utils.audio import load_audio
7
 
8
  if __name__ == '__main__':
9
  fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
10
- outpath = 'D:\\tmp\\tortoise-tts-eval\\baseline'
11
  outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
12
 
13
  os.makedirs(outpath, exist_ok=True)
@@ -24,7 +24,8 @@ if __name__ == '__main__':
24
  path = os.path.join(os.path.dirname(fname), line[1])
25
  cond_audio = load_audio(path, 22050)
26
  torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
27
- sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=512, k=1, diffusion_iterations=200, cond_free=True)
 
28
  down = torchaudio.functional.resample(sample, 24000, 22050)
29
  fout_path = os.path.join(outpath, os.path.basename(line[1]))
30
  torchaudio.save(fout_path, down.squeeze(0), 22050)
 
7
 
8
  if __name__ == '__main__':
9
  fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
10
+ outpath = 'D:\\tmp\\tortoise-tts-eval\\redo_outlier'
11
  outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
12
 
13
  os.makedirs(outpath, exist_ok=True)
 
24
  path = os.path.join(os.path.dirname(fname), line[1])
25
  cond_audio = load_audio(path, 22050)
26
  torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
27
+ sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256, k=1, diffusion_iterations=200, cond_free=False,
28
+ top_k=None, top_p=.95, typical_sampling=False, temperature=.7, length_penalty=.5, repetition_penalty=1)
29
  down = torchaudio.functional.resample(sample, 24000, 22050)
30
  fout_path = os.path.join(outpath, os.path.basename(line[1]))
31
  torchaudio.save(fout_path, down.squeeze(0), 22050)
models/autoregressive.py CHANGED
@@ -3,11 +3,11 @@ import functools
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- from transformers import GPT2Config, GPT2PreTrainedModel
7
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
  from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
9
  from models.arch_util import AttentionBlock
10
-
11
 
12
 
13
  def null_position_embeddings(range, dim):
@@ -497,7 +497,7 @@ class UnifiedVoice(nn.Module):
497
  loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
498
  return loss_mel.mean()
499
 
500
- def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
501
  seq_length = self.max_mel_tokens + self.max_text_tokens + 2
502
  if not hasattr(self, 'inference_model'):
503
  # TODO: Decouple gpt_config from this inference model.
@@ -530,8 +530,9 @@ class UnifiedVoice(nn.Module):
530
  fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
531
  fake_inputs[:,-1] = self.start_mel_token
532
 
 
533
  gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
534
- max_length=seq_length, **hf_generate_kwargs)
535
  return gen[:, fake_inputs.shape[1]:]
536
 
537
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
7
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
  from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
9
  from models.arch_util import AttentionBlock
10
+ from utils.typical_sampling import TypicalLogitsWarper
11
 
12
 
13
  def null_position_embeddings(range, dim):
 
497
  loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
498
  return loss_mel.mean()
499
 
500
+ def inference_speech(self, speech_conditioning_input, text_inputs, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
501
  seq_length = self.max_mel_tokens + self.max_text_tokens + 2
502
  if not hasattr(self, 'inference_model'):
503
  # TODO: Decouple gpt_config from this inference model.
 
530
  fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
531
  fake_inputs[:,-1] = self.start_mel_token
532
 
533
+ logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
534
  gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
535
+ max_length=fake_inputs.shape[-1] + self.max_mel_tokens - 1, logits_processor=logits_processor, **hf_generate_kwargs)
536
  return gen[:, fake_inputs.shape[1]:]
537
 
538
 
utils/typical_sampling.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import LogitsWarper
3
+
4
+
5
+ class TypicalLogitsWarper(LogitsWarper):
6
+ def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
7
+ self.filter_value = filter_value
8
+ self.mass = mass
9
+ self.min_tokens_to_keep = min_tokens_to_keep
10
+
11
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
12
+ # calculate entropy
13
+ normalized = torch.nn.functional.log_softmax(scores, dim=-1)
14
+ p = torch.exp(normalized)
15
+ ent = -(normalized * p).nansum(-1, keepdim=True)
16
+
17
+ # shift and sort
18
+ shifted_scores = torch.abs((-normalized) - ent)
19
+ sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
20
+ sorted_logits = scores.gather(-1, sorted_indices)
21
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
22
+
23
+ # Remove tokens with cumulative mass above the threshold
24
+ last_ind = (cumulative_probs < self.mass).sum(dim=1)
25
+ last_ind[last_ind < 0] = 0
26
+ sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
27
+ if self.min_tokens_to_keep > 1:
28
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
29
+ sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
30
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
31
+
32
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
33
+ return scores