CodingBillionaire commited on
Commit
ee04bc2
·
1 Parent(s): 579b2c4

Upload 132 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. tortoise/__init__.py +0 -0
  3. tortoise/api.py +752 -0
  4. tortoise/data/got.txt +276 -0
  5. tortoise/data/layman.txt +0 -0
  6. tortoise/data/mel_norms.pth +3 -0
  7. tortoise/data/riding_hood.txt +54 -0
  8. tortoise/data/seal_copypasta.txt +1 -0
  9. tortoise/data/tokenizer.json +1 -0
  10. tortoise/do_tts.py +102 -0
  11. tortoise/eval.py +44 -0
  12. tortoise/get_conditioning_latents.py +39 -0
  13. tortoise/is_this_from_tortoise.py +21 -0
  14. tortoise/models/__init__.py +0 -0
  15. tortoise/models/arch_util.py +424 -0
  16. tortoise/models/autoregressive.py +704 -0
  17. tortoise/models/classifier.py +166 -0
  18. tortoise/models/clvp.py +173 -0
  19. tortoise/models/cvvp.py +156 -0
  20. tortoise/models/diffusion_decoder.py +445 -0
  21. tortoise/models/random_latent_generator.py +56 -0
  22. tortoise/models/transformer.py +241 -0
  23. tortoise/models/vocoder.py +400 -0
  24. tortoise/models/xtransformers.py +1432 -0
  25. tortoise/read.py +157 -0
  26. tortoise/utils/__init__.py +0 -0
  27. tortoise/utils/audio.py +216 -0
  28. tortoise/utils/diffusion.py +1277 -0
  29. tortoise/utils/stft.py +215 -0
  30. tortoise/utils/text.py +144 -0
  31. tortoise/utils/tokenizer.py +202 -0
  32. tortoise/utils/typical_sampling.py +44 -0
  33. tortoise/utils/wav2vec_alignment.py +173 -0
  34. tortoise/voices/angie/1.wav +0 -0
  35. tortoise/voices/angie/2.wav +3 -0
  36. tortoise/voices/angie/3.wav +0 -0
  37. tortoise/voices/applejack/1.wav +0 -0
  38. tortoise/voices/applejack/2.wav +0 -0
  39. tortoise/voices/applejack/3.wav +0 -0
  40. tortoise/voices/cond_latent_example/pat.pth +3 -0
  41. tortoise/voices/daniel/1.wav +0 -0
  42. tortoise/voices/daniel/2.wav +0 -0
  43. tortoise/voices/daniel/3.wav +0 -0
  44. tortoise/voices/daniel/4.wav +0 -0
  45. tortoise/voices/deniro/1.wav +0 -0
  46. tortoise/voices/deniro/2.wav +3 -0
  47. tortoise/voices/deniro/3.wav +0 -0
  48. tortoise/voices/deniro/4.wav +0 -0
  49. tortoise/voices/deutsch/de_speaker_2.mp3 +0 -0
  50. tortoise/voices/deutsch/de_speaker_3.mp3 +0 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tortoise/voices/angie/2.wav filter=lfs diff=lfs merge=lfs -text
37
+ tortoise/voices/deniro/2.wav filter=lfs diff=lfs merge=lfs -text
38
+ tortoise/voices/train_lescault/lescault_new4.wav filter=lfs diff=lfs merge=lfs -text
tortoise/__init__.py ADDED
File without changes
tortoise/api.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ from time import time
5
+ from urllib import request
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import progressbar
10
+ import torchaudio
11
+
12
+ from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead
13
+ from tortoise.models.diffusion_decoder import DiffusionTts
14
+ from tortoise.models.autoregressive import UnifiedVoice
15
+ from tqdm import tqdm
16
+
17
+ from tortoise.models.arch_util import TorchMelSpectrogram
18
+ from tortoise.models.clvp import CLVP
19
+ from tortoise.models.cvvp import CVVP
20
+ from tortoise.models.random_latent_generator import RandomLatentConverter
21
+ from tortoise.models.vocoder import UnivNetGenerator
22
+ from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
23
+ from tortoise.utils.diffusion import (
24
+ SpacedDiffusion,
25
+ space_timesteps,
26
+ get_named_beta_schedule,
27
+ )
28
+ from tortoise.utils.tokenizer import VoiceBpeTokenizer
29
+ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
30
+
31
+ pbar = None
32
+
33
+ DEFAULT_MODELS_DIR = os.path.join(
34
+ os.path.expanduser("~"), ".cache", "tortoise", "models"
35
+ )
36
+ MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
37
+ MODELS = {
38
+ "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
39
+ "classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth",
40
+ "clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth",
41
+ "cvvp.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth",
42
+ "diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth",
43
+ "vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth",
44
+ "rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth",
45
+ "rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth",
46
+ }
47
+
48
+
49
+ def download_models(specific_models=None):
50
+ """
51
+ Call to download all the models that Tortoise uses.
52
+ """
53
+ os.makedirs(MODELS_DIR, exist_ok=True)
54
+
55
+ def show_progress(block_num, block_size, total_size):
56
+ global pbar
57
+ if pbar is None:
58
+ pbar = progressbar.ProgressBar(maxval=total_size)
59
+ pbar.start()
60
+
61
+ downloaded = block_num * block_size
62
+ if downloaded < total_size:
63
+ pbar.update(downloaded)
64
+ else:
65
+ pbar.finish()
66
+ pbar = None
67
+
68
+ for model_name, url in MODELS.items():
69
+ if specific_models is not None and model_name not in specific_models:
70
+ continue
71
+ model_path = os.path.join(MODELS_DIR, model_name)
72
+ if os.path.exists(model_path):
73
+ continue
74
+ print(f"Downloading {model_name} from {url}...")
75
+ request.urlretrieve(url, model_path, show_progress)
76
+ print("Done.")
77
+
78
+
79
+ def get_model_path(model_name, models_dir=MODELS_DIR):
80
+ """
81
+ Get path to given model, download it if it doesn't exist.
82
+ """
83
+ if model_name not in MODELS:
84
+ raise ValueError(f"Model {model_name} not found in available models.")
85
+ model_path = os.path.join(models_dir, model_name)
86
+ if not os.path.exists(model_path) and models_dir == MODELS_DIR:
87
+ download_models([model_name])
88
+ return model_path
89
+
90
+
91
+ def pad_or_truncate(t, length):
92
+ """
93
+ Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
94
+ """
95
+ if t.shape[-1] == length:
96
+ return t
97
+ elif t.shape[-1] < length:
98
+ return F.pad(t, (0, length - t.shape[-1]))
99
+ else:
100
+ return t[..., :length]
101
+
102
+
103
+ def load_discrete_vocoder_diffuser(
104
+ trained_diffusion_steps=4000,
105
+ desired_diffusion_steps=200,
106
+ cond_free=True,
107
+ cond_free_k=1,
108
+ ):
109
+ """
110
+ Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
111
+ """
112
+ return SpacedDiffusion(
113
+ use_timesteps=space_timesteps(
114
+ trained_diffusion_steps, [desired_diffusion_steps]
115
+ ),
116
+ model_mean_type="epsilon",
117
+ model_var_type="learned_range",
118
+ loss_type="mse",
119
+ betas=get_named_beta_schedule("linear", trained_diffusion_steps),
120
+ conditioning_free=cond_free,
121
+ conditioning_free_k=cond_free_k,
122
+ )
123
+
124
+
125
+ def format_conditioning(clip, cond_length=132300, device="cuda"):
126
+ """
127
+ Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
128
+ """
129
+ gap = clip.shape[-1] - cond_length
130
+ if gap < 0:
131
+ clip = F.pad(clip, pad=(0, abs(gap)))
132
+ elif gap > 0:
133
+ rand_start = random.randint(0, gap)
134
+ clip = clip[:, rand_start : rand_start + cond_length]
135
+ mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
136
+ return mel_clip.unsqueeze(0).to(device)
137
+
138
+
139
+ def fix_autoregressive_output(codes, stop_token, complain=True):
140
+ """
141
+ This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
142
+ trained on and what the autoregressive code generator creates (which has no padding or end).
143
+ This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
144
+ a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
145
+ and copying out the last few codes.
146
+
147
+ Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
148
+ """
149
+ # Strip off the autoregressive stop token and add padding.
150
+ stop_token_indices = (codes == stop_token).nonzero()
151
+ if len(stop_token_indices) == 0:
152
+ if complain:
153
+ print(
154
+ "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
155
+ "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
156
+ "try breaking up your input text."
157
+ )
158
+ return codes
159
+ else:
160
+ codes[stop_token_indices] = 83
161
+ stm = stop_token_indices.min().item()
162
+ codes[stm:] = 83
163
+ if stm - 3 < codes.shape[0]:
164
+ codes[-3] = 45
165
+ codes[-2] = 45
166
+ codes[-1] = 248
167
+
168
+ return codes
169
+
170
+
171
+ def do_spectrogram_diffusion(
172
+ diffusion_model,
173
+ diffuser,
174
+ latents,
175
+ conditioning_latents,
176
+ temperature=1,
177
+ verbose=True,
178
+ ):
179
+ """
180
+ Uses the specified diffusion model to convert discrete codes into a spectrogram.
181
+ """
182
+ with torch.no_grad():
183
+ output_seq_len = (
184
+ latents.shape[1] * 4 * 24000 // 22050
185
+ ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
186
+ output_shape = (latents.shape[0], 100, output_seq_len)
187
+ precomputed_embeddings = diffusion_model.timestep_independent(
188
+ latents, conditioning_latents, output_seq_len, False
189
+ )
190
+
191
+ noise = torch.randn(output_shape, device=latents.device) * temperature
192
+ mel = diffuser.p_sample_loop(
193
+ diffusion_model,
194
+ output_shape,
195
+ noise=noise,
196
+ model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
197
+ progress=verbose,
198
+ )
199
+ return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
200
+
201
+
202
+ def classify_audio_clip(clip):
203
+ """
204
+ Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
205
+ :param clip: torch tensor containing audio waveform data (get it from load_audio)
206
+ :return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
207
+ """
208
+ classifier = AudioMiniEncoderWithClassifierHead(
209
+ 2,
210
+ spec_dim=1,
211
+ embedding_dim=512,
212
+ depth=5,
213
+ downsample_factor=4,
214
+ resnet_blocks=2,
215
+ attn_blocks=4,
216
+ num_attn_heads=4,
217
+ base_channels=32,
218
+ dropout=0,
219
+ kernel_size=5,
220
+ distribute_zero_label=False,
221
+ )
222
+ classifier.load_state_dict(
223
+ torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu"))
224
+ )
225
+ clip = clip.cpu().unsqueeze(0)
226
+ results = F.softmax(classifier(clip), dim=-1)
227
+ return results[0][0]
228
+
229
+
230
+ def pick_best_batch_size_for_gpu():
231
+ """
232
+ Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
233
+ you a good shot.
234
+ """
235
+ if torch.cuda.is_available():
236
+ _, available = torch.cuda.mem_get_info()
237
+ availableGb = available / (1024**3)
238
+ if availableGb > 14:
239
+ return 16
240
+ elif availableGb > 10:
241
+ return 8
242
+ elif availableGb > 7:
243
+ return 4
244
+ return 1
245
+
246
+
247
+ class TextToSpeech:
248
+ """
249
+ Main entry point into Tortoise.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ autoregressive_batch_size=None,
255
+ models_dir=MODELS_DIR,
256
+ enable_redaction=True,
257
+ device=None,
258
+ ):
259
+ """
260
+ Constructor
261
+ :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
262
+ GPU OOM errors. Larger numbers generates slightly faster.
263
+ :param models_dir: Where model weights are stored. This should only be specified if you are providing your own
264
+ models, otherwise use the defaults.
265
+ :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
266
+ (but are still rendered by the model). This can be used for prompt engineering.
267
+ Default is true.
268
+ :param device: Device to use when running the model. If omitted, the device will be automatically chosen.
269
+ """
270
+ self.models_dir = models_dir
271
+ self.autoregressive_batch_size = (
272
+ pick_best_batch_size_for_gpu()
273
+ if autoregressive_batch_size is None
274
+ else autoregressive_batch_size
275
+ )
276
+ self.enable_redaction = enable_redaction
277
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
278
+ if self.enable_redaction:
279
+ self.aligner = Wav2VecAlignment()
280
+
281
+ self.tokenizer = VoiceBpeTokenizer()
282
+
283
+ if os.path.exists(f"{models_dir}/autoregressive.ptt"):
284
+ # Assume this is a traced directory.
285
+ self.autoregressive = torch.jit.load(f"{models_dir}/autoregressive.ptt")
286
+ self.diffusion = torch.jit.load(f"{models_dir}/diffusion_decoder.ptt")
287
+ else:
288
+ self.autoregressive = (
289
+ UnifiedVoice(
290
+ max_mel_tokens=604,
291
+ max_text_tokens=402,
292
+ max_conditioning_inputs=2,
293
+ layers=30,
294
+ model_dim=1024,
295
+ heads=16,
296
+ number_text_tokens=255,
297
+ start_text_token=255,
298
+ checkpointing=False,
299
+ train_solo_embeddings=False,
300
+ )
301
+ .cpu()
302
+ .eval()
303
+ )
304
+ self.autoregressive.load_state_dict(
305
+ torch.load(get_model_path("autoregressive.pth", models_dir))
306
+ )
307
+
308
+ self.diffusion = (
309
+ DiffusionTts(
310
+ model_channels=1024,
311
+ num_layers=10,
312
+ in_channels=100,
313
+ out_channels=200,
314
+ in_latent_channels=1024,
315
+ in_tokens=8193,
316
+ dropout=0,
317
+ use_fp16=False,
318
+ num_heads=16,
319
+ layer_drop=0,
320
+ unconditioned_percentage=0,
321
+ )
322
+ .cpu()
323
+ .eval()
324
+ )
325
+ self.diffusion.load_state_dict(
326
+ torch.load(get_model_path("diffusion_decoder.pth", models_dir))
327
+ )
328
+
329
+ self.clvp = (
330
+ CLVP(
331
+ dim_text=768,
332
+ dim_speech=768,
333
+ dim_latent=768,
334
+ num_text_tokens=256,
335
+ text_enc_depth=20,
336
+ text_seq_len=350,
337
+ text_heads=12,
338
+ num_speech_tokens=8192,
339
+ speech_enc_depth=20,
340
+ speech_heads=12,
341
+ speech_seq_len=430,
342
+ use_xformers=True,
343
+ )
344
+ .cpu()
345
+ .eval()
346
+ )
347
+ self.clvp.load_state_dict(torch.load(get_model_path("clvp2.pth", models_dir)))
348
+ self.cvvp = None # CVVP model is only loaded if used.
349
+
350
+ self.vocoder = UnivNetGenerator().cpu()
351
+ self.vocoder.load_state_dict(
352
+ torch.load(
353
+ get_model_path("vocoder.pth", models_dir),
354
+ map_location=torch.device("cpu"),
355
+ )["model_g"]
356
+ )
357
+ self.vocoder.eval(inference=True)
358
+
359
+ # Random latent generators (RLGs) are loaded lazily.
360
+ self.rlg_auto = None
361
+ self.rlg_diffusion = None
362
+
363
+ def load_cvvp(self):
364
+ """Load CVVP model."""
365
+ self.cvvp = (
366
+ CVVP(
367
+ model_dim=512,
368
+ transformer_heads=8,
369
+ dropout=0,
370
+ mel_codes=8192,
371
+ conditioning_enc_depth=8,
372
+ cond_mask_percentage=0,
373
+ speech_enc_depth=8,
374
+ speech_mask_percentage=0,
375
+ latent_multiplier=1,
376
+ )
377
+ .cpu()
378
+ .eval()
379
+ )
380
+ self.cvvp.load_state_dict(
381
+ torch.load(get_model_path("cvvp.pth", self.models_dir))
382
+ )
383
+
384
+ def get_conditioning_latents(self, voice_samples, return_mels=False):
385
+ """
386
+ Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
387
+ These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
388
+ properties.
389
+ :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
390
+ """
391
+ with torch.no_grad():
392
+ voice_samples = [v.to(self.device) for v in voice_samples]
393
+
394
+ auto_conds = []
395
+ if not isinstance(voice_samples, list):
396
+ voice_samples = [voice_samples]
397
+ for vs in voice_samples:
398
+ auto_conds.append(format_conditioning(vs, device=self.device))
399
+ auto_conds = torch.stack(auto_conds, dim=1)
400
+ self.autoregressive = self.autoregressive.to(self.device)
401
+ auto_latent = self.autoregressive.get_conditioning(auto_conds)
402
+ self.autoregressive = self.autoregressive.cpu()
403
+
404
+ diffusion_conds = []
405
+ for sample in voice_samples:
406
+ # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
407
+ sample = torchaudio.functional.resample(sample, 22050, 24000)
408
+ sample = pad_or_truncate(sample, 102400)
409
+ cond_mel = wav_to_univnet_mel(
410
+ sample.to(self.device), do_normalization=False, device=self.device
411
+ )
412
+ diffusion_conds.append(cond_mel)
413
+ diffusion_conds = torch.stack(diffusion_conds, dim=1)
414
+
415
+ self.diffusion = self.diffusion.to(self.device)
416
+ diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
417
+ self.diffusion = self.diffusion.cpu()
418
+
419
+ if return_mels:
420
+ return auto_latent, diffusion_latent, auto_conds, diffusion_conds
421
+ else:
422
+ return auto_latent, diffusion_latent
423
+
424
+ def get_random_conditioning_latents(self):
425
+ # Lazy-load the RLG models.
426
+ if self.rlg_auto is None:
427
+ self.rlg_auto = RandomLatentConverter(1024).eval()
428
+ self.rlg_auto.load_state_dict(
429
+ torch.load(
430
+ get_model_path("rlg_auto.pth", self.models_dir),
431
+ map_location=torch.device("cpu"),
432
+ )
433
+ )
434
+ self.rlg_diffusion = RandomLatentConverter(2048).eval()
435
+ self.rlg_diffusion.load_state_dict(
436
+ torch.load(
437
+ get_model_path("rlg_diffuser.pth", self.models_dir),
438
+ map_location=torch.device("cpu"),
439
+ )
440
+ )
441
+ with torch.no_grad():
442
+ return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(
443
+ torch.tensor([0.0])
444
+ )
445
+
446
+ def tts_with_preset(self, text, preset="fast", **kwargs):
447
+ """
448
+ Calls TTS with one of a set of preset generation parameters. Options:
449
+ 'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
450
+ 'fast': Decent quality speech at a decent inference rate. A good choice for mass inference.
451
+ 'standard': Very good quality. This is generally about as good as you are going to get.
452
+ 'high_quality': Use if you want the absolute best. This is not really worth the compute, though.
453
+ """
454
+ # Use generally found best tuning knobs for generation.
455
+ settings = {
456
+ "temperature": 0.8,
457
+ "length_penalty": 1.0,
458
+ "repetition_penalty": 2.0,
459
+ "top_p": 0.8,
460
+ "cond_free_k": 2.0,
461
+ "diffusion_temperature": 1.0,
462
+ }
463
+ # Presets are defined here.
464
+ presets = {
465
+ "ultra_fast": {
466
+ "num_autoregressive_samples": 16,
467
+ "diffusion_iterations": 30,
468
+ "cond_free": False,
469
+ },
470
+ "fast": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
471
+ "standard": {
472
+ "num_autoregressive_samples": 256,
473
+ "diffusion_iterations": 200,
474
+ },
475
+ "high_quality": {
476
+ "num_autoregressive_samples": 256,
477
+ "diffusion_iterations": 400,
478
+ },
479
+ }
480
+ settings.update(presets[preset])
481
+ settings.update(kwargs) # allow overriding of preset settings with kwargs
482
+ return self.tts(text, **settings)
483
+
484
+ def tts(
485
+ self,
486
+ text,
487
+ voice_samples=None,
488
+ conditioning_latents=None,
489
+ k=1,
490
+ verbose=True,
491
+ use_deterministic_seed=None,
492
+ return_deterministic_state=False,
493
+ # autoregressive generation parameters follow
494
+ num_autoregressive_samples=512,
495
+ temperature=0.8,
496
+ length_penalty=1,
497
+ repetition_penalty=2.0,
498
+ top_p=0.8,
499
+ max_mel_tokens=500,
500
+ # CVVP parameters follow
501
+ cvvp_amount=0.0,
502
+ # diffusion generation parameters follow
503
+ diffusion_iterations=100,
504
+ cond_free=True,
505
+ cond_free_k=2,
506
+ diffusion_temperature=1.0,
507
+ **hf_generate_kwargs,
508
+ ):
509
+ """
510
+ Produces an audio clip of the given text being spoken with the given reference voice.
511
+ :param text: Text to be spoken.
512
+ :param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data.
513
+ :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which
514
+ can be provided in lieu of voice_samples. This is ignored unless voice_samples=None.
515
+ Conditioning latents can be retrieved via get_conditioning_latents().
516
+ :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned.
517
+ :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true.
518
+ ~~AUTOREGRESSIVE KNOBS~~
519
+ :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
520
+ As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great".
521
+ :param temperature: The softmax temperature of the autoregressive model.
522
+ :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs.
523
+ :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence
524
+ of long silences or "uhhhhhhs", etc.
525
+ :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs.
526
+ :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second.
527
+ :param typical_sampling: Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666
528
+ I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but
529
+ could use some tuning.
530
+ :param typical_mass: The typical_mass parameter from the typical_sampling algorithm.
531
+ ~~CLVP-CVVP KNOBS~~
532
+ :param cvvp_amount: Controls the influence of the CVVP model in selecting the best output from the autoregressive model.
533
+ [0,1]. Values closer to 1 mean the CVVP model is more important, 0 disables the CVVP model.
534
+ ~~DIFFUSION KNOBS~~
535
+ :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
536
+ the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
537
+ however.
538
+ :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for
539
+ each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output
540
+ of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and
541
+ dramatically improves realism.
542
+ :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
543
+ As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
544
+ Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k
545
+ :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
546
+ are the "mean" prediction of the diffusion network and will sound bland and smeared.
547
+ ~~OTHER STUFF~~
548
+ :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer.
549
+ Extra keyword args fed to this function get forwarded directly to that API. Documentation
550
+ here: https://huggingface.co/docs/transformers/internal/generation_utils
551
+ :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
552
+ Sample rate is 24kHz.
553
+ """
554
+ deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
555
+
556
+ text_tokens = (
557
+ torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
558
+ )
559
+ text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
560
+ assert (
561
+ text_tokens.shape[-1] < 400
562
+ ), "Too much text provided. Break the text up into separate segments and re-try inference."
563
+
564
+ auto_conds = None
565
+ if voice_samples is not None:
566
+ (
567
+ auto_conditioning,
568
+ diffusion_conditioning,
569
+ auto_conds,
570
+ _,
571
+ ) = self.get_conditioning_latents(voice_samples, return_mels=True)
572
+ elif conditioning_latents is not None:
573
+ auto_conditioning, diffusion_conditioning = conditioning_latents
574
+ else:
575
+ (
576
+ auto_conditioning,
577
+ diffusion_conditioning,
578
+ ) = self.get_random_conditioning_latents()
579
+ auto_conditioning = auto_conditioning.to(self.device)
580
+ diffusion_conditioning = diffusion_conditioning.to(self.device)
581
+
582
+ diffuser = load_discrete_vocoder_diffuser(
583
+ desired_diffusion_steps=diffusion_iterations,
584
+ cond_free=cond_free,
585
+ cond_free_k=cond_free_k,
586
+ )
587
+
588
+ with torch.no_grad():
589
+ samples = []
590
+ num_batches = num_autoregressive_samples // self.autoregressive_batch_size
591
+ stop_mel_token = self.autoregressive.stop_mel_token
592
+ calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
593
+ self.autoregressive = self.autoregressive.to(self.device)
594
+ if verbose:
595
+ print("Generating autoregressive samples..")
596
+ for b in tqdm(range(num_batches), disable=not verbose):
597
+ codes = self.autoregressive.inference_speech(
598
+ auto_conditioning,
599
+ text_tokens,
600
+ do_sample=True,
601
+ top_p=top_p,
602
+ temperature=temperature,
603
+ num_return_sequences=self.autoregressive_batch_size,
604
+ length_penalty=length_penalty,
605
+ repetition_penalty=repetition_penalty,
606
+ max_generate_length=max_mel_tokens,
607
+ **hf_generate_kwargs,
608
+ )
609
+ padding_needed = max_mel_tokens - codes.shape[1]
610
+ codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
611
+ samples.append(codes)
612
+ self.autoregressive = self.autoregressive.cpu()
613
+
614
+ clip_results = []
615
+ self.clvp = self.clvp.to(self.device)
616
+ if cvvp_amount > 0:
617
+ if self.cvvp is None:
618
+ self.load_cvvp()
619
+ self.cvvp = self.cvvp.to(self.device)
620
+ if verbose:
621
+ if self.cvvp is None:
622
+ print("Computing best candidates using CLVP")
623
+ else:
624
+ print(
625
+ f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
626
+ )
627
+ for batch in tqdm(samples, disable=not verbose):
628
+ for i in range(batch.shape[0]):
629
+ batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
630
+ if cvvp_amount != 1:
631
+ clvp = self.clvp(
632
+ text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False
633
+ )
634
+ if auto_conds is not None and cvvp_amount > 0:
635
+ cvvp_accumulator = 0
636
+ for cl in range(auto_conds.shape[1]):
637
+ cvvp_accumulator = cvvp_accumulator + self.cvvp(
638
+ auto_conds[:, cl].repeat(batch.shape[0], 1, 1),
639
+ batch,
640
+ return_loss=False,
641
+ )
642
+ cvvp = cvvp_accumulator / auto_conds.shape[1]
643
+ if cvvp_amount == 1:
644
+ clip_results.append(cvvp)
645
+ else:
646
+ clip_results.append(
647
+ cvvp * cvvp_amount + clvp * (1 - cvvp_amount)
648
+ )
649
+ else:
650
+ clip_results.append(clvp)
651
+ clip_results = torch.cat(clip_results, dim=0)
652
+ samples = torch.cat(samples, dim=0)
653
+ best_results = samples[torch.topk(clip_results, k=k).indices]
654
+ self.clvp = self.clvp.cpu()
655
+ if self.cvvp is not None:
656
+ self.cvvp = self.cvvp.cpu()
657
+ del samples
658
+
659
+ # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
660
+ # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
661
+ # results, but will increase memory usage.
662
+ self.autoregressive = self.autoregressive.to(self.device)
663
+ best_latents = self.autoregressive(
664
+ auto_conditioning.repeat(k, 1),
665
+ text_tokens.repeat(k, 1),
666
+ torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
667
+ best_results,
668
+ torch.tensor(
669
+ [
670
+ best_results.shape[-1]
671
+ * self.autoregressive.mel_length_compression
672
+ ],
673
+ device=text_tokens.device,
674
+ ),
675
+ return_latent=True,
676
+ clip_inputs=False,
677
+ )
678
+ self.autoregressive = self.autoregressive.cpu()
679
+ del auto_conditioning
680
+
681
+ if verbose:
682
+ print("Transforming autoregressive outputs into audio..")
683
+ wav_candidates = []
684
+ self.diffusion = self.diffusion.to(self.device)
685
+ self.vocoder = self.vocoder.to(self.device)
686
+ for b in range(best_results.shape[0]):
687
+ codes = best_results[b].unsqueeze(0)
688
+ latents = best_latents[b].unsqueeze(0)
689
+
690
+ # Find the first occurrence of the "calm" token and trim the codes to that.
691
+ ctokens = 0
692
+ for k in range(codes.shape[-1]):
693
+ if codes[0, k] == calm_token:
694
+ ctokens += 1
695
+ else:
696
+ ctokens = 0
697
+ if (
698
+ ctokens > 8
699
+ ): # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
700
+ latents = latents[:, :k]
701
+ break
702
+
703
+ mel = do_spectrogram_diffusion(
704
+ self.diffusion,
705
+ diffuser,
706
+ latents,
707
+ diffusion_conditioning,
708
+ temperature=diffusion_temperature,
709
+ verbose=verbose,
710
+ )
711
+ wav = self.vocoder.inference(mel)
712
+ wav_candidates.append(wav.cpu())
713
+ self.diffusion = self.diffusion.cpu()
714
+ self.vocoder = self.vocoder.cpu()
715
+
716
+ def potentially_redact(clip, text):
717
+ if self.enable_redaction:
718
+ return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
719
+ return clip
720
+
721
+ wav_candidates = [
722
+ potentially_redact(wav_candidate, text)
723
+ for wav_candidate in wav_candidates
724
+ ]
725
+
726
+ if len(wav_candidates) > 1:
727
+ res = wav_candidates
728
+ else:
729
+ res = wav_candidates[0]
730
+
731
+ if return_deterministic_state:
732
+ return res, (
733
+ deterministic_seed,
734
+ text,
735
+ voice_samples,
736
+ conditioning_latents,
737
+ )
738
+ else:
739
+ return res
740
+
741
+ def deterministic_state(self, seed=None):
742
+ """
743
+ Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
744
+ reproduced.
745
+ """
746
+ seed = int(time()) if seed is None else seed
747
+ torch.manual_seed(seed)
748
+ random.seed(seed)
749
+ # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
750
+ # torch.use_deterministic_algorithms(True)
751
+
752
+ return seed
tortoise/data/got.txt ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Chapter One
2
+
3
+
4
+ Bran
5
+
6
+
7
+ The morning had dawned clear and cold, with a crispness that hinted at the end of summer. They set forth at daybreak to see a man beheaded, twenty in all, and Bran rode among them, nervous with excitement. This was the first time he had been deemed old enough to go with his lord father and his brothers to see the king's justice done. It was the ninth year of summer, and the seventh of Bran's life.
8
+
9
+
10
+ The man had been taken outside a small holdfast in the hills. Robb thought he was a wildling, his sword sworn to Mance Rayder, the King-beyond-the-Wall. It made Bran's skin prickle to think of it. He remembered the hearth tales Old Nan told them. The wildlings were cruel men, she said, slavers and slayers and thieves. They consorted with giants and ghouls, stole girl children in the dead of night, and drank blood from polished horns. And their women lay with the Others in the Long Night to sire terrible half-human children.
11
+
12
+
13
+ But the man they found bound hand and foot to the holdfast wall awaiting the king's justice was old and scrawny, not much taller than Robb. He had lost both ears and a finger to frostbite, and he dressed all in black, the same as a brother of the Night's Watch, except that his furs were ragged and greasy.
14
+
15
+
16
+ The breath of man and horse mingled, steaming, in the cold morning air as his lord father had the man cut down from the wall and dragged before them. Robb and Jon sat tall and still on their horses, with Bran between them on his pony, trying to seem older than seven, trying to pretend that he'd seen all this before. A faint wind blew through the holdfast gate. Over their heads flapped the banner of the Starks of Winterfell: a grey direwolf racing across an ice-white field.
17
+
18
+ Bran's father sat solemnly on his horse, long brown hair stirring in the wind. His closely trimmed beard was shot with white, making him look older than his thirty-five years. He had a grim cast to his grey eyes this day, and he seemed not at all the man who would sit before the fire in the evening and talk softly of the age of heroes and the children of the forest. He had taken off Father's face, Bran thought, and donned the face of Lord Stark of Winterfell.
19
+
20
+
21
+ There were questions asked and answers given there in the chill of morning, but afterward Bran could not recall much of what had been said. Finally his lord father gave a command, and two of his guardsmen dragged the ragged man to the ironwood stump in the center of the square. They forced his head down onto the hard black wood. Lord Eddard Stark dismounted and his ward Theon Greyjoy brought forth the sword. "Ice," that sword was called. It was as wide across as a man's hand, and taller even than Robb. The blade was Valyrian steel, spell-forged and dark as smoke. Nothing held an edge like Valyrian steel.
22
+
23
+
24
+ His father peeled off his gloves and handed them to Jory Cassel, the captain of his household guard. He took hold of Ice with both hands and said, "In the name of Robert of the House Baratheon, the First of his Name, King of the Andals and the Rhoynar and the First Men, Lord of the Seven Kingdoms and Protector of the Realm, by the word of Eddard of the House Stark, Lord of Winterfell and Warden of the North, I do sentence you to die." He lifted the greatsword high above his head.
25
+
26
+
27
+ Bran's bastard brother Jon Snow moved closer. "Keep the pony well in hand," he whispered. "And don't look away. Father will know if you do."
28
+
29
+
30
+ Bran kept his pony well in hand, and did not look away.
31
+
32
+
33
+ His father took off the man's head with a single sure stroke. Blood sprayed out across the snow, as red as surnmerwine. One of the horses reared and had to be restrained to keep from bolting. Bran could not take his eyes off the blood. The snows around the stump drank it eagerly, reddening as he watched.
34
+
35
+ The head bounced off a thick root and rolled. It came up near Greyjoy's feet. Theon was a lean, dark youth of nineteen who found everything amusing. He laughed, put his boot on the head, and kicked it away.
36
+
37
+
38
+ "Ass," Jon muttered, low enough so Greyjoy did not hear. He put a hand on Bran's shoulder, and Bran looked over at his bastard brother. "You did well," Jon told him solemnly. Jon was fourteen, an old hand at justice.
39
+
40
+
41
+ It seemed colder on the long ride back to Winterfell, though the wind had died by then and the sun was higher in the sky. Bran rode with his brothers, well ahead of the main party, his pony struggling hard to keep up with their horses.
42
+
43
+
44
+ "The deserter died bravely," Robb said. He was big and broad and growing every day, with his mother's coloring, the fair skin, red-brown hair, and blue eyes of the Tullys of Riverrun. "He had courage, at the least."
45
+
46
+
47
+ "No," Jon Snow said quietly. "It was not courage. This one was dead of fear. You could see it in his eyes, Stark." Jon's eyes were a grey so dark they seemed almost black, but there was little they did not see. He was of an age with Robb, but they did not look alike. Jon was slender where Robb was muscular, dark where Robb was fair, graceful and quick where his half brother was strong and fast.
48
+
49
+
50
+ Robb was not impressed. "The Others take his eyes," he swore. "He died well. Race you to the bridge?"
51
+
52
+
53
+ "Done," Jon said, kicking his horse forward. Robb cursed and followed, and they galloped off down the trail, Robb laughing and hooting, Jon silent and intent. The hooves of their horses kicked up showers of snow as they went.
54
+
55
+ Bran did not try to follow. His pony could not keep up. He had seen the ragged man's eyes, and he was thinking of them now. After a while, the sound of Robb's laughter receded, and the woods grew silent again.
56
+
57
+
58
+ So deep in thought was he that he never heard the rest of the party until his father moved up to ride beside him. "Are you well, Bran?" he asked, not unkindly.
59
+
60
+
61
+ "Yes, Father," Bran told him. He looked up. Wrapped in his furs and leathers, mounted on his great warhorse, his lord father loomed over him like a giant. "Robb says the man died bravely, but Jon says he was afraid."
62
+
63
+
64
+ "What do you think?" his father asked.
65
+
66
+
67
+ Bran thought about it. "Can a man still be brave if he's afraid?"
68
+
69
+
70
+ "That is the only time a man can be brave," his father told him. "Do you understand why I did it?"
71
+
72
+
73
+ "He was a wildling," Bran said. "They carry off women and sell them to the Others."
74
+
75
+
76
+ His lord father smiled. "Old Nan has been telling you stories again. In truth, the man was an oathbreaker, a deserter from the Night's Watch. No man is more dangerous. The deserter knows his life is forfeit if he is taken, so he will not flinch from any crime, no matter how vile. But you mistake me. The question was not why the man had to die, but why I must do it."
77
+
78
+
79
+ Bran had no answer for that. "King Robert has a headsman," he said, uncertainly.
80
+
81
+
82
+ "He does," his father admitted. "As did the Targaryen kings before him. Yet our way is the older way. The blood of the First Men still flows in the veins of the Starks, and we hold to the belief that the man who passes the sentence should swing the sword. If you would take a man's life, you owe it to him to look into his eyes and hear his final words. And if you cannot bear to do that, then perhaps the man does not deserve to die.
83
+
84
+
85
+ "One day, Bran, you will be Robb's bannerman, holding a keep of your own for your brother and your king, and justice will fall to you. When that day comes, you must take no pleasure in the task, but neither must you look away. A ruler who hides behind paid executioners soon forgets what death is."
86
+
87
+
88
+ That was when Jon reappeared on the crest of the hill before them. He waved and shouted down at them. "Father, Bran, come quickly, see what Robb has found!" Then he was gone again.
89
+
90
+
91
+ Jory rode up beside them. "Trouble, my lord?"
92
+
93
+
94
+ "Beyond a doubt," his lord father said. "Come, let us see what mischief my sons have rooted out now." He sent his horse into a trot. Jory and Bran and the rest came after.
95
+
96
+
97
+ They found Robb on the riverbank north of the bridge, with Jon still mounted beside him. The late summer snows had been heavy this moonturn. Robb stood knee-deep in white, his hood pulled back so the sun shone in his hair. He was cradling something in his arm, while the boys talked in hushed, excited voices.
98
+
99
+
100
+ The riders picked their way carefully through the drifts, groping for solid footing on the hidden, uneven ground . Jory Cassel and Theon Greyjoy were the first to reach the boys. Greyjoy was laughing and joking as he rode. Bran heard the breath go out of him. "Gods!" he exclaimed, struggling to keep control of his horse as he reached for his sword.
101
+
102
+
103
+ Jory's sword was already out. "Robb, get away from it!" he called as his horse reared under him.
104
+
105
+
106
+ Robb grinned and looked up from the bundle in his arms. "She can't hurt you," he said. "She's dead, Jory."
107
+
108
+
109
+ Bran was afire with curiosity by then. He would have spurred the pony faster, but his father made them dismount beside the bridge and approach on foot. Bran jumped off and ran.
110
+
111
+
112
+ By then Jon, Jory, and Theon Greyjoy had all dismounted as well. "What in the seven hells is it?" Greyjoy was saying.
113
+
114
+
115
+ "A wolf," Robb told him.
116
+
117
+
118
+ "A freak," Greyjoy said. "Look at the size of it."
119
+
120
+
121
+ Bran's heart was thumping in his chest as he pushed through a waist-high drift to his brothers' side.
122
+
123
+
124
+ Half-buried in bloodstained snow, a huge dark shape slumped in death. Ice had formed in its shaggy grey fur, and the faint smell of corruption clung to it like a woman's perfume. Bran glimpsed blind eyes crawling with maggots, a wide mouth full of yellowed teeth. But it was the size of it that made him gasp. It was bigger than his pony, twice the size of the largest hound in his father's kennel.
125
+
126
+
127
+ "It's no freak," Jon said calmly. "That's a direwolf. They grow larger than the other kind."
128
+
129
+
130
+ Theon Greyjoy said, "There's not been a direwolf sighted south of the Wall in two hundred years."
131
+
132
+
133
+ "I see one now," Jon replied.
134
+
135
+
136
+ Bran tore his eyes away from the monster. That was when he noticed the bundle in Robb's arms. He gave a cry of delight and moved closer. The pup was a tiny ball of grey-black fur, its eyes still closed. It nuzzled blindly against Robb's chest as he cradled it, searching for milk among his leathers, making a sad little whimpery sound. Bran reached out hesitantly. "Go on," Robb told him. "You can touch him."
137
+
138
+
139
+ Bran gave the pup a quick nervous stroke, then turned as Jon said, "Here you go." His half brother put a second pup into his arms. "There are five of them." Bran sat down in the snow and hugged the wolf pup to his face. Its fur was soft and warm against his cheek.
140
+
141
+
142
+ "Direwolves loose in the realm, after so many years," muttered Hullen, the master of horse. "I like it not."
143
+
144
+
145
+ "It is a sign," Jory said.
146
+
147
+
148
+ Father frowned. "This is only a dead animal, Jory," he said. Yet he seemed troubled. Snow crunched under his boots as he moved around the body. "Do we know what killed her?"
149
+
150
+
151
+ "There's something in the throat," Robb told him, proud to have found the answer before his father even asked. "There, just under the jaw."
152
+
153
+
154
+ His father knelt and groped under the beast's head with his hand. He gave a yank and held it up for all to see. A foot of shattered antler, tines snapped off, all wet with blood.
155
+
156
+
157
+ A sudden silence descended over the party. The men looked at the antler uneasily, and no one dared to speak. Even Bran could sense their fear, though he did not understand.
158
+
159
+
160
+ His father tossed the antler to the side and cleansed his hands in the snow. "I'm surprised she lived long enough to whelp," he said. His voice broke the spell.
161
+
162
+
163
+ "Maybe she didn't," Jory said. "I've heard tales . . . maybe the bitch was already dead when the pups came."
164
+
165
+
166
+ "Born with the dead," another man put in. "Worse luck."
167
+
168
+
169
+ "No matter," said Hullen. "They be dead soon enough too."
170
+
171
+
172
+ Bran gave a wordless cry of dismay.
173
+
174
+
175
+ "The sooner the better," Theon Greyjoy agreed. He drew his sword. "Give the beast here, Bran."
176
+
177
+
178
+ The little thing squirmed against him, as if it heard and understood. "No!" Bran cried out fiercely. "It's mine."
179
+
180
+
181
+ "Put away your sword, Greyjoy," Robb said. For a moment he sounded as commanding as their father, like the lord he would someday be. "We will keep these pups."
182
+
183
+
184
+ "You cannot do that, boy," said Harwin, who was Hullen's son.
185
+
186
+
187
+ "It be a mercy to kill them," Hullen said.
188
+
189
+
190
+ Bran looked to his lord father for rescue, but got only a frown, a furrowed brow. "Hullen speaks truly, son. Better a swift death than a hard one from cold and starvation."
191
+
192
+
193
+ "No!" He could feel tears welling in his eyes, and he looked away. He did not want to cry in front of his father.
194
+
195
+
196
+ Robb resisted stubbornly. "Ser Rodrik's red bitch whelped again last week," he said. "It was a small litter, only two live pups. She'll have milk enough."
197
+
198
+
199
+ "She'll rip them apart when they try to nurse."
200
+
201
+
202
+ "Lord Stark," Jon said. It was strange to hear him call Father that, so formal. Bran looked at him with desperate hope. "There are five pups," he told Father. "Three male, two female."
203
+
204
+
205
+ "What of it, Jon?"
206
+
207
+
208
+ "You have five trueborn children," Jon said. "Three sons, two daughters. The direwolf is the sigil of your House. Your children were meant to have these pups, my lord."
209
+
210
+
211
+ Bran saw his father's face change, saw the other men exchange glances. He loved Jon with all his heart at that moment. Even at seven, Bran understood what his brother had done. The count had come right only because Jon had omitted himself. He had included the girls, included even Rickon, the baby, but not the bastard who bore the surname Snow, the name that custom decreed be given to all those in the north unlucky enough to be born with no name of their own.
212
+
213
+
214
+ Their father understood as well. "You want no pup for yourself, Jon?" he asked softly.
215
+
216
+
217
+ "The direwolf graces the banners of House Stark," Jon pointed out. "I am no Stark, Father."
218
+
219
+
220
+ Their lord father regarded Jon thoughtfully. Robb rushed into the silence he left. "I will nurse him myself, Father," he promised. "I will soak a towel with warm milk, and give him suck from that."
221
+
222
+
223
+ "Me too!" Bran echoed.
224
+
225
+
226
+ The lord weighed his sons long and carefully with his eyes. "Easy to say, and harder to do. I will not have you wasting the servants' time with this. If you want these pups, you will feed them yourselves. Is that understood?"
227
+
228
+
229
+ Bran nodded eagerly. The pup squirmed in his grasp, licked at his face with a warm tongue.
230
+
231
+
232
+ "You must train them as well," their father said. "You must train them. The kennelmaster will have nothing to do with these monsters, I promise you that. And the gods help you if you neglect them, or brutalize them, or train them badly. These are not dogs to beg for treats and slink off at a kick. A direwolf will rip a man's arm off his shoulder as easily as a dog will kill a rat. Are you sure you want this?"
233
+
234
+ "Yes, Father," Bran said.
235
+
236
+
237
+ "Yes," Robb agreed.
238
+
239
+
240
+ "The pups may die anyway, despite all you do."
241
+
242
+
243
+ "They won't die," Robb said. "We won't let them die."
244
+
245
+
246
+ "Keep them, then. Jory, Desmond, gather up the other pups. It's time we were back to Winterfell."
247
+
248
+
249
+ It was not until they were mounted and on their way that Bran allowed himself to taste the sweet air of victory. By then, his pup was snuggled inside his leathers, warm against him, safe for the long ride home. Bran was wondering what to name him.
250
+
251
+
252
+ Halfway across the bridge, Jon pulled up suddenly.
253
+
254
+
255
+ "What is it, Jon?" their lord father asked.
256
+
257
+
258
+ "Can't you hear it?"
259
+
260
+
261
+ Bran could hear the wind in the trees, the clatter of their hooves on the ironwood planks, the whimpering of his hungry pup, but Jon was listening to something else.
262
+
263
+
264
+ "There," Jon said. He swung his horse around and galloped back across the bridge. They watched him dismount where the direwolf lay dead in the snow, watched him kneel. A moment later he was riding back to them, smiling.
265
+
266
+
267
+ "He must have crawled away from the others," Jon said.
268
+
269
+
270
+ "Or been driven away," their father said, looking at the sixth pup. His fur was white, where the rest of the litter was grey. His eyes were as red as the blood of the ragged man who had died that morning. Bran thought it curious that this pup alone would have opened his eyes while the others were still blind.
271
+
272
+
273
+ "An albino," Theon Greyjoy said with wry amusement. "This one will die even faster than the others."
274
+
275
+
276
+ Jon Snow gave his father's ward a long, chilling look. "I think not, Greyjoy," he said. "This one belongs to me."
tortoise/data/layman.txt ADDED
File without changes
tortoise/data/mel_norms.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f69422a8a8f344c4fca2f0c6b8d41d2151d6615b7321e48e6bb15ae949b119c
3
+ size 1067
tortoise/data/riding_hood.txt ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her. It suited the girl so extremely well that everybody called her Little Red Riding Hood.
2
+ One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter."
3
+
4
+ Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village.
5
+
6
+ As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest. He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother."
7
+
8
+ "Does she live far off?" said the wolf
9
+
10
+ "Oh I say," answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village."
11
+
12
+ "Well," said the wolf, "and I'll go and see her too. I'll go this way and go you that, and we shall see who will be there first."
13
+
14
+ The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers. It was not long before the wolf arrived at the old woman's house. He knocked at the door: tap, tap.
15
+
16
+ "Who's there?"
17
+
18
+ "Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother."
19
+
20
+ The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up."
21
+
22
+ The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten. He then shut the door and got into the grandmother's bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap.
23
+
24
+ "Who's there?"
25
+
26
+ Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you."
27
+
28
+ The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up."
29
+
30
+ Little Red Riding Hood pulled the bobbin, and the door opened.
31
+
32
+ The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me."
33
+
34
+ Little Red Riding Hood took off her clothes and got into bed. She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!"
35
+
36
+ "All the better to hug you with, my dear."
37
+
38
+ "Grandmother, what big legs you have!"
39
+
40
+ "All the better to run with, my child."
41
+
42
+ "Grandmother, what big ears you have!"
43
+
44
+ "All the better to hear with, my child."
45
+
46
+ "Grandmother, what big eyes you have!"
47
+
48
+ "All the better to see with, my child."
49
+
50
+ "Grandmother, what big teeth you have got!"
51
+
52
+ "All the better to eat you up with."
53
+
54
+ And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up.
tortoise/data/seal_copypasta.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al kayda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire U S armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the U S A and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little "clever" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo.
tortoise/data/tokenizer.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
tortoise/do_tts.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ import torchaudio
6
+
7
+ from api import TextToSpeech, MODELS_DIR
8
+ from utils.audio import load_voices
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--text",
14
+ type=str,
15
+ help="Text to speak.",
16
+ default="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.",
17
+ )
18
+ parser.add_argument(
19
+ "--voice",
20
+ type=str,
21
+ help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) "
22
+ "Use the & character to join two voices together. Use a comma to perform inference on multiple voices.",
23
+ default="random",
24
+ )
25
+ parser.add_argument(
26
+ "--preset", type=str, help="Which voice preset to use.", default="fast"
27
+ )
28
+ parser.add_argument(
29
+ "--output_path", type=str, help="Where to store outputs.", default="results/"
30
+ )
31
+ parser.add_argument(
32
+ "--model_dir",
33
+ type=str,
34
+ help="Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this"
35
+ "should only be specified if you have custom checkpoints.",
36
+ default=MODELS_DIR,
37
+ )
38
+ parser.add_argument(
39
+ "--candidates",
40
+ type=int,
41
+ help="How many output candidates to produce per-voice.",
42
+ default=3,
43
+ )
44
+ parser.add_argument(
45
+ "--seed",
46
+ type=int,
47
+ help="Random seed which can be used to reproduce results.",
48
+ default=None,
49
+ )
50
+ parser.add_argument(
51
+ "--produce_debug_state",
52
+ type=bool,
53
+ help="Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.",
54
+ default=True,
55
+ )
56
+ parser.add_argument(
57
+ "--cvvp_amount",
58
+ type=float,
59
+ help="How much the CVVP model should influence the output."
60
+ "Increasing this can in some cases reduce the likelyhood of multiple speakers. Defaults to 0 (disabled)",
61
+ default=0.0,
62
+ )
63
+ args = parser.parse_args()
64
+ os.makedirs(args.output_path, exist_ok=True)
65
+
66
+ tts = TextToSpeech(models_dir=args.model_dir)
67
+
68
+ selected_voices = args.voice.split(",")
69
+ for k, selected_voice in enumerate(selected_voices):
70
+ if "&" in selected_voice:
71
+ voice_sel = selected_voice.split("&")
72
+ else:
73
+ voice_sel = [selected_voice]
74
+ voice_samples, conditioning_latents = load_voices(voice_sel)
75
+
76
+ gen, dbg_state = tts.tts_with_preset(
77
+ args.text,
78
+ k=args.candidates,
79
+ voice_samples=voice_samples,
80
+ conditioning_latents=conditioning_latents,
81
+ preset=args.preset,
82
+ use_deterministic_seed=args.seed,
83
+ return_deterministic_state=True,
84
+ cvvp_amount=args.cvvp_amount,
85
+ )
86
+ if isinstance(gen, list):
87
+ for j, g in enumerate(gen):
88
+ torchaudio.save(
89
+ os.path.join(args.output_path, f"{selected_voice}_{k}_{j}.wav"),
90
+ g.squeeze(0).cpu(),
91
+ 24000,
92
+ )
93
+ else:
94
+ torchaudio.save(
95
+ os.path.join(args.output_path, f"{selected_voice}_{k}.wav"),
96
+ gen.squeeze(0).cpu(),
97
+ 24000,
98
+ )
99
+
100
+ if args.produce_debug_state:
101
+ os.makedirs("debug_states", exist_ok=True)
102
+ torch.save(dbg_state, f"debug_states/do_tts_debug_{selected_voice}.pth")
tortoise/eval.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torchaudio
5
+
6
+ from api import TextToSpeech
7
+ from tortoise.utils.audio import load_audio
8
+
9
+ if __name__ == "__main__":
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "--eval_path",
13
+ type=str,
14
+ help="Path to TSV test file",
15
+ default="D:\\tmp\\tortoise-tts-eval\\test.tsv",
16
+ )
17
+ parser.add_argument(
18
+ "--output_path",
19
+ type=str,
20
+ help="Where to put results",
21
+ default="D:\\tmp\\tortoise-tts-eval\\baseline",
22
+ )
23
+ parser.add_argument(
24
+ "--preset", type=str, help="Rendering preset.", default="standard"
25
+ )
26
+ args = parser.parse_args()
27
+ os.makedirs(args.output_path, exist_ok=True)
28
+
29
+ tts = TextToSpeech()
30
+
31
+ with open(args.eval_path, "r", encoding="utf-8") as f:
32
+ lines = f.readlines()
33
+
34
+ for line in lines:
35
+ text, real = line.strip().split("\t")
36
+ conds = [load_audio(real, 22050)]
37
+ gen = tts.tts_with_preset(
38
+ text, voice_samples=conds, conditioning_latents=None, preset=args.preset
39
+ )
40
+ torchaudio.save(
41
+ os.path.join(args.output_path, os.path.basename(real)),
42
+ gen.squeeze(0).cpu(),
43
+ 24000,
44
+ )
tortoise/get_conditioning_latents.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+
5
+ from api import TextToSpeech
6
+ from tortoise.utils.audio import load_audio, get_voices
7
+
8
+ """
9
+ Dumps the conditioning latents for the specified voice to disk. These are expressive latents which can be used for
10
+ other ML models, or can be augmented manually and fed back into Tortoise to affect vocal qualities.
11
+ """
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument(
15
+ "--voice",
16
+ type=str,
17
+ help="Selects the voice to convert to conditioning latents",
18
+ default="pat2",
19
+ )
20
+ parser.add_argument(
21
+ "--output_path",
22
+ type=str,
23
+ help="Where to store outputs.",
24
+ default="../results/conditioning_latents",
25
+ )
26
+ args = parser.parse_args()
27
+ os.makedirs(args.output_path, exist_ok=True)
28
+
29
+ tts = TextToSpeech()
30
+ voices = get_voices()
31
+ selected_voices = args.voice.split(",")
32
+ for voice in selected_voices:
33
+ cond_paths = voices[voice]
34
+ conds = []
35
+ for cond_path in cond_paths:
36
+ c = load_audio(cond_path, 22050)
37
+ conds.append(c)
38
+ conditioning_latents = tts.get_conditioning_latents(conds)
39
+ torch.save(conditioning_latents, os.path.join(args.output_path, f"{voice}.pth"))
tortoise/is_this_from_tortoise.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from api import classify_audio_clip
4
+ from tortoise.utils.audio import load_audio
5
+
6
+ if __name__ == "__main__":
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument(
9
+ "--clip",
10
+ type=str,
11
+ help="Path to an audio clip to classify.",
12
+ default="../examples/favorite_riding_hood.mp3",
13
+ )
14
+ args = parser.parse_args()
15
+
16
+ clip = load_audio(args.clip, 24000)
17
+ clip = clip[:, :220000]
18
+ prob = classify_audio_clip(clip)
19
+ print(
20
+ f"This classifier thinks there is a {prob*100}% chance that this clip was generated from Tortoise."
21
+ )
tortoise/models/__init__.py ADDED
File without changes
tortoise/models/arch_util.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from tortoise.models.xtransformers import (
10
+ ContinuousTransformerWrapper,
11
+ RelativePositionBias,
12
+ )
13
+
14
+
15
+ def zero_module(module):
16
+ """
17
+ Zero out the parameters of a module and return it.
18
+ """
19
+ for p in module.parameters():
20
+ p.detach().zero_()
21
+ return module
22
+
23
+
24
+ class GroupNorm32(nn.GroupNorm):
25
+ def forward(self, x):
26
+ return super().forward(x.float()).type(x.dtype)
27
+
28
+
29
+ def normalization(channels):
30
+ """
31
+ Make a standard normalization layer.
32
+
33
+ :param channels: number of input channels.
34
+ :return: an nn.Module for normalization.
35
+ """
36
+ groups = 32
37
+ if channels <= 16:
38
+ groups = 8
39
+ elif channels <= 64:
40
+ groups = 16
41
+ while channels % groups != 0:
42
+ groups = int(groups / 2)
43
+ assert groups > 2
44
+ return GroupNorm32(groups, channels)
45
+
46
+
47
+ class QKVAttentionLegacy(nn.Module):
48
+ """
49
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
50
+ """
51
+
52
+ def __init__(self, n_heads):
53
+ super().__init__()
54
+ self.n_heads = n_heads
55
+
56
+ def forward(self, qkv, mask=None, rel_pos=None):
57
+ """
58
+ Apply QKV attention.
59
+
60
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
61
+ :return: an [N x (H * C) x T] tensor after attention.
62
+ """
63
+ bs, width, length = qkv.shape
64
+ assert width % (3 * self.n_heads) == 0
65
+ ch = width // (3 * self.n_heads)
66
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
67
+ scale = 1 / math.sqrt(math.sqrt(ch))
68
+ weight = torch.einsum(
69
+ "bct,bcs->bts", q * scale, k * scale
70
+ ) # More stable with f16 than dividing afterwards
71
+ if rel_pos is not None:
72
+ weight = rel_pos(
73
+ weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])
74
+ ).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
75
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
76
+ if mask is not None:
77
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
78
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
79
+ weight = weight * mask
80
+ a = torch.einsum("bts,bcs->bct", weight, v)
81
+
82
+ return a.reshape(bs, -1, length)
83
+
84
+
85
+ class AttentionBlock(nn.Module):
86
+ """
87
+ An attention block that allows spatial positions to attend to each other.
88
+
89
+ Originally ported from here, but adapted to the N-d case.
90
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ channels,
96
+ num_heads=1,
97
+ num_head_channels=-1,
98
+ do_checkpoint=True,
99
+ relative_pos_embeddings=False,
100
+ ):
101
+ super().__init__()
102
+ self.channels = channels
103
+ self.do_checkpoint = do_checkpoint
104
+ if num_head_channels == -1:
105
+ self.num_heads = num_heads
106
+ else:
107
+ assert (
108
+ channels % num_head_channels == 0
109
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
110
+ self.num_heads = channels // num_head_channels
111
+ self.norm = normalization(channels)
112
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
113
+ # split heads before split qkv
114
+ self.attention = QKVAttentionLegacy(self.num_heads)
115
+
116
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
117
+ if relative_pos_embeddings:
118
+ self.relative_pos_embeddings = RelativePositionBias(
119
+ scale=(channels // self.num_heads) ** 0.5,
120
+ causal=False,
121
+ heads=num_heads,
122
+ num_buckets=32,
123
+ max_distance=64,
124
+ )
125
+ else:
126
+ self.relative_pos_embeddings = None
127
+
128
+ def forward(self, x, mask=None):
129
+ b, c, *spatial = x.shape
130
+ x = x.reshape(b, c, -1)
131
+ qkv = self.qkv(self.norm(x))
132
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
133
+ h = self.proj_out(h)
134
+ return (x + h).reshape(b, c, *spatial)
135
+
136
+
137
+ class Upsample(nn.Module):
138
+ """
139
+ An upsampling layer with an optional convolution.
140
+
141
+ :param channels: channels in the inputs and outputs.
142
+ :param use_conv: a bool determining if a convolution is applied.
143
+ """
144
+
145
+ def __init__(self, channels, use_conv, out_channels=None, factor=4):
146
+ super().__init__()
147
+ self.channels = channels
148
+ self.out_channels = out_channels or channels
149
+ self.use_conv = use_conv
150
+ self.factor = factor
151
+ if use_conv:
152
+ ksize = 5
153
+ pad = 2
154
+ self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
155
+
156
+ def forward(self, x):
157
+ assert x.shape[1] == self.channels
158
+ x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
159
+ if self.use_conv:
160
+ x = self.conv(x)
161
+ return x
162
+
163
+
164
+ class Downsample(nn.Module):
165
+ """
166
+ A downsampling layer with an optional convolution.
167
+
168
+ :param channels: channels in the inputs and outputs.
169
+ :param use_conv: a bool determining if a convolution is applied.
170
+ """
171
+
172
+ def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
173
+ super().__init__()
174
+ self.channels = channels
175
+ self.out_channels = out_channels or channels
176
+ self.use_conv = use_conv
177
+
178
+ stride = factor
179
+ if use_conv:
180
+ self.op = nn.Conv1d(
181
+ self.channels, self.out_channels, ksize, stride=stride, padding=pad
182
+ )
183
+ else:
184
+ assert self.channels == self.out_channels
185
+ self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
186
+
187
+ def forward(self, x):
188
+ assert x.shape[1] == self.channels
189
+ return self.op(x)
190
+
191
+
192
+ class ResBlock(nn.Module):
193
+ def __init__(
194
+ self,
195
+ channels,
196
+ dropout,
197
+ out_channels=None,
198
+ use_conv=False,
199
+ use_scale_shift_norm=False,
200
+ up=False,
201
+ down=False,
202
+ kernel_size=3,
203
+ ):
204
+ super().__init__()
205
+ self.channels = channels
206
+ self.dropout = dropout
207
+ self.out_channels = out_channels or channels
208
+ self.use_conv = use_conv
209
+ self.use_scale_shift_norm = use_scale_shift_norm
210
+ padding = 1 if kernel_size == 3 else 2
211
+
212
+ self.in_layers = nn.Sequential(
213
+ normalization(channels),
214
+ nn.SiLU(),
215
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
216
+ )
217
+
218
+ self.updown = up or down
219
+
220
+ if up:
221
+ self.h_upd = Upsample(channels, False)
222
+ self.x_upd = Upsample(channels, False)
223
+ elif down:
224
+ self.h_upd = Downsample(channels, False)
225
+ self.x_upd = Downsample(channels, False)
226
+ else:
227
+ self.h_upd = self.x_upd = nn.Identity()
228
+
229
+ self.out_layers = nn.Sequential(
230
+ normalization(self.out_channels),
231
+ nn.SiLU(),
232
+ nn.Dropout(p=dropout),
233
+ zero_module(
234
+ nn.Conv1d(
235
+ self.out_channels, self.out_channels, kernel_size, padding=padding
236
+ )
237
+ ),
238
+ )
239
+
240
+ if self.out_channels == channels:
241
+ self.skip_connection = nn.Identity()
242
+ elif use_conv:
243
+ self.skip_connection = nn.Conv1d(
244
+ channels, self.out_channels, kernel_size, padding=padding
245
+ )
246
+ else:
247
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
248
+
249
+ def forward(self, x):
250
+ if self.updown:
251
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
252
+ h = in_rest(x)
253
+ h = self.h_upd(h)
254
+ x = self.x_upd(x)
255
+ h = in_conv(h)
256
+ else:
257
+ h = self.in_layers(x)
258
+ h = self.out_layers(h)
259
+ return self.skip_connection(x) + h
260
+
261
+
262
+ class AudioMiniEncoder(nn.Module):
263
+ def __init__(
264
+ self,
265
+ spec_dim,
266
+ embedding_dim,
267
+ base_channels=128,
268
+ depth=2,
269
+ resnet_blocks=2,
270
+ attn_blocks=4,
271
+ num_attn_heads=4,
272
+ dropout=0,
273
+ downsample_factor=2,
274
+ kernel_size=3,
275
+ ):
276
+ super().__init__()
277
+ self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
278
+ ch = base_channels
279
+ res = []
280
+ for l in range(depth):
281
+ for r in range(resnet_blocks):
282
+ res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
283
+ res.append(
284
+ Downsample(
285
+ ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor
286
+ )
287
+ )
288
+ ch *= 2
289
+ self.res = nn.Sequential(*res)
290
+ self.final = nn.Sequential(
291
+ normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)
292
+ )
293
+ attn = []
294
+ for a in range(attn_blocks):
295
+ attn.append(
296
+ AttentionBlock(
297
+ embedding_dim,
298
+ num_attn_heads,
299
+ )
300
+ )
301
+ self.attn = nn.Sequential(*attn)
302
+ self.dim = embedding_dim
303
+
304
+ def forward(self, x):
305
+ h = self.init(x)
306
+ h = self.res(h)
307
+ h = self.final(h)
308
+ h = self.attn(h)
309
+ return h[:, :, 0]
310
+
311
+
312
+ DEFAULT_MEL_NORM_FILE = os.path.join(
313
+ os.path.dirname(os.path.realpath(__file__)), "../data/mel_norms.pth"
314
+ )
315
+
316
+
317
+ class TorchMelSpectrogram(nn.Module):
318
+ def __init__(
319
+ self,
320
+ filter_length=1024,
321
+ hop_length=256,
322
+ win_length=1024,
323
+ n_mel_channels=80,
324
+ mel_fmin=0,
325
+ mel_fmax=8000,
326
+ sampling_rate=22050,
327
+ normalize=False,
328
+ mel_norm_file=DEFAULT_MEL_NORM_FILE,
329
+ ):
330
+ super().__init__()
331
+ # These are the default tacotron values for the MEL spectrogram.
332
+ self.filter_length = filter_length
333
+ self.hop_length = hop_length
334
+ self.win_length = win_length
335
+ self.n_mel_channels = n_mel_channels
336
+ self.mel_fmin = mel_fmin
337
+ self.mel_fmax = mel_fmax
338
+ self.sampling_rate = sampling_rate
339
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
340
+ n_fft=self.filter_length,
341
+ hop_length=self.hop_length,
342
+ win_length=self.win_length,
343
+ power=2,
344
+ normalized=normalize,
345
+ sample_rate=self.sampling_rate,
346
+ f_min=self.mel_fmin,
347
+ f_max=self.mel_fmax,
348
+ n_mels=self.n_mel_channels,
349
+ norm="slaney",
350
+ )
351
+ self.mel_norm_file = mel_norm_file
352
+ if self.mel_norm_file is not None:
353
+ self.mel_norms = torch.load(self.mel_norm_file)
354
+ else:
355
+ self.mel_norms = None
356
+
357
+ def forward(self, inp):
358
+ if (
359
+ len(inp.shape) == 3
360
+ ): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
361
+ inp = inp.squeeze(1)
362
+ assert len(inp.shape) == 2
363
+ self.mel_stft = self.mel_stft.to(inp.device)
364
+ mel = self.mel_stft(inp)
365
+ # Perform dynamic range compression
366
+ mel = torch.log(torch.clamp(mel, min=1e-5))
367
+ if self.mel_norms is not None:
368
+ self.mel_norms = self.mel_norms.to(mel.device)
369
+ mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
370
+ return mel
371
+
372
+
373
+ class CheckpointedLayer(nn.Module):
374
+ """
375
+ Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
376
+ checkpoint for all other args.
377
+ """
378
+
379
+ def __init__(self, wrap):
380
+ super().__init__()
381
+ self.wrap = wrap
382
+
383
+ def forward(self, x, *args, **kwargs):
384
+ for k, v in kwargs.items():
385
+ assert not (
386
+ isinstance(v, torch.Tensor) and v.requires_grad
387
+ ) # This would screw up checkpointing.
388
+ partial = functools.partial(self.wrap, **kwargs)
389
+ return partial(x, *args)
390
+
391
+
392
+ class CheckpointedXTransformerEncoder(nn.Module):
393
+ """
394
+ Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
395
+ to channels-last that XTransformer expects.
396
+ """
397
+
398
+ def __init__(
399
+ self,
400
+ needs_permute=True,
401
+ exit_permute=True,
402
+ checkpoint=True,
403
+ **xtransformer_kwargs,
404
+ ):
405
+ super().__init__()
406
+ self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
407
+ self.needs_permute = needs_permute
408
+ self.exit_permute = exit_permute
409
+
410
+ if not checkpoint:
411
+ return
412
+ for i in range(len(self.transformer.attn_layers.layers)):
413
+ n, b, r = self.transformer.attn_layers.layers[i]
414
+ self.transformer.attn_layers.layers[i] = nn.ModuleList(
415
+ [n, CheckpointedLayer(b), r]
416
+ )
417
+
418
+ def forward(self, x, **kwargs):
419
+ if self.needs_permute:
420
+ x = x.permute(0, 2, 1)
421
+ h = self.transformer(x, **kwargs)
422
+ if self.exit_permute:
423
+ h = h.permute(0, 2, 1)
424
+ return h
tortoise/models/autoregressive.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
9
+ from tortoise.models.arch_util import AttentionBlock
10
+ from tortoise.utils.typical_sampling import TypicalLogitsWarper
11
+
12
+
13
+ def null_position_embeddings(range, dim):
14
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
15
+
16
+
17
+ class ResBlock(nn.Module):
18
+ """
19
+ Basic residual convolutional block that uses GroupNorm.
20
+ """
21
+
22
+ def __init__(self, chan):
23
+ super().__init__()
24
+ self.net = nn.Sequential(
25
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
26
+ nn.GroupNorm(chan // 8, chan),
27
+ nn.ReLU(),
28
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
29
+ nn.GroupNorm(chan // 8, chan),
30
+ )
31
+
32
+ def forward(self, x):
33
+ return F.relu(self.net(x) + x)
34
+
35
+
36
+ class GPT2InferenceModel(GPT2PreTrainedModel):
37
+ def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
38
+ super().__init__(config)
39
+ self.transformer = gpt
40
+ self.text_pos_embedding = text_pos_emb
41
+ self.embeddings = embeddings
42
+ self.lm_head = nn.Sequential(norm, linear)
43
+
44
+ # Model parallel
45
+ self.model_parallel = False
46
+ self.device_map = None
47
+ self.cached_mel_emb = None
48
+
49
+ def parallelize(self, device_map=None):
50
+ self.device_map = (
51
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
52
+ if device_map is None
53
+ else device_map
54
+ )
55
+ assert_device_map(self.device_map, len(self.transformer.h))
56
+ self.transformer.parallelize(self.device_map)
57
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
58
+ self.model_parallel = True
59
+
60
+ def deparallelize(self):
61
+ self.transformer.deparallelize()
62
+ self.transformer = self.transformer.to("cpu")
63
+ self.lm_head = self.lm_head.to("cpu")
64
+ self.model_parallel = False
65
+ torch.cuda.empty_cache()
66
+
67
+ def get_output_embeddings(self):
68
+ return self.lm_head
69
+
70
+ def set_output_embeddings(self, new_embeddings):
71
+ self.lm_head = new_embeddings
72
+
73
+ def store_mel_emb(self, mel_emb):
74
+ self.cached_mel_emb = mel_emb
75
+
76
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
77
+
78
+ token_type_ids = kwargs.get("token_type_ids", None)
79
+ # only last token for inputs_ids if past is defined in kwargs
80
+ if past:
81
+ input_ids = input_ids[:, -1].unsqueeze(-1)
82
+ if token_type_ids is not None:
83
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
84
+
85
+ attention_mask = kwargs.get("attention_mask", None)
86
+ position_ids = kwargs.get("position_ids", None)
87
+
88
+ if attention_mask is not None and position_ids is None:
89
+ # create position_ids on the fly for batch generation
90
+ position_ids = attention_mask.long().cumsum(-1) - 1
91
+ position_ids.masked_fill_(attention_mask == 0, 1)
92
+ if past:
93
+ position_ids = position_ids[:, -1].unsqueeze(-1)
94
+ else:
95
+ position_ids = None
96
+ return {
97
+ "input_ids": input_ids,
98
+ "past_key_values": past,
99
+ "use_cache": kwargs.get("use_cache"),
100
+ "position_ids": position_ids,
101
+ "attention_mask": attention_mask,
102
+ "token_type_ids": token_type_ids,
103
+ }
104
+
105
+ def forward(
106
+ self,
107
+ input_ids=None,
108
+ past_key_values=None,
109
+ attention_mask=None,
110
+ token_type_ids=None,
111
+ position_ids=None,
112
+ head_mask=None,
113
+ inputs_embeds=None,
114
+ encoder_hidden_states=None,
115
+ encoder_attention_mask=None,
116
+ labels=None,
117
+ use_cache=None,
118
+ output_attentions=None,
119
+ output_hidden_states=None,
120
+ return_dict=None,
121
+ ):
122
+ assert self.cached_mel_emb is not None
123
+ assert inputs_embeds is None # Not supported by this inference model.
124
+ assert labels is None # Training not supported by this inference model.
125
+ return_dict = (
126
+ return_dict if return_dict is not None else self.config.use_return_dict
127
+ )
128
+
129
+ # Create embedding
130
+ mel_len = self.cached_mel_emb.shape[1]
131
+ if input_ids.shape[1] != 1:
132
+ text_inputs = input_ids[:, mel_len:]
133
+ text_emb = self.embeddings(text_inputs)
134
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
135
+ if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
136
+ mel_emb = self.cached_mel_emb.repeat_interleave(
137
+ text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
138
+ )
139
+ else:
140
+ mel_emb = self.cached_mel_emb
141
+ emb = torch.cat([mel_emb, text_emb], dim=1)
142
+ else:
143
+ emb = self.embeddings(input_ids)
144
+ emb = emb + self.text_pos_embedding.get_fixed_embedding(
145
+ attention_mask.shape[1] - mel_len, attention_mask.device
146
+ )
147
+
148
+ transformer_outputs = self.transformer(
149
+ inputs_embeds=emb,
150
+ past_key_values=past_key_values,
151
+ attention_mask=attention_mask,
152
+ token_type_ids=token_type_ids,
153
+ position_ids=position_ids,
154
+ head_mask=head_mask,
155
+ encoder_hidden_states=encoder_hidden_states,
156
+ encoder_attention_mask=encoder_attention_mask,
157
+ use_cache=use_cache,
158
+ output_attentions=output_attentions,
159
+ output_hidden_states=output_hidden_states,
160
+ return_dict=return_dict,
161
+ )
162
+ hidden_states = transformer_outputs[0]
163
+
164
+ # Set device for model parallelism
165
+ if self.model_parallel:
166
+ torch.cuda.set_device(self.transformer.first_device)
167
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
168
+
169
+ lm_logits = self.lm_head(hidden_states)
170
+
171
+ if not return_dict:
172
+ return (lm_logits,) + transformer_outputs[1:]
173
+
174
+ return CausalLMOutputWithCrossAttentions(
175
+ loss=None,
176
+ logits=lm_logits,
177
+ past_key_values=transformer_outputs.past_key_values,
178
+ hidden_states=transformer_outputs.hidden_states,
179
+ attentions=transformer_outputs.attentions,
180
+ cross_attentions=transformer_outputs.cross_attentions,
181
+ )
182
+
183
+ @staticmethod
184
+ def _reorder_cache(past, beam_idx):
185
+ """
186
+ This function is used to re-order the :obj:`past_key_values` cache if
187
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
188
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
189
+ """
190
+ return tuple(
191
+ tuple(
192
+ past_state.index_select(0, beam_idx.to(past_state.device))
193
+ for past_state in layer_past
194
+ )
195
+ for layer_past in past
196
+ )
197
+
198
+
199
+ class ConditioningEncoder(nn.Module):
200
+ def __init__(
201
+ self,
202
+ spec_dim,
203
+ embedding_dim,
204
+ attn_blocks=6,
205
+ num_attn_heads=4,
206
+ do_checkpointing=False,
207
+ mean=False,
208
+ ):
209
+ super().__init__()
210
+ attn = []
211
+ self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
212
+ for a in range(attn_blocks):
213
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads))
214
+ self.attn = nn.Sequential(*attn)
215
+ self.dim = embedding_dim
216
+ self.do_checkpointing = do_checkpointing
217
+ self.mean = mean
218
+
219
+ def forward(self, x):
220
+ h = self.init(x)
221
+ h = self.attn(h)
222
+ if self.mean:
223
+ return h.mean(dim=2)
224
+ else:
225
+ return h[:, :, 0]
226
+
227
+
228
+ class LearnedPositionEmbeddings(nn.Module):
229
+ def __init__(self, seq_len, model_dim, init=0.02):
230
+ super().__init__()
231
+ self.emb = nn.Embedding(seq_len, model_dim)
232
+ # Initializing this way is standard for GPT-2
233
+ self.emb.weight.data.normal_(mean=0.0, std=init)
234
+
235
+ def forward(self, x):
236
+ sl = x.shape[1]
237
+ return self.emb(torch.arange(0, sl, device=x.device))
238
+
239
+ def get_fixed_embedding(self, ind, dev):
240
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
241
+
242
+
243
+ def build_hf_gpt_transformer(
244
+ layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing
245
+ ):
246
+ """
247
+ GPT-2 implemented by the HuggingFace library.
248
+ """
249
+ from transformers import GPT2Config, GPT2Model
250
+
251
+ gpt_config = GPT2Config(
252
+ vocab_size=256, # Unused.
253
+ n_positions=max_mel_seq_len + max_text_seq_len,
254
+ n_ctx=max_mel_seq_len + max_text_seq_len,
255
+ n_embd=model_dim,
256
+ n_layer=layers,
257
+ n_head=heads,
258
+ gradient_checkpointing=checkpointing,
259
+ use_cache=not checkpointing,
260
+ )
261
+ gpt = GPT2Model(gpt_config)
262
+ # Override the built in positional embeddings
263
+ del gpt.wpe
264
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
265
+ # Built-in token embeddings are unused.
266
+ del gpt.wte
267
+ return (
268
+ gpt,
269
+ LearnedPositionEmbeddings(max_mel_seq_len, model_dim),
270
+ LearnedPositionEmbeddings(max_text_seq_len, model_dim),
271
+ None,
272
+ None,
273
+ )
274
+
275
+
276
+ class MelEncoder(nn.Module):
277
+ def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
278
+ super().__init__()
279
+ self.channels = channels
280
+ self.encoder = nn.Sequential(
281
+ nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
282
+ nn.Sequential(
283
+ *[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]
284
+ ),
285
+ nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
286
+ nn.GroupNorm(channels // 16, channels // 2),
287
+ nn.ReLU(),
288
+ nn.Sequential(
289
+ *[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]
290
+ ),
291
+ nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
292
+ nn.GroupNorm(channels // 8, channels),
293
+ nn.ReLU(),
294
+ nn.Sequential(
295
+ *[ResBlock(channels) for _ in range(resblocks_per_reduction)]
296
+ ),
297
+ )
298
+ self.reduction = 4
299
+
300
+ def forward(self, x):
301
+ for e in self.encoder:
302
+ x = e(x)
303
+ return x.permute(0, 2, 1)
304
+
305
+
306
+ class UnifiedVoice(nn.Module):
307
+ def __init__(
308
+ self,
309
+ layers=8,
310
+ model_dim=512,
311
+ heads=8,
312
+ max_text_tokens=120,
313
+ max_mel_tokens=250,
314
+ max_conditioning_inputs=1,
315
+ mel_length_compression=1024,
316
+ number_text_tokens=256,
317
+ start_text_token=None,
318
+ number_mel_codes=8194,
319
+ start_mel_token=8192,
320
+ stop_mel_token=8193,
321
+ train_solo_embeddings=False,
322
+ use_mel_codes_as_input=True,
323
+ checkpointing=True,
324
+ types=1,
325
+ ):
326
+ """
327
+ Args:
328
+ layers: Number of layers in transformer stack.
329
+ model_dim: Operating dimensions of the transformer
330
+ heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
331
+ max_text_tokens: Maximum number of text tokens that will be encountered by model.
332
+ max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
333
+ max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
334
+ mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
335
+ number_text_tokens:
336
+ start_text_token:
337
+ stop_text_token:
338
+ number_mel_codes:
339
+ start_mel_token:
340
+ stop_mel_token:
341
+ train_solo_embeddings:
342
+ use_mel_codes_as_input:
343
+ checkpointing:
344
+ """
345
+ super().__init__()
346
+
347
+ self.number_text_tokens = number_text_tokens
348
+ self.start_text_token = (
349
+ number_text_tokens * types if start_text_token is None else start_text_token
350
+ )
351
+ self.stop_text_token = 0
352
+ self.number_mel_codes = number_mel_codes
353
+ self.start_mel_token = start_mel_token
354
+ self.stop_mel_token = stop_mel_token
355
+ self.layers = layers
356
+ self.heads = heads
357
+ self.max_mel_tokens = max_mel_tokens
358
+ self.max_text_tokens = max_text_tokens
359
+ self.model_dim = model_dim
360
+ self.max_conditioning_inputs = max_conditioning_inputs
361
+ self.mel_length_compression = mel_length_compression
362
+ self.conditioning_encoder = ConditioningEncoder(
363
+ 80, model_dim, num_attn_heads=heads
364
+ )
365
+ self.text_embedding = nn.Embedding(
366
+ self.number_text_tokens * types + 1, model_dim
367
+ )
368
+ if use_mel_codes_as_input:
369
+ self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
370
+ else:
371
+ self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
372
+ (
373
+ self.gpt,
374
+ self.mel_pos_embedding,
375
+ self.text_pos_embedding,
376
+ self.mel_layer_pos_embedding,
377
+ self.text_layer_pos_embedding,
378
+ ) = build_hf_gpt_transformer(
379
+ layers,
380
+ model_dim,
381
+ heads,
382
+ self.max_mel_tokens + 2 + self.max_conditioning_inputs,
383
+ self.max_text_tokens + 2,
384
+ checkpointing,
385
+ )
386
+ if train_solo_embeddings:
387
+ self.mel_solo_embedding = nn.Parameter(
388
+ torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
389
+ )
390
+ self.text_solo_embedding = nn.Parameter(
391
+ torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
392
+ )
393
+ else:
394
+ self.mel_solo_embedding = 0
395
+ self.text_solo_embedding = 0
396
+
397
+ self.final_norm = nn.LayerNorm(model_dim)
398
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
399
+ self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
400
+
401
+ # Initialize the embeddings per the GPT-2 scheme
402
+ embeddings = [self.text_embedding]
403
+ if use_mel_codes_as_input:
404
+ embeddings.append(self.mel_embedding)
405
+ for module in embeddings:
406
+ module.weight.data.normal_(mean=0.0, std=0.02)
407
+
408
+ def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
409
+ inp = F.pad(input, (1, 0), value=start_token)
410
+ tar = F.pad(input, (0, 1), value=stop_token)
411
+ return inp, tar
412
+
413
+ def set_mel_padding(self, mel_input_tokens, wav_lengths):
414
+ """
415
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
416
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
417
+ preformatting to create a working TTS model.
418
+ """
419
+ # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
420
+ mel_lengths = torch.div(
421
+ wav_lengths, self.mel_length_compression, rounding_mode="trunc"
422
+ )
423
+ for b in range(len(mel_lengths)):
424
+ actual_end = (
425
+ mel_lengths[b] + 1
426
+ ) # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
427
+ if actual_end < mel_input_tokens.shape[-1]:
428
+ mel_input_tokens[b, actual_end:] = self.stop_mel_token
429
+ return mel_input_tokens
430
+
431
+ def get_logits(
432
+ self,
433
+ speech_conditioning_inputs,
434
+ first_inputs,
435
+ first_head,
436
+ second_inputs=None,
437
+ second_head=None,
438
+ get_attns=False,
439
+ return_latent=False,
440
+ ):
441
+ if second_inputs is not None:
442
+ emb = torch.cat(
443
+ [speech_conditioning_inputs, first_inputs, second_inputs], dim=1
444
+ )
445
+ else:
446
+ emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
447
+
448
+ gpt_out = self.gpt(
449
+ inputs_embeds=emb, return_dict=True, output_attentions=get_attns
450
+ )
451
+ if get_attns:
452
+ return gpt_out.attentions
453
+
454
+ enc = gpt_out.last_hidden_state[
455
+ :, 1:
456
+ ] # The first logit is tied to the speech_conditioning_input
457
+ enc = self.final_norm(enc)
458
+
459
+ if return_latent:
460
+ return (
461
+ enc[
462
+ :,
463
+ speech_conditioning_inputs.shape[
464
+ 1
465
+ ] : speech_conditioning_inputs.shape[1]
466
+ + first_inputs.shape[1],
467
+ ],
468
+ enc[:, -second_inputs.shape[1] :],
469
+ )
470
+
471
+ first_logits = enc[:, : first_inputs.shape[1]]
472
+ first_logits = first_head(first_logits)
473
+ first_logits = first_logits.permute(0, 2, 1)
474
+ if second_inputs is not None:
475
+ second_logits = enc[:, -second_inputs.shape[1] :]
476
+ second_logits = second_head(second_logits)
477
+ second_logits = second_logits.permute(0, 2, 1)
478
+ return first_logits, second_logits
479
+ else:
480
+ return first_logits
481
+
482
+ def get_conditioning(self, speech_conditioning_input):
483
+ speech_conditioning_input = (
484
+ speech_conditioning_input.unsqueeze(1)
485
+ if len(speech_conditioning_input.shape) == 3
486
+ else speech_conditioning_input
487
+ )
488
+ conds = []
489
+ for j in range(speech_conditioning_input.shape[1]):
490
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
491
+ conds = torch.stack(conds, dim=1)
492
+ conds = conds.mean(dim=1)
493
+ return conds
494
+
495
+ def forward(
496
+ self,
497
+ speech_conditioning_latent,
498
+ text_inputs,
499
+ text_lengths,
500
+ mel_codes,
501
+ wav_lengths,
502
+ types=None,
503
+ text_first=True,
504
+ raw_mels=None,
505
+ return_attentions=False,
506
+ return_latent=False,
507
+ clip_inputs=True,
508
+ ):
509
+ """
510
+ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
511
+ (actuated by `text_first`).
512
+
513
+ speech_conditioning_input: MEL float tensor, (b,1024)
514
+ text_inputs: long tensor, (b,t)
515
+ text_lengths: long tensor, (b,)
516
+ mel_inputs: long tensor, (b,m)
517
+ wav_lengths: long tensor, (b,)
518
+ raw_mels: MEL float tensor (b,80,s)
519
+
520
+ If return_attentions is specified, only logits are returned.
521
+ If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
522
+ If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
523
+ """
524
+ # Types are expressed by expanding the text embedding space.
525
+ if types is not None:
526
+ text_inputs = text_inputs * (1 + types).unsqueeze(-1)
527
+
528
+ if clip_inputs:
529
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
530
+ # chopping the inputs by the maximum actual length.
531
+ max_text_len = text_lengths.max()
532
+ text_inputs = text_inputs[:, :max_text_len]
533
+ max_mel_len = wav_lengths.max() // self.mel_length_compression
534
+ mel_codes = mel_codes[:, :max_mel_len]
535
+ if raw_mels is not None:
536
+ raw_mels = raw_mels[:, :, : max_mel_len * 4]
537
+ mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
538
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
539
+ mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
540
+
541
+ conds = speech_conditioning_latent.unsqueeze(1)
542
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(
543
+ text_inputs, self.start_text_token, self.stop_text_token
544
+ )
545
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
546
+ text_inputs
547
+ )
548
+ mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
549
+ mel_codes, self.start_mel_token, self.stop_mel_token
550
+ )
551
+ if raw_mels is not None:
552
+ mel_inp = F.pad(raw_mels, (0, 8))
553
+ else:
554
+ mel_inp = mel_codes
555
+ mel_emb = self.mel_embedding(mel_inp)
556
+ mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
557
+
558
+ if text_first:
559
+ text_logits, mel_logits = self.get_logits(
560
+ conds,
561
+ text_emb,
562
+ self.text_head,
563
+ mel_emb,
564
+ self.mel_head,
565
+ get_attns=return_attentions,
566
+ return_latent=return_latent,
567
+ )
568
+ if return_latent:
569
+ return mel_logits[
570
+ :, :-2
571
+ ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
572
+ else:
573
+ mel_logits, text_logits = self.get_logits(
574
+ conds,
575
+ mel_emb,
576
+ self.mel_head,
577
+ text_emb,
578
+ self.text_head,
579
+ get_attns=return_attentions,
580
+ return_latent=return_latent,
581
+ )
582
+ if return_latent:
583
+ return text_logits[
584
+ :, :-2
585
+ ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
586
+
587
+ if return_attentions:
588
+ return mel_logits
589
+ loss_text = F.cross_entropy(text_logits, text_targets.long())
590
+ loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
591
+ return loss_text.mean(), loss_mel.mean(), mel_logits
592
+
593
+ def inference_speech(
594
+ self,
595
+ speech_conditioning_latent,
596
+ text_inputs,
597
+ input_tokens=None,
598
+ num_return_sequences=1,
599
+ max_generate_length=None,
600
+ typical_sampling=False,
601
+ typical_mass=0.9,
602
+ **hf_generate_kwargs
603
+ ):
604
+ seq_length = self.max_mel_tokens + self.max_text_tokens + 2
605
+ if not hasattr(self, "inference_model"):
606
+ # TODO: Decouple gpt_config from this inference model.
607
+ gpt_config = GPT2Config(
608
+ vocab_size=self.max_mel_tokens,
609
+ n_positions=seq_length,
610
+ n_ctx=seq_length,
611
+ n_embd=self.model_dim,
612
+ n_layer=self.layers,
613
+ n_head=self.heads,
614
+ gradient_checkpointing=False,
615
+ use_cache=True,
616
+ )
617
+ self.inference_model = GPT2InferenceModel(
618
+ gpt_config,
619
+ self.gpt,
620
+ self.mel_pos_embedding,
621
+ self.mel_embedding,
622
+ self.final_norm,
623
+ self.mel_head,
624
+ )
625
+ self.gpt.wte = self.mel_embedding
626
+
627
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
628
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(
629
+ text_inputs, self.start_text_token, self.stop_text_token
630
+ )
631
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
632
+ text_inputs
633
+ )
634
+
635
+ conds = speech_conditioning_latent.unsqueeze(1)
636
+ emb = torch.cat([conds, text_emb], dim=1)
637
+ self.inference_model.store_mel_emb(emb)
638
+
639
+ fake_inputs = torch.full(
640
+ (
641
+ emb.shape[0],
642
+ conds.shape[1] + emb.shape[1],
643
+ ),
644
+ fill_value=1,
645
+ dtype=torch.long,
646
+ device=text_inputs.device,
647
+ )
648
+ fake_inputs[:, -1] = self.start_mel_token
649
+ trunc_index = fake_inputs.shape[1]
650
+ if input_tokens is None:
651
+ inputs = fake_inputs
652
+ else:
653
+ assert (
654
+ num_return_sequences % input_tokens.shape[0] == 0
655
+ ), "The number of return sequences must be divisible by the number of input sequences"
656
+ fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
657
+ input_tokens = input_tokens.repeat(
658
+ num_return_sequences // input_tokens.shape[0], 1
659
+ )
660
+ inputs = torch.cat([fake_inputs, input_tokens], dim=1)
661
+
662
+ logits_processor = (
663
+ LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)])
664
+ if typical_sampling
665
+ else LogitsProcessorList()
666
+ )
667
+ max_length = (
668
+ trunc_index + self.max_mel_tokens - 1
669
+ if max_generate_length is None
670
+ else trunc_index + max_generate_length
671
+ )
672
+ gen = self.inference_model.generate(
673
+ inputs,
674
+ bos_token_id=self.start_mel_token,
675
+ pad_token_id=self.stop_mel_token,
676
+ eos_token_id=self.stop_mel_token,
677
+ max_length=max_length,
678
+ logits_processor=logits_processor,
679
+ num_return_sequences=num_return_sequences,
680
+ **hf_generate_kwargs
681
+ )
682
+ return gen[:, trunc_index:]
683
+
684
+
685
+ if __name__ == "__main__":
686
+ gpt = UnifiedVoice(
687
+ model_dim=256,
688
+ heads=4,
689
+ train_solo_embeddings=True,
690
+ use_mel_codes_as_input=True,
691
+ max_conditioning_inputs=4,
692
+ )
693
+ l = gpt(
694
+ torch.randn(2, 3, 80, 800),
695
+ torch.randint(high=120, size=(2, 120)),
696
+ torch.tensor([32, 120]),
697
+ torch.randint(high=8192, size=(2, 250)),
698
+ torch.tensor([250 * 256, 195 * 256]),
699
+ )
700
+ gpt.text_forward(
701
+ torch.randn(2, 80, 800),
702
+ torch.randint(high=50, size=(2, 80)),
703
+ torch.tensor([32, 80]),
704
+ )
tortoise/models/classifier.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from tortoise.models.arch_util import (
5
+ Upsample,
6
+ Downsample,
7
+ normalization,
8
+ zero_module,
9
+ AttentionBlock,
10
+ )
11
+
12
+
13
+ class ResBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ channels,
17
+ dropout,
18
+ out_channels=None,
19
+ use_conv=False,
20
+ use_scale_shift_norm=False,
21
+ dims=2,
22
+ up=False,
23
+ down=False,
24
+ kernel_size=3,
25
+ do_checkpoint=True,
26
+ ):
27
+ super().__init__()
28
+ self.channels = channels
29
+ self.dropout = dropout
30
+ self.out_channels = out_channels or channels
31
+ self.use_conv = use_conv
32
+ self.use_scale_shift_norm = use_scale_shift_norm
33
+ self.do_checkpoint = do_checkpoint
34
+ padding = 1 if kernel_size == 3 else 2
35
+
36
+ self.in_layers = nn.Sequential(
37
+ normalization(channels),
38
+ nn.SiLU(),
39
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
40
+ )
41
+
42
+ self.updown = up or down
43
+
44
+ if up:
45
+ self.h_upd = Upsample(channels, False, dims)
46
+ self.x_upd = Upsample(channels, False, dims)
47
+ elif down:
48
+ self.h_upd = Downsample(channels, False, dims)
49
+ self.x_upd = Downsample(channels, False, dims)
50
+ else:
51
+ self.h_upd = self.x_upd = nn.Identity()
52
+
53
+ self.out_layers = nn.Sequential(
54
+ normalization(self.out_channels),
55
+ nn.SiLU(),
56
+ nn.Dropout(p=dropout),
57
+ zero_module(
58
+ nn.Conv1d(
59
+ self.out_channels, self.out_channels, kernel_size, padding=padding
60
+ )
61
+ ),
62
+ )
63
+
64
+ if self.out_channels == channels:
65
+ self.skip_connection = nn.Identity()
66
+ elif use_conv:
67
+ self.skip_connection = nn.Conv1d(
68
+ dims, channels, self.out_channels, kernel_size, padding=padding
69
+ )
70
+ else:
71
+ self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
72
+
73
+ def forward(self, x):
74
+ if self.updown:
75
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
76
+ h = in_rest(x)
77
+ h = self.h_upd(h)
78
+ x = self.x_upd(x)
79
+ h = in_conv(h)
80
+ else:
81
+ h = self.in_layers(x)
82
+ h = self.out_layers(h)
83
+ return self.skip_connection(x) + h
84
+
85
+
86
+ class AudioMiniEncoder(nn.Module):
87
+ def __init__(
88
+ self,
89
+ spec_dim,
90
+ embedding_dim,
91
+ base_channels=128,
92
+ depth=2,
93
+ resnet_blocks=2,
94
+ attn_blocks=4,
95
+ num_attn_heads=4,
96
+ dropout=0,
97
+ downsample_factor=2,
98
+ kernel_size=3,
99
+ ):
100
+ super().__init__()
101
+ self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
102
+ ch = base_channels
103
+ res = []
104
+ self.layers = depth
105
+ for l in range(depth):
106
+ for r in range(resnet_blocks):
107
+ res.append(
108
+ ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)
109
+ )
110
+ res.append(
111
+ Downsample(
112
+ ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor
113
+ )
114
+ )
115
+ ch *= 2
116
+ self.res = nn.Sequential(*res)
117
+ self.final = nn.Sequential(
118
+ normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)
119
+ )
120
+ attn = []
121
+ for a in range(attn_blocks):
122
+ attn.append(
123
+ AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)
124
+ )
125
+ self.attn = nn.Sequential(*attn)
126
+ self.dim = embedding_dim
127
+
128
+ def forward(self, x):
129
+ h = self.init(x)
130
+ h = self.res(h)
131
+ h = self.final(h)
132
+ for blk in self.attn:
133
+ h = blk(h)
134
+ return h[:, :, 0]
135
+
136
+
137
+ class AudioMiniEncoderWithClassifierHead(nn.Module):
138
+ def __init__(self, classes, distribute_zero_label=True, **kwargs):
139
+ super().__init__()
140
+ self.enc = AudioMiniEncoder(**kwargs)
141
+ self.head = nn.Linear(self.enc.dim, classes)
142
+ self.num_classes = classes
143
+ self.distribute_zero_label = distribute_zero_label
144
+
145
+ def forward(self, x, labels=None):
146
+ h = self.enc(x)
147
+ logits = self.head(h)
148
+ if labels is None:
149
+ return logits
150
+ else:
151
+ if self.distribute_zero_label:
152
+ oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
153
+ zeros_indices = (labels == 0).unsqueeze(-1)
154
+ # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
155
+ zero_extra_mass = torch.full_like(
156
+ oh_labels,
157
+ dtype=torch.float,
158
+ fill_value=0.2 / (self.num_classes - 1),
159
+ )
160
+ zero_extra_mass[:, 0] = -0.2
161
+ zero_extra_mass = zero_extra_mass * zeros_indices
162
+ oh_labels = oh_labels + zero_extra_mass
163
+ else:
164
+ oh_labels = labels
165
+ loss = nn.functional.cross_entropy(logits, oh_labels)
166
+ return loss
tortoise/models/clvp.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import einsum
5
+
6
+ from tortoise.models.arch_util import CheckpointedXTransformerEncoder
7
+ from tortoise.models.transformer import Transformer
8
+ from tortoise.models.xtransformers import Encoder
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def masked_mean(t, mask, dim=1):
16
+ t = t.masked_fill(~mask[:, :, None], 0.0)
17
+ return t.sum(dim=1) / mask.sum(dim=1)[..., None]
18
+
19
+
20
+ class CLVP(nn.Module):
21
+ """
22
+ CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
23
+ transcribed text.
24
+
25
+ Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ *,
31
+ dim_text=512,
32
+ dim_speech=512,
33
+ dim_latent=512,
34
+ num_text_tokens=256,
35
+ text_enc_depth=6,
36
+ text_seq_len=120,
37
+ text_heads=8,
38
+ num_speech_tokens=8192,
39
+ speech_enc_depth=6,
40
+ speech_heads=8,
41
+ speech_seq_len=250,
42
+ text_mask_percentage=0,
43
+ voice_mask_percentage=0,
44
+ wav_token_compression=1024,
45
+ use_xformers=False,
46
+ ):
47
+ super().__init__()
48
+ self.text_emb = nn.Embedding(num_text_tokens, dim_text)
49
+ self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
50
+
51
+ self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
52
+ self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
53
+
54
+ if use_xformers:
55
+ self.text_transformer = CheckpointedXTransformerEncoder(
56
+ needs_permute=False,
57
+ exit_permute=False,
58
+ max_seq_len=-1,
59
+ attn_layers=Encoder(
60
+ dim=dim_text,
61
+ depth=text_enc_depth,
62
+ heads=text_heads,
63
+ ff_dropout=0.1,
64
+ ff_mult=2,
65
+ attn_dropout=0.1,
66
+ use_rmsnorm=True,
67
+ ff_glu=True,
68
+ rotary_pos_emb=True,
69
+ ),
70
+ )
71
+ self.speech_transformer = CheckpointedXTransformerEncoder(
72
+ needs_permute=False,
73
+ exit_permute=False,
74
+ max_seq_len=-1,
75
+ attn_layers=Encoder(
76
+ dim=dim_speech,
77
+ depth=speech_enc_depth,
78
+ heads=speech_heads,
79
+ ff_dropout=0.1,
80
+ ff_mult=2,
81
+ attn_dropout=0.1,
82
+ use_rmsnorm=True,
83
+ ff_glu=True,
84
+ rotary_pos_emb=True,
85
+ ),
86
+ )
87
+ else:
88
+ self.text_transformer = Transformer(
89
+ causal=False,
90
+ seq_len=text_seq_len,
91
+ dim=dim_text,
92
+ depth=text_enc_depth,
93
+ heads=text_heads,
94
+ )
95
+ self.speech_transformer = Transformer(
96
+ causal=False,
97
+ seq_len=speech_seq_len,
98
+ dim=dim_speech,
99
+ depth=speech_enc_depth,
100
+ heads=speech_heads,
101
+ )
102
+
103
+ self.temperature = nn.Parameter(torch.tensor(1.0))
104
+ self.text_mask_percentage = text_mask_percentage
105
+ self.voice_mask_percentage = voice_mask_percentage
106
+ self.wav_token_compression = wav_token_compression
107
+ self.xformers = use_xformers
108
+ if not use_xformers:
109
+ self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
110
+ self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
111
+
112
+ def forward(self, text, speech_tokens, return_loss=False):
113
+ b, device = text.shape[0], text.device
114
+ if self.training:
115
+ text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
116
+ voice_mask = (
117
+ torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
118
+ )
119
+ else:
120
+ text_mask = torch.ones_like(text.float()).bool()
121
+ voice_mask = torch.ones_like(speech_tokens.float()).bool()
122
+
123
+ text_emb = self.text_emb(text)
124
+ speech_emb = self.speech_emb(speech_tokens)
125
+
126
+ if not self.xformers:
127
+ text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
128
+ speech_emb += self.speech_pos_emb(
129
+ torch.arange(speech_emb.shape[1], device=device)
130
+ )
131
+
132
+ enc_text = self.text_transformer(text_emb, mask=text_mask)
133
+ enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
134
+
135
+ text_latents = masked_mean(enc_text, text_mask, dim=1)
136
+ speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
137
+
138
+ text_latents = self.to_text_latent(text_latents)
139
+ speech_latents = self.to_speech_latent(speech_latents)
140
+
141
+ text_latents, speech_latents = map(
142
+ lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)
143
+ )
144
+
145
+ temp = self.temperature.exp()
146
+
147
+ if not return_loss:
148
+ sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp
149
+ return sim
150
+
151
+ sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp
152
+ labels = torch.arange(b, device=device)
153
+ loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
154
+ return loss
155
+
156
+
157
+ if __name__ == "__main__":
158
+ clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2)
159
+ clip(
160
+ torch.randint(0, 256, (2, 120)),
161
+ torch.tensor([50, 100]),
162
+ torch.randint(0, 8192, (2, 250)),
163
+ torch.tensor([101, 102]),
164
+ return_loss=True,
165
+ )
166
+ nonloss = clip(
167
+ torch.randint(0, 256, (2, 120)),
168
+ torch.tensor([50, 100]),
169
+ torch.randint(0, 8192, (2, 250)),
170
+ torch.tensor([101, 102]),
171
+ return_loss=False,
172
+ )
173
+ print(nonloss.shape)
tortoise/models/cvvp.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import einsum
5
+
6
+ from tortoise.models.arch_util import AttentionBlock
7
+ from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
8
+
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+
14
+ def masked_mean(t, mask):
15
+ t = t.masked_fill(~mask, 0.0)
16
+ return t.sum(dim=1) / mask.sum(dim=1)
17
+
18
+
19
+ class CollapsingTransformer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ model_dim,
23
+ output_dims,
24
+ heads,
25
+ dropout,
26
+ depth,
27
+ mask_percentage=0,
28
+ **encoder_kwargs
29
+ ):
30
+ super().__init__()
31
+ self.transformer = ContinuousTransformerWrapper(
32
+ max_seq_len=-1,
33
+ use_pos_emb=False,
34
+ attn_layers=Encoder(
35
+ dim=model_dim,
36
+ depth=depth,
37
+ heads=heads,
38
+ ff_dropout=dropout,
39
+ ff_mult=1,
40
+ attn_dropout=dropout,
41
+ use_rmsnorm=True,
42
+ ff_glu=True,
43
+ rotary_pos_emb=True,
44
+ **encoder_kwargs,
45
+ ),
46
+ )
47
+ self.pre_combiner = nn.Sequential(
48
+ nn.Conv1d(model_dim, output_dims, 1),
49
+ AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False),
50
+ nn.Conv1d(output_dims, output_dims, 1),
51
+ )
52
+ self.mask_percentage = mask_percentage
53
+
54
+ def forward(self, x, **transformer_kwargs):
55
+ h = self.transformer(x, **transformer_kwargs)
56
+ h = h.permute(0, 2, 1)
57
+ h = self.pre_combiner(h).permute(0, 2, 1)
58
+ if self.training:
59
+ mask = torch.rand_like(h.float()) > self.mask_percentage
60
+ else:
61
+ mask = torch.ones_like(h.float()).bool()
62
+ return masked_mean(h, mask)
63
+
64
+
65
+ class ConvFormatEmbedding(nn.Module):
66
+ def __init__(self, *args, **kwargs):
67
+ super().__init__()
68
+ self.emb = nn.Embedding(*args, **kwargs)
69
+
70
+ def forward(self, x):
71
+ y = self.emb(x)
72
+ return y.permute(0, 2, 1)
73
+
74
+
75
+ class CVVP(nn.Module):
76
+ def __init__(
77
+ self,
78
+ model_dim=512,
79
+ transformer_heads=8,
80
+ dropout=0.1,
81
+ conditioning_enc_depth=8,
82
+ cond_mask_percentage=0,
83
+ mel_channels=80,
84
+ mel_codes=None,
85
+ speech_enc_depth=8,
86
+ speech_mask_percentage=0,
87
+ latent_multiplier=1,
88
+ ):
89
+ super().__init__()
90
+ latent_dim = latent_multiplier * model_dim
91
+ self.temperature = nn.Parameter(torch.tensor(1.0))
92
+
93
+ self.cond_emb = nn.Sequential(
94
+ nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2),
95
+ nn.Conv1d(model_dim // 2, model_dim, kernel_size=3, stride=2, padding=1),
96
+ )
97
+ self.conditioning_transformer = CollapsingTransformer(
98
+ model_dim,
99
+ model_dim,
100
+ transformer_heads,
101
+ dropout,
102
+ conditioning_enc_depth,
103
+ cond_mask_percentage,
104
+ )
105
+ self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
106
+
107
+ if mel_codes is None:
108
+ self.speech_emb = nn.Conv1d(
109
+ mel_channels, model_dim, kernel_size=5, padding=2
110
+ )
111
+ else:
112
+ self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
113
+ self.speech_transformer = CollapsingTransformer(
114
+ model_dim,
115
+ latent_dim,
116
+ transformer_heads,
117
+ dropout,
118
+ speech_enc_depth,
119
+ speech_mask_percentage,
120
+ )
121
+ self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
122
+
123
+ def get_grad_norm_parameter_groups(self):
124
+ return {
125
+ "conditioning": list(self.conditioning_transformer.parameters()),
126
+ "speech": list(self.speech_transformer.parameters()),
127
+ }
128
+
129
+ def forward(self, mel_cond, mel_input, return_loss=False):
130
+ cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1)
131
+ enc_cond = self.conditioning_transformer(cond_emb)
132
+ cond_latents = self.to_conditioning_latent(enc_cond)
133
+
134
+ speech_emb = self.speech_emb(mel_input).permute(0, 2, 1)
135
+ enc_speech = self.speech_transformer(speech_emb)
136
+ speech_latents = self.to_speech_latent(enc_speech)
137
+
138
+ cond_latents, speech_latents = map(
139
+ lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)
140
+ )
141
+ temp = self.temperature.exp()
142
+
143
+ if not return_loss:
144
+ sim = einsum("n d, n d -> n", cond_latents, speech_latents) * temp
145
+ return sim
146
+
147
+ sim = einsum("i d, j d -> i j", cond_latents, speech_latents) * temp
148
+ labels = torch.arange(cond_latents.shape[0], device=mel_input.device)
149
+ loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
150
+
151
+ return loss
152
+
153
+
154
+ if __name__ == "__main__":
155
+ clvp = CVVP()
156
+ clvp(torch.randn(2, 80, 100), torch.randn(2, 80, 95), return_loss=True)
tortoise/models/diffusion_decoder.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from abc import abstractmethod
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import autocast
9
+
10
+ from tortoise.models.arch_util import normalization, AttentionBlock
11
+
12
+
13
+ def is_latent(t):
14
+ return t.dtype == torch.float
15
+
16
+
17
+ def is_sequence(t):
18
+ return t.dtype == torch.long
19
+
20
+
21
+ def timestep_embedding(timesteps, dim, max_period=10000):
22
+ """
23
+ Create sinusoidal timestep embeddings.
24
+
25
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
26
+ These may be fractional.
27
+ :param dim: the dimension of the output.
28
+ :param max_period: controls the minimum frequency of the embeddings.
29
+ :return: an [N x dim] Tensor of positional embeddings.
30
+ """
31
+ half = dim // 2
32
+ freqs = torch.exp(
33
+ -math.log(max_period)
34
+ * torch.arange(start=0, end=half, dtype=torch.float32)
35
+ / half
36
+ ).to(device=timesteps.device)
37
+ args = timesteps[:, None].float() * freqs[None]
38
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
39
+ if dim % 2:
40
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
41
+ return embedding
42
+
43
+
44
+ class TimestepBlock(nn.Module):
45
+ @abstractmethod
46
+ def forward(self, x, emb):
47
+ """
48
+ Apply the module to `x` given `emb` timestep embeddings.
49
+ """
50
+
51
+
52
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
53
+ def forward(self, x, emb):
54
+ for layer in self:
55
+ if isinstance(layer, TimestepBlock):
56
+ x = layer(x, emb)
57
+ else:
58
+ x = layer(x)
59
+ return x
60
+
61
+
62
+ class ResBlock(TimestepBlock):
63
+ def __init__(
64
+ self,
65
+ channels,
66
+ emb_channels,
67
+ dropout,
68
+ out_channels=None,
69
+ dims=2,
70
+ kernel_size=3,
71
+ efficient_config=True,
72
+ use_scale_shift_norm=False,
73
+ ):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.emb_channels = emb_channels
77
+ self.dropout = dropout
78
+ self.out_channels = out_channels or channels
79
+ self.use_scale_shift_norm = use_scale_shift_norm
80
+ padding = {1: 0, 3: 1, 5: 2}[kernel_size]
81
+ eff_kernel = 1 if efficient_config else 3
82
+ eff_padding = 0 if efficient_config else 1
83
+
84
+ self.in_layers = nn.Sequential(
85
+ normalization(channels),
86
+ nn.SiLU(),
87
+ nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
88
+ )
89
+
90
+ self.emb_layers = nn.Sequential(
91
+ nn.SiLU(),
92
+ nn.Linear(
93
+ emb_channels,
94
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
95
+ ),
96
+ )
97
+ self.out_layers = nn.Sequential(
98
+ normalization(self.out_channels),
99
+ nn.SiLU(),
100
+ nn.Dropout(p=dropout),
101
+ nn.Conv1d(
102
+ self.out_channels, self.out_channels, kernel_size, padding=padding
103
+ ),
104
+ )
105
+
106
+ if self.out_channels == channels:
107
+ self.skip_connection = nn.Identity()
108
+ else:
109
+ self.skip_connection = nn.Conv1d(
110
+ channels, self.out_channels, eff_kernel, padding=eff_padding
111
+ )
112
+
113
+ def forward(self, x, emb):
114
+ h = self.in_layers(x)
115
+ emb_out = self.emb_layers(emb).type(h.dtype)
116
+ while len(emb_out.shape) < len(h.shape):
117
+ emb_out = emb_out[..., None]
118
+ if self.use_scale_shift_norm:
119
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
120
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
121
+ h = out_norm(h) * (1 + scale) + shift
122
+ h = out_rest(h)
123
+ else:
124
+ h = h + emb_out
125
+ h = self.out_layers(h)
126
+ return self.skip_connection(x) + h
127
+
128
+
129
+ class DiffusionLayer(TimestepBlock):
130
+ def __init__(self, model_channels, dropout, num_heads):
131
+ super().__init__()
132
+ self.resblk = ResBlock(
133
+ model_channels,
134
+ model_channels,
135
+ dropout,
136
+ model_channels,
137
+ dims=1,
138
+ use_scale_shift_norm=True,
139
+ )
140
+ self.attn = AttentionBlock(
141
+ model_channels, num_heads, relative_pos_embeddings=True
142
+ )
143
+
144
+ def forward(self, x, time_emb):
145
+ y = self.resblk(x, time_emb)
146
+ return self.attn(y)
147
+
148
+
149
+ class DiffusionTts(nn.Module):
150
+ def __init__(
151
+ self,
152
+ model_channels=512,
153
+ num_layers=8,
154
+ in_channels=100,
155
+ in_latent_channels=512,
156
+ in_tokens=8193,
157
+ out_channels=200, # mean and variance
158
+ dropout=0,
159
+ use_fp16=False,
160
+ num_heads=16,
161
+ # Parameters for regularization.
162
+ layer_drop=0.1,
163
+ unconditioned_percentage=0.1, # This implements a mechanism similar to what is used in classifier-free training.
164
+ ):
165
+ super().__init__()
166
+
167
+ self.in_channels = in_channels
168
+ self.model_channels = model_channels
169
+ self.out_channels = out_channels
170
+ self.dropout = dropout
171
+ self.num_heads = num_heads
172
+ self.unconditioned_percentage = unconditioned_percentage
173
+ self.enable_fp16 = use_fp16
174
+ self.layer_drop = layer_drop
175
+
176
+ self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
177
+ self.time_embed = nn.Sequential(
178
+ nn.Linear(model_channels, model_channels),
179
+ nn.SiLU(),
180
+ nn.Linear(model_channels, model_channels),
181
+ )
182
+
183
+ # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
184
+ # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
185
+ # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
186
+ # transformer network.
187
+ self.code_embedding = nn.Embedding(in_tokens, model_channels)
188
+ self.code_converter = nn.Sequential(
189
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
190
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
191
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
192
+ )
193
+ self.code_norm = normalization(model_channels)
194
+ self.latent_conditioner = nn.Sequential(
195
+ nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
196
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
197
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
198
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
199
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
200
+ )
201
+ self.contextual_embedder = nn.Sequential(
202
+ nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
203
+ nn.Conv1d(model_channels, model_channels * 2, 3, padding=1, stride=2),
204
+ AttentionBlock(
205
+ model_channels * 2,
206
+ num_heads,
207
+ relative_pos_embeddings=True,
208
+ do_checkpoint=False,
209
+ ),
210
+ AttentionBlock(
211
+ model_channels * 2,
212
+ num_heads,
213
+ relative_pos_embeddings=True,
214
+ do_checkpoint=False,
215
+ ),
216
+ AttentionBlock(
217
+ model_channels * 2,
218
+ num_heads,
219
+ relative_pos_embeddings=True,
220
+ do_checkpoint=False,
221
+ ),
222
+ AttentionBlock(
223
+ model_channels * 2,
224
+ num_heads,
225
+ relative_pos_embeddings=True,
226
+ do_checkpoint=False,
227
+ ),
228
+ AttentionBlock(
229
+ model_channels * 2,
230
+ num_heads,
231
+ relative_pos_embeddings=True,
232
+ do_checkpoint=False,
233
+ ),
234
+ )
235
+ self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
236
+ self.conditioning_timestep_integrator = TimestepEmbedSequential(
237
+ DiffusionLayer(model_channels, dropout, num_heads),
238
+ DiffusionLayer(model_channels, dropout, num_heads),
239
+ DiffusionLayer(model_channels, dropout, num_heads),
240
+ )
241
+
242
+ self.integrating_conv = nn.Conv1d(
243
+ model_channels * 2, model_channels, kernel_size=1
244
+ )
245
+ self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
246
+
247
+ self.layers = nn.ModuleList(
248
+ [
249
+ DiffusionLayer(model_channels, dropout, num_heads)
250
+ for _ in range(num_layers)
251
+ ]
252
+ + [
253
+ ResBlock(
254
+ model_channels,
255
+ model_channels,
256
+ dropout,
257
+ dims=1,
258
+ use_scale_shift_norm=True,
259
+ )
260
+ for _ in range(3)
261
+ ]
262
+ )
263
+
264
+ self.out = nn.Sequential(
265
+ normalization(model_channels),
266
+ nn.SiLU(),
267
+ nn.Conv1d(model_channels, out_channels, 3, padding=1),
268
+ )
269
+
270
+ def get_grad_norm_parameter_groups(self):
271
+ groups = {
272
+ "minicoder": list(self.contextual_embedder.parameters()),
273
+ "layers": list(self.layers.parameters()),
274
+ "code_converters": list(self.code_embedding.parameters())
275
+ + list(self.code_converter.parameters())
276
+ + list(self.latent_conditioner.parameters())
277
+ + list(self.latent_conditioner.parameters()),
278
+ "timestep_integrator": list(
279
+ self.conditioning_timestep_integrator.parameters()
280
+ )
281
+ + list(self.integrating_conv.parameters()),
282
+ "time_embed": list(self.time_embed.parameters()),
283
+ }
284
+ return groups
285
+
286
+ def get_conditioning(self, conditioning_input):
287
+ speech_conditioning_input = (
288
+ conditioning_input.unsqueeze(1)
289
+ if len(conditioning_input.shape) == 3
290
+ else conditioning_input
291
+ )
292
+ conds = []
293
+ for j in range(speech_conditioning_input.shape[1]):
294
+ conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
295
+ conds = torch.cat(conds, dim=-1)
296
+ conds = conds.mean(dim=-1)
297
+ return conds
298
+
299
+ def timestep_independent(
300
+ self,
301
+ aligned_conditioning,
302
+ conditioning_latent,
303
+ expected_seq_len,
304
+ return_code_pred,
305
+ ):
306
+ # Shuffle aligned_latent to BxCxS format
307
+ if is_latent(aligned_conditioning):
308
+ aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
309
+
310
+ cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1)
311
+ if is_latent(aligned_conditioning):
312
+ code_emb = self.latent_conditioner(aligned_conditioning)
313
+ else:
314
+ code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
315
+ code_emb = self.code_converter(code_emb)
316
+ code_emb = self.code_norm(code_emb) * (
317
+ 1 + cond_scale.unsqueeze(-1)
318
+ ) + cond_shift.unsqueeze(-1)
319
+
320
+ unconditioned_batches = torch.zeros(
321
+ (code_emb.shape[0], 1, 1), device=code_emb.device
322
+ )
323
+ # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
324
+ if self.training and self.unconditioned_percentage > 0:
325
+ unconditioned_batches = (
326
+ torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device)
327
+ < self.unconditioned_percentage
328
+ )
329
+ code_emb = torch.where(
330
+ unconditioned_batches,
331
+ self.unconditioned_embedding.repeat(
332
+ aligned_conditioning.shape[0], 1, 1
333
+ ),
334
+ code_emb,
335
+ )
336
+ expanded_code_emb = F.interpolate(
337
+ code_emb, size=expected_seq_len, mode="nearest"
338
+ )
339
+
340
+ if not return_code_pred:
341
+ return expanded_code_emb
342
+ else:
343
+ mel_pred = self.mel_head(expanded_code_emb)
344
+ # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
345
+ mel_pred = mel_pred * unconditioned_batches.logical_not()
346
+ return expanded_code_emb, mel_pred
347
+
348
+ def forward(
349
+ self,
350
+ x,
351
+ timesteps,
352
+ aligned_conditioning=None,
353
+ conditioning_latent=None,
354
+ precomputed_aligned_embeddings=None,
355
+ conditioning_free=False,
356
+ return_code_pred=False,
357
+ ):
358
+ """
359
+ Apply the model to an input batch.
360
+
361
+ :param x: an [N x C x ...] Tensor of inputs.
362
+ :param timesteps: a 1-D batch of timesteps.
363
+ :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
364
+ :param conditioning_latent: a pre-computed conditioning latent; see get_conditioning().
365
+ :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
366
+ :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
367
+ :return: an [N x C x ...] Tensor of outputs.
368
+ """
369
+ assert precomputed_aligned_embeddings is not None or (
370
+ aligned_conditioning is not None and conditioning_latent is not None
371
+ )
372
+ assert not (
373
+ return_code_pred and precomputed_aligned_embeddings is not None
374
+ ) # These two are mutually exclusive.
375
+
376
+ unused_params = []
377
+ if conditioning_free:
378
+ code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
379
+ unused_params.extend(
380
+ list(self.code_converter.parameters())
381
+ + list(self.code_embedding.parameters())
382
+ )
383
+ unused_params.extend(list(self.latent_conditioner.parameters()))
384
+ else:
385
+ if precomputed_aligned_embeddings is not None:
386
+ code_emb = precomputed_aligned_embeddings
387
+ else:
388
+ code_emb, mel_pred = self.timestep_independent(
389
+ aligned_conditioning, conditioning_latent, x.shape[-1], True
390
+ )
391
+ if is_latent(aligned_conditioning):
392
+ unused_params.extend(
393
+ list(self.code_converter.parameters())
394
+ + list(self.code_embedding.parameters())
395
+ )
396
+ else:
397
+ unused_params.extend(list(self.latent_conditioner.parameters()))
398
+
399
+ unused_params.append(self.unconditioned_embedding)
400
+
401
+ time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
402
+ code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
403
+ x = self.inp_block(x)
404
+ x = torch.cat([x, code_emb], dim=1)
405
+ x = self.integrating_conv(x)
406
+ for i, lyr in enumerate(self.layers):
407
+ # Do layer drop where applicable. Do not drop first and last layers.
408
+ if (
409
+ self.training
410
+ and self.layer_drop > 0
411
+ and i != 0
412
+ and i != (len(self.layers) - 1)
413
+ and random.random() < self.layer_drop
414
+ ):
415
+ unused_params.extend(list(lyr.parameters()))
416
+ else:
417
+ # First and last blocks will have autocast disabled for improved precision.
418
+ with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
419
+ x = lyr(x, time_emb)
420
+
421
+ x = x.float()
422
+ out = self.out(x)
423
+
424
+ # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
425
+ extraneous_addition = 0
426
+ for p in unused_params:
427
+ extraneous_addition = extraneous_addition + p.mean()
428
+ out = out + extraneous_addition * 0
429
+
430
+ if return_code_pred:
431
+ return out, mel_pred
432
+ return out
433
+
434
+
435
+ if __name__ == "__main__":
436
+ clip = torch.randn(2, 100, 400)
437
+ aligned_latent = torch.randn(2, 388, 512)
438
+ aligned_sequence = torch.randint(0, 8192, (2, 100))
439
+ cond = torch.randn(2, 100, 400)
440
+ ts = torch.LongTensor([600, 600])
441
+ model = DiffusionTts(512, layer_drop=0.3, unconditioned_percentage=0.5)
442
+ # Test with latent aligned conditioning
443
+ # o = model(clip, ts, aligned_latent, cond)
444
+ # Test with sequence aligned conditioning
445
+ o = model(clip, ts, aligned_sequence, cond)
tortoise/models/random_latent_generator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
9
+ if bias is not None:
10
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
11
+ return (
12
+ F.leaky_relu(
13
+ input + bias.view(1, bias.shape[0], *rest_dim),
14
+ negative_slope=negative_slope,
15
+ )
16
+ * scale
17
+ )
18
+ else:
19
+ return F.leaky_relu(input, negative_slope=0.2) * scale
20
+
21
+
22
+ class EqualLinear(nn.Module):
23
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1):
24
+ super().__init__()
25
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
26
+ if bias:
27
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
28
+ else:
29
+ self.bias = None
30
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
31
+ self.lr_mul = lr_mul
32
+
33
+ def forward(self, input):
34
+ out = F.linear(input, self.weight * self.scale)
35
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
36
+ return out
37
+
38
+
39
+ class RandomLatentConverter(nn.Module):
40
+ def __init__(self, channels):
41
+ super().__init__()
42
+ self.layers = nn.Sequential(
43
+ *[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)],
44
+ nn.Linear(channels, channels)
45
+ )
46
+ self.channels = channels
47
+
48
+ def forward(self, ref):
49
+ r = torch.randn(ref.shape[0], self.channels, device=ref.device)
50
+ y = self.layers(r)
51
+ return y
52
+
53
+
54
+ if __name__ == "__main__":
55
+ model = RandomLatentConverter(512)
56
+ model(torch.randn(5, 512))
tortoise/models/transformer.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from rotary_embedding_torch import RotaryEmbedding, broadcat
7
+ from torch import nn
8
+
9
+
10
+ # helpers
11
+
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+
21
+ def cast_tuple(val, depth=1):
22
+ if isinstance(val, list):
23
+ val = tuple(val)
24
+ return val if isinstance(val, tuple) else (val,) * depth
25
+
26
+
27
+ def max_neg_value(t):
28
+ return -torch.finfo(t.dtype).max
29
+
30
+
31
+ def stable_softmax(t, dim=-1, alpha=32**2):
32
+ t = t / alpha
33
+ t = t - torch.amax(t, dim=dim, keepdim=True).detach()
34
+ return (t * alpha).softmax(dim=dim)
35
+
36
+
37
+ def route_args(router, args, depth):
38
+ routed_args = [(dict(), dict()) for _ in range(depth)]
39
+ matched_keys = [key for key in args.keys() if key in router]
40
+
41
+ for key in matched_keys:
42
+ val = args[key]
43
+ for depth, ((f_args, g_args), routes) in enumerate(
44
+ zip(routed_args, router[key])
45
+ ):
46
+ new_f_args, new_g_args = map(
47
+ lambda route: ({key: val} if route else {}), routes
48
+ )
49
+ routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
50
+ return routed_args
51
+
52
+
53
+ # classes
54
+ class SequentialSequence(nn.Module):
55
+ def __init__(self, layers, args_route={}, layer_dropout=0.0):
56
+ super().__init__()
57
+ assert all(
58
+ len(route) == len(layers) for route in args_route.values()
59
+ ), "each argument route map must have the same depth as the number of sequential layers"
60
+ self.layers = layers
61
+ self.args_route = args_route
62
+ self.layer_dropout = layer_dropout
63
+
64
+ def forward(self, x, **kwargs):
65
+ args = route_args(self.args_route, kwargs, len(self.layers))
66
+ layers_and_args = list(zip(self.layers, args))
67
+
68
+ for (f, g), (f_args, g_args) in layers_and_args:
69
+ x = x + f(x, **f_args)
70
+ x = x + g(x, **g_args)
71
+ return x
72
+
73
+
74
+ class DivideMax(nn.Module):
75
+ def __init__(self, dim):
76
+ super().__init__()
77
+ self.dim = dim
78
+
79
+ def forward(self, x):
80
+ maxes = x.amax(dim=self.dim, keepdim=True).detach()
81
+ return x / maxes
82
+
83
+
84
+ # https://arxiv.org/abs/2103.17239
85
+ class LayerScale(nn.Module):
86
+ def __init__(self, dim, depth, fn):
87
+ super().__init__()
88
+ if depth <= 18:
89
+ init_eps = 0.1
90
+ elif depth > 18 and depth <= 24:
91
+ init_eps = 1e-5
92
+ else:
93
+ init_eps = 1e-6
94
+
95
+ scale = torch.zeros(1, 1, dim).fill_(init_eps)
96
+ self.scale = nn.Parameter(scale)
97
+ self.fn = fn
98
+
99
+ def forward(self, x, **kwargs):
100
+ return self.fn(x, **kwargs) * self.scale
101
+
102
+
103
+ # layer norm
104
+
105
+
106
+ class PreNorm(nn.Module):
107
+ def __init__(self, dim, fn, sandwich=False):
108
+ super().__init__()
109
+ self.norm = nn.LayerNorm(dim)
110
+ self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
111
+ self.fn = fn
112
+
113
+ def forward(self, x, **kwargs):
114
+ x = self.norm(x)
115
+ x = self.fn(x, **kwargs)
116
+ return self.norm_out(x)
117
+
118
+
119
+ # feed forward
120
+
121
+
122
+ class GEGLU(nn.Module):
123
+ def forward(self, x):
124
+ x, gates = x.chunk(2, dim=-1)
125
+ return x * F.gelu(gates)
126
+
127
+
128
+ class FeedForward(nn.Module):
129
+ def __init__(self, dim, dropout=0.0, mult=4.0):
130
+ super().__init__()
131
+ self.net = nn.Sequential(
132
+ nn.Linear(dim, dim * mult * 2),
133
+ GEGLU(),
134
+ nn.Dropout(dropout),
135
+ nn.Linear(dim * mult, dim),
136
+ )
137
+
138
+ def forward(self, x):
139
+ return self.net(x)
140
+
141
+
142
+ # Attention
143
+
144
+
145
+ class Attention(nn.Module):
146
+ def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.0):
147
+ super().__init__()
148
+ inner_dim = dim_head * heads
149
+ self.heads = heads
150
+ self.seq_len = seq_len
151
+ self.scale = dim_head**-0.5
152
+
153
+ self.causal = causal
154
+
155
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
156
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
157
+
158
+ def forward(self, x, mask=None):
159
+ b, n, _, h, device = *x.shape, self.heads, x.device
160
+ softmax = torch.softmax
161
+
162
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
163
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
164
+
165
+ q = q * self.scale
166
+
167
+ dots = torch.einsum("b h i d, b h j d -> b h i j", q, k)
168
+ mask_value = max_neg_value(dots)
169
+
170
+ if exists(mask):
171
+ mask = rearrange(mask, "b j -> b () () j")
172
+ dots.masked_fill_(~mask, mask_value)
173
+ del mask
174
+
175
+ if self.causal:
176
+ i, j = dots.shape[-2:]
177
+ mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
178
+ dots.masked_fill_(mask, mask_value)
179
+
180
+ attn = softmax(dots, dim=-1)
181
+
182
+ out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
183
+ out = rearrange(out, "b h n d -> b n (h d)")
184
+ out = self.to_out(out)
185
+ return out
186
+
187
+
188
+ # main transformer class
189
+ class Transformer(nn.Module):
190
+ def __init__(
191
+ self,
192
+ *,
193
+ dim,
194
+ depth,
195
+ seq_len,
196
+ causal=True,
197
+ heads=8,
198
+ dim_head=64,
199
+ ff_mult=4,
200
+ attn_dropout=0.0,
201
+ ff_dropout=0.0,
202
+ sparse_attn=False,
203
+ sandwich_norm=False,
204
+ ):
205
+ super().__init__()
206
+ layers = nn.ModuleList([])
207
+ sparse_layer = cast_tuple(sparse_attn, depth)
208
+
209
+ for ind, sparse_attn in zip(range(depth), sparse_layer):
210
+ attn = Attention(
211
+ dim,
212
+ causal=causal,
213
+ seq_len=seq_len,
214
+ heads=heads,
215
+ dim_head=dim_head,
216
+ dropout=attn_dropout,
217
+ )
218
+
219
+ ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
220
+
221
+ layers.append(
222
+ nn.ModuleList(
223
+ [
224
+ LayerScale(
225
+ dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)
226
+ ),
227
+ LayerScale(
228
+ dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)
229
+ ),
230
+ ]
231
+ )
232
+ )
233
+
234
+ execute_type = SequentialSequence
235
+ route_attn = ((True, False),) * depth
236
+ attn_route_map = {"mask": route_attn}
237
+
238
+ self.layers = execute_type(layers, args_route=attn_route_map)
239
+
240
+ def forward(self, x, **kwargs):
241
+ return self.layers(x, **kwargs)
tortoise/models/vocoder.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ MAX_WAV_VALUE = 32768.0
6
+
7
+
8
+ class KernelPredictor(torch.nn.Module):
9
+ """Kernel predictor for the location-variable convolutions"""
10
+
11
+ def __init__(
12
+ self,
13
+ cond_channels,
14
+ conv_in_channels,
15
+ conv_out_channels,
16
+ conv_layers,
17
+ conv_kernel_size=3,
18
+ kpnet_hidden_channels=64,
19
+ kpnet_conv_size=3,
20
+ kpnet_dropout=0.0,
21
+ kpnet_nonlinear_activation="LeakyReLU",
22
+ kpnet_nonlinear_activation_params={"negative_slope": 0.1},
23
+ ):
24
+ """
25
+ Args:
26
+ cond_channels (int): number of channel for the conditioning sequence,
27
+ conv_in_channels (int): number of channel for the input sequence,
28
+ conv_out_channels (int): number of channel for the output sequence,
29
+ conv_layers (int): number of layers
30
+ """
31
+ super().__init__()
32
+
33
+ self.conv_in_channels = conv_in_channels
34
+ self.conv_out_channels = conv_out_channels
35
+ self.conv_kernel_size = conv_kernel_size
36
+ self.conv_layers = conv_layers
37
+
38
+ kpnet_kernel_channels = (
39
+ conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
40
+ ) # l_w
41
+ kpnet_bias_channels = conv_out_channels * conv_layers # l_b
42
+
43
+ self.input_conv = nn.Sequential(
44
+ nn.utils.weight_norm(
45
+ nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
46
+ ),
47
+ getattr(nn, kpnet_nonlinear_activation)(
48
+ **kpnet_nonlinear_activation_params
49
+ ),
50
+ )
51
+
52
+ self.residual_convs = nn.ModuleList()
53
+ padding = (kpnet_conv_size - 1) // 2
54
+ for _ in range(3):
55
+ self.residual_convs.append(
56
+ nn.Sequential(
57
+ nn.Dropout(kpnet_dropout),
58
+ nn.utils.weight_norm(
59
+ nn.Conv1d(
60
+ kpnet_hidden_channels,
61
+ kpnet_hidden_channels,
62
+ kpnet_conv_size,
63
+ padding=padding,
64
+ bias=True,
65
+ )
66
+ ),
67
+ getattr(nn, kpnet_nonlinear_activation)(
68
+ **kpnet_nonlinear_activation_params
69
+ ),
70
+ nn.utils.weight_norm(
71
+ nn.Conv1d(
72
+ kpnet_hidden_channels,
73
+ kpnet_hidden_channels,
74
+ kpnet_conv_size,
75
+ padding=padding,
76
+ bias=True,
77
+ )
78
+ ),
79
+ getattr(nn, kpnet_nonlinear_activation)(
80
+ **kpnet_nonlinear_activation_params
81
+ ),
82
+ )
83
+ )
84
+ self.kernel_conv = nn.utils.weight_norm(
85
+ nn.Conv1d(
86
+ kpnet_hidden_channels,
87
+ kpnet_kernel_channels,
88
+ kpnet_conv_size,
89
+ padding=padding,
90
+ bias=True,
91
+ )
92
+ )
93
+ self.bias_conv = nn.utils.weight_norm(
94
+ nn.Conv1d(
95
+ kpnet_hidden_channels,
96
+ kpnet_bias_channels,
97
+ kpnet_conv_size,
98
+ padding=padding,
99
+ bias=True,
100
+ )
101
+ )
102
+
103
+ def forward(self, c):
104
+ """
105
+ Args:
106
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
107
+ """
108
+ batch, _, cond_length = c.shape
109
+ c = self.input_conv(c)
110
+ for residual_conv in self.residual_convs:
111
+ residual_conv.to(c.device)
112
+ c = c + residual_conv(c)
113
+ k = self.kernel_conv(c)
114
+ b = self.bias_conv(c)
115
+ kernels = k.contiguous().view(
116
+ batch,
117
+ self.conv_layers,
118
+ self.conv_in_channels,
119
+ self.conv_out_channels,
120
+ self.conv_kernel_size,
121
+ cond_length,
122
+ )
123
+ bias = b.contiguous().view(
124
+ batch,
125
+ self.conv_layers,
126
+ self.conv_out_channels,
127
+ cond_length,
128
+ )
129
+
130
+ return kernels, bias
131
+
132
+ def remove_weight_norm(self):
133
+ nn.utils.remove_weight_norm(self.input_conv[0])
134
+ nn.utils.remove_weight_norm(self.kernel_conv)
135
+ nn.utils.remove_weight_norm(self.bias_conv)
136
+ for block in self.residual_convs:
137
+ nn.utils.remove_weight_norm(block[1])
138
+ nn.utils.remove_weight_norm(block[3])
139
+
140
+
141
+ class LVCBlock(torch.nn.Module):
142
+ """the location-variable convolutions"""
143
+
144
+ def __init__(
145
+ self,
146
+ in_channels,
147
+ cond_channels,
148
+ stride,
149
+ dilations=[1, 3, 9, 27],
150
+ lReLU_slope=0.2,
151
+ conv_kernel_size=3,
152
+ cond_hop_length=256,
153
+ kpnet_hidden_channels=64,
154
+ kpnet_conv_size=3,
155
+ kpnet_dropout=0.0,
156
+ ):
157
+ super().__init__()
158
+
159
+ self.cond_hop_length = cond_hop_length
160
+ self.conv_layers = len(dilations)
161
+ self.conv_kernel_size = conv_kernel_size
162
+
163
+ self.kernel_predictor = KernelPredictor(
164
+ cond_channels=cond_channels,
165
+ conv_in_channels=in_channels,
166
+ conv_out_channels=2 * in_channels,
167
+ conv_layers=len(dilations),
168
+ conv_kernel_size=conv_kernel_size,
169
+ kpnet_hidden_channels=kpnet_hidden_channels,
170
+ kpnet_conv_size=kpnet_conv_size,
171
+ kpnet_dropout=kpnet_dropout,
172
+ kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
173
+ )
174
+
175
+ self.convt_pre = nn.Sequential(
176
+ nn.LeakyReLU(lReLU_slope),
177
+ nn.utils.weight_norm(
178
+ nn.ConvTranspose1d(
179
+ in_channels,
180
+ in_channels,
181
+ 2 * stride,
182
+ stride=stride,
183
+ padding=stride // 2 + stride % 2,
184
+ output_padding=stride % 2,
185
+ )
186
+ ),
187
+ )
188
+
189
+ self.conv_blocks = nn.ModuleList()
190
+ for dilation in dilations:
191
+ self.conv_blocks.append(
192
+ nn.Sequential(
193
+ nn.LeakyReLU(lReLU_slope),
194
+ nn.utils.weight_norm(
195
+ nn.Conv1d(
196
+ in_channels,
197
+ in_channels,
198
+ conv_kernel_size,
199
+ padding=dilation * (conv_kernel_size - 1) // 2,
200
+ dilation=dilation,
201
+ )
202
+ ),
203
+ nn.LeakyReLU(lReLU_slope),
204
+ )
205
+ )
206
+
207
+ def forward(self, x, c):
208
+ """forward propagation of the location-variable convolutions.
209
+ Args:
210
+ x (Tensor): the input sequence (batch, in_channels, in_length)
211
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
212
+
213
+ Returns:
214
+ Tensor: the output sequence (batch, in_channels, in_length)
215
+ """
216
+ _, in_channels, _ = x.shape # (B, c_g, L')
217
+
218
+ x = self.convt_pre(x) # (B, c_g, stride * L')
219
+ kernels, bias = self.kernel_predictor(c)
220
+
221
+ for i, conv in enumerate(self.conv_blocks):
222
+ output = conv(x) # (B, c_g, stride * L')
223
+
224
+ k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
225
+ b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
226
+
227
+ output = self.location_variable_convolution(
228
+ output, k, b, hop_size=self.cond_hop_length
229
+ ) # (B, 2 * c_g, stride * L'): LVC
230
+ x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
231
+ output[:, in_channels:, :]
232
+ ) # (B, c_g, stride * L'): GAU
233
+
234
+ return x
235
+
236
+ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
237
+ """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
238
+ Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
239
+ Args:
240
+ x (Tensor): the input sequence (batch, in_channels, in_length).
241
+ kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
242
+ bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
243
+ dilation (int): the dilation of convolution.
244
+ hop_size (int): the hop_size of the conditioning sequence.
245
+ Returns:
246
+ (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
247
+ """
248
+ batch, _, in_length = x.shape
249
+ batch, _, out_channels, kernel_size, kernel_length = kernel.shape
250
+ assert in_length == (
251
+ kernel_length * hop_size
252
+ ), "length of (x, kernel) is not matched"
253
+
254
+ padding = dilation * int((kernel_size - 1) / 2)
255
+ x = F.pad(
256
+ x, (padding, padding), "constant", 0
257
+ ) # (batch, in_channels, in_length + 2*padding)
258
+ x = x.unfold(
259
+ 2, hop_size + 2 * padding, hop_size
260
+ ) # (batch, in_channels, kernel_length, hop_size + 2*padding)
261
+
262
+ if hop_size < dilation:
263
+ x = F.pad(x, (0, dilation), "constant", 0)
264
+ x = x.unfold(
265
+ 3, dilation, dilation
266
+ ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
267
+ x = x[:, :, :, :, :hop_size]
268
+ x = x.transpose(
269
+ 3, 4
270
+ ) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
271
+ x = x.unfold(
272
+ 4, kernel_size, 1
273
+ ) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
274
+
275
+ o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
276
+ o = o.to(memory_format=torch.channels_last_3d)
277
+ bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
278
+ o = o + bias
279
+ o = o.contiguous().view(batch, out_channels, -1)
280
+
281
+ return o
282
+
283
+ def remove_weight_norm(self):
284
+ self.kernel_predictor.remove_weight_norm()
285
+ nn.utils.remove_weight_norm(self.convt_pre[1])
286
+ for block in self.conv_blocks:
287
+ nn.utils.remove_weight_norm(block[1])
288
+
289
+
290
+ class UnivNetGenerator(nn.Module):
291
+ """UnivNet Generator"""
292
+
293
+ def __init__(
294
+ self,
295
+ noise_dim=64,
296
+ channel_size=32,
297
+ dilations=[1, 3, 9, 27],
298
+ strides=[8, 8, 4],
299
+ lReLU_slope=0.2,
300
+ kpnet_conv_size=3,
301
+ # Below are MEL configurations options that this generator requires.
302
+ hop_length=256,
303
+ n_mel_channels=100,
304
+ ):
305
+ super(UnivNetGenerator, self).__init__()
306
+ self.mel_channel = n_mel_channels
307
+ self.noise_dim = noise_dim
308
+ self.hop_length = hop_length
309
+ channel_size = channel_size
310
+ kpnet_conv_size = kpnet_conv_size
311
+
312
+ self.res_stack = nn.ModuleList()
313
+ hop_length = 1
314
+ for stride in strides:
315
+ hop_length = stride * hop_length
316
+ self.res_stack.append(
317
+ LVCBlock(
318
+ channel_size,
319
+ n_mel_channels,
320
+ stride=stride,
321
+ dilations=dilations,
322
+ lReLU_slope=lReLU_slope,
323
+ cond_hop_length=hop_length,
324
+ kpnet_conv_size=kpnet_conv_size,
325
+ )
326
+ )
327
+
328
+ self.conv_pre = nn.utils.weight_norm(
329
+ nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
330
+ )
331
+
332
+ self.conv_post = nn.Sequential(
333
+ nn.LeakyReLU(lReLU_slope),
334
+ nn.utils.weight_norm(
335
+ nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")
336
+ ),
337
+ nn.Tanh(),
338
+ )
339
+
340
+ def forward(self, c, z):
341
+ """
342
+ Args:
343
+ c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
344
+ z (Tensor): the noise sequence (batch, noise_dim, in_length)
345
+
346
+ """
347
+ z = self.conv_pre(z) # (B, c_g, L)
348
+
349
+ for res_block in self.res_stack:
350
+ res_block.to(z.device)
351
+ z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
352
+
353
+ z = self.conv_post(z) # (B, 1, L * 256)
354
+
355
+ return z
356
+
357
+ def eval(self, inference=False):
358
+ super(UnivNetGenerator, self).eval()
359
+ # don't remove weight norm while validation in training loop
360
+ if inference:
361
+ self.remove_weight_norm()
362
+
363
+ def remove_weight_norm(self):
364
+ nn.utils.remove_weight_norm(self.conv_pre)
365
+
366
+ for layer in self.conv_post:
367
+ if len(layer.state_dict()) != 0:
368
+ nn.utils.remove_weight_norm(layer)
369
+
370
+ for res_block in self.res_stack:
371
+ res_block.remove_weight_norm()
372
+
373
+ def inference(self, c, z=None):
374
+ # pad input mel with zeros to cut artifact
375
+ # see https://github.com/seungwonpark/melgan/issues/8
376
+ zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
377
+ mel = torch.cat((c, zero), dim=2)
378
+
379
+ if z is None:
380
+ z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
381
+
382
+ audio = self.forward(mel, z)
383
+ audio = audio[:, :, : -(self.hop_length * 10)]
384
+ audio = audio.clamp(min=-1, max=1)
385
+ return audio
386
+
387
+
388
+ if __name__ == "__main__":
389
+ model = UnivNetGenerator()
390
+
391
+ c = torch.randn(3, 100, 10)
392
+ z = torch.randn(3, 64, 10)
393
+ print(c.shape)
394
+
395
+ y = model(c, z)
396
+ print(y.shape)
397
+ assert y.shape == torch.Size([3, 1, 2560])
398
+
399
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
400
+ print(pytorch_total_params)
tortoise/models/xtransformers.py ADDED
@@ -0,0 +1,1432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import namedtuple
3
+ from functools import partial
4
+ from inspect import isfunction
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from torch import nn, einsum
10
+
11
+ DEFAULT_DIM_HEAD = 64
12
+
13
+ Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
14
+
15
+ LayerIntermediates = namedtuple(
16
+ "Intermediates",
17
+ [
18
+ "hiddens",
19
+ "attn_intermediates",
20
+ "past_key_values",
21
+ ],
22
+ )
23
+
24
+
25
+ # helpers
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ if exists(val):
34
+ return val
35
+ return d() if isfunction(d) else d
36
+
37
+
38
+ def cast_tuple(val, depth):
39
+ return val if isinstance(val, tuple) else (val,) * depth
40
+
41
+
42
+ class always:
43
+ def __init__(self, val):
44
+ self.val = val
45
+
46
+ def __call__(self, *args, **kwargs):
47
+ return self.val
48
+
49
+
50
+ class not_equals:
51
+ def __init__(self, val):
52
+ self.val = val
53
+
54
+ def __call__(self, x, *args, **kwargs):
55
+ return x != self.val
56
+
57
+
58
+ class equals:
59
+ def __init__(self, val):
60
+ self.val = val
61
+
62
+ def __call__(self, x, *args, **kwargs):
63
+ return x == self.val
64
+
65
+
66
+ def max_neg_value(tensor):
67
+ return -torch.finfo(tensor.dtype).max
68
+
69
+
70
+ def l2norm(t):
71
+ return F.normalize(t, p=2, dim=-1)
72
+
73
+
74
+ # init helpers
75
+
76
+
77
+ def init_zero_(layer):
78
+ nn.init.constant_(layer.weight, 0.0)
79
+ if exists(layer.bias):
80
+ nn.init.constant_(layer.bias, 0.0)
81
+
82
+
83
+ # keyword argument helpers
84
+
85
+
86
+ def pick_and_pop(keys, d):
87
+ values = list(map(lambda key: d.pop(key), keys))
88
+ return dict(zip(keys, values))
89
+
90
+
91
+ def group_dict_by_key(cond, d):
92
+ return_val = [dict(), dict()]
93
+ for key in d.keys():
94
+ match = bool(cond(key))
95
+ ind = int(not match)
96
+ return_val[ind][key] = d[key]
97
+ return (*return_val,)
98
+
99
+
100
+ def string_begins_with(prefix, str):
101
+ return str.startswith(prefix)
102
+
103
+
104
+ def group_by_key_prefix(prefix, d):
105
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
106
+
107
+
108
+ def groupby_prefix_and_trim(prefix, d):
109
+ kwargs_with_prefix, kwargs = group_dict_by_key(
110
+ partial(string_begins_with, prefix), d
111
+ )
112
+ kwargs_without_prefix = dict(
113
+ map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
114
+ )
115
+ return kwargs_without_prefix, kwargs
116
+
117
+
118
+ # activations
119
+
120
+
121
+ class ReluSquared(nn.Module):
122
+ def forward(self, x):
123
+ return F.relu(x) ** 2
124
+
125
+
126
+ # positional embeddings
127
+
128
+
129
+ class AbsolutePositionalEmbedding(nn.Module):
130
+ def __init__(self, dim, max_seq_len):
131
+ super().__init__()
132
+ self.scale = dim**-0.5
133
+ self.emb = nn.Embedding(max_seq_len, dim)
134
+
135
+ def forward(self, x):
136
+ n = torch.arange(x.shape[1], device=x.device)
137
+ pos_emb = self.emb(n)
138
+ pos_emb = rearrange(pos_emb, "n d -> () n d")
139
+ return pos_emb * self.scale
140
+
141
+
142
+ class FixedPositionalEmbedding(nn.Module):
143
+ def __init__(self, dim):
144
+ super().__init__()
145
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
146
+ self.register_buffer("inv_freq", inv_freq)
147
+
148
+ def forward(self, x, seq_dim=1, offset=0):
149
+ t = (
150
+ torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
151
+ + offset
152
+ )
153
+ sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
154
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
155
+ return rearrange(emb, "n d -> () n d")
156
+
157
+
158
+ class RelativePositionBias(nn.Module):
159
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
160
+ super().__init__()
161
+ self.scale = scale
162
+ self.causal = causal
163
+ self.num_buckets = num_buckets
164
+ self.max_distance = max_distance
165
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
166
+
167
+ @staticmethod
168
+ def _relative_position_bucket(
169
+ relative_position, causal=True, num_buckets=32, max_distance=128
170
+ ):
171
+ ret = 0
172
+ n = -relative_position
173
+ if not causal:
174
+ num_buckets //= 2
175
+ ret += (n < 0).long() * num_buckets
176
+ n = torch.abs(n)
177
+ else:
178
+ n = torch.max(n, torch.zeros_like(n))
179
+
180
+ max_exact = num_buckets // 2
181
+ is_small = n < max_exact
182
+
183
+ val_if_large = (
184
+ max_exact
185
+ + (
186
+ torch.log(n.float() / max_exact)
187
+ / math.log(max_distance / max_exact)
188
+ * (num_buckets - max_exact)
189
+ ).long()
190
+ )
191
+ val_if_large = torch.min(
192
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
193
+ )
194
+
195
+ ret += torch.where(is_small, n, val_if_large)
196
+ return ret
197
+
198
+ def forward(self, qk_dots):
199
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
200
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
201
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
202
+ rel_pos = k_pos[None, :] - q_pos[:, None]
203
+ rp_bucket = self._relative_position_bucket(
204
+ rel_pos,
205
+ causal=self.causal,
206
+ num_buckets=self.num_buckets,
207
+ max_distance=self.max_distance,
208
+ )
209
+ values = self.relative_attention_bias(rp_bucket)
210
+ bias = rearrange(values, "i j h -> () h i j")
211
+ return qk_dots + (bias * self.scale)
212
+
213
+
214
+ class AlibiPositionalBias(nn.Module):
215
+ def __init__(self, heads, **kwargs):
216
+ super().__init__()
217
+ self.heads = heads
218
+ slopes = torch.Tensor(self._get_slopes(heads))
219
+ slopes = rearrange(slopes, "h -> () h () ()")
220
+ self.register_buffer("slopes", slopes, persistent=False)
221
+ self.register_buffer("bias", None, persistent=False)
222
+
223
+ @staticmethod
224
+ def _get_slopes(heads):
225
+ def get_slopes_power_of_2(n):
226
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
227
+ ratio = start
228
+ return [start * ratio**i for i in range(n)]
229
+
230
+ if math.log2(heads).is_integer():
231
+ return get_slopes_power_of_2(heads)
232
+
233
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
234
+ return (
235
+ get_slopes_power_of_2(closest_power_of_2)
236
+ + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
237
+ : heads - closest_power_of_2
238
+ ]
239
+ )
240
+
241
+ def forward(self, qk_dots):
242
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
243
+
244
+ if exists(self.bias) and self.bias.shape[-1] >= j:
245
+ return qk_dots + self.bias[..., :j]
246
+
247
+ bias = torch.arange(j, device=device)
248
+ bias = rearrange(bias, "j -> () () () j")
249
+ bias = bias * self.slopes
250
+
251
+ num_heads_unalibied = h - bias.shape[1]
252
+ bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
253
+
254
+ self.register_buffer("bias", bias, persistent=False)
255
+ return qk_dots + self.bias
256
+
257
+
258
+ class LearnedAlibiPositionalBias(AlibiPositionalBias):
259
+ def __init__(self, heads, bidirectional=False):
260
+ super().__init__(heads)
261
+ los_slopes = torch.log(self.slopes)
262
+ self.learned_logslopes = nn.Parameter(los_slopes)
263
+
264
+ self.bidirectional = bidirectional
265
+ if self.bidirectional:
266
+ self.learned_logslopes_future = nn.Parameter(los_slopes)
267
+
268
+ def forward(self, qk_dots):
269
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
270
+
271
+ def get_slopes(param):
272
+ return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
273
+
274
+ if exists(self.bias) and self.bias.shape[-1] >= j:
275
+ bias = self.bias[..., :i, :j]
276
+ else:
277
+ i_arange = torch.arange(i, device=device)
278
+ j_arange = torch.arange(j, device=device)
279
+ bias = rearrange(j_arange, "j -> 1 1 1 j") - rearrange(
280
+ i_arange, "i -> 1 1 i 1"
281
+ )
282
+ self.register_buffer("bias", bias, persistent=False)
283
+
284
+ if self.bidirectional:
285
+ past_slopes = get_slopes(self.learned_logslopes)
286
+ future_slopes = get_slopes(self.learned_logslopes_future)
287
+ bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
288
+ else:
289
+ slopes = get_slopes(self.learned_logslopes)
290
+ bias = bias * slopes
291
+
292
+ return qk_dots + bias
293
+
294
+
295
+ class RotaryEmbedding(nn.Module):
296
+ def __init__(self, dim):
297
+ super().__init__()
298
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
299
+ self.register_buffer("inv_freq", inv_freq)
300
+
301
+ def forward(self, max_seq_len, device):
302
+ t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
303
+ freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
304
+ emb = torch.cat((freqs, freqs), dim=-1)
305
+ return rearrange(emb, "n d -> () () n d")
306
+
307
+
308
+ def rotate_half(x):
309
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
310
+ x1, x2 = x.unbind(dim=-2)
311
+ return torch.cat((-x2, x1), dim=-1)
312
+
313
+
314
+ def apply_rotary_pos_emb(t, freqs):
315
+ seq_len = t.shape[-2]
316
+ freqs = freqs[:, :, -seq_len:]
317
+ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
318
+
319
+
320
+ # norms
321
+
322
+
323
+ class Scale(nn.Module):
324
+ def __init__(self, value, fn):
325
+ super().__init__()
326
+ self.value = value
327
+ self.fn = fn
328
+
329
+ def forward(self, x, **kwargs):
330
+ out = self.fn(x, **kwargs)
331
+ scale_fn = lambda t: t * self.value
332
+
333
+ if not isinstance(out, tuple):
334
+ return scale_fn(out)
335
+
336
+ return (scale_fn(out[0]), *out[1:])
337
+
338
+
339
+ class Rezero(nn.Module):
340
+ def __init__(self, fn):
341
+ super().__init__()
342
+ self.fn = fn
343
+ self.g = nn.Parameter(torch.zeros(1))
344
+
345
+ def forward(self, x, **kwargs):
346
+ out = self.fn(x, **kwargs)
347
+ rezero_fn = lambda t: t * self.g
348
+
349
+ if not isinstance(out, tuple):
350
+ return rezero_fn(out)
351
+
352
+ return (rezero_fn(out[0]), *out[1:])
353
+
354
+
355
+ class ScaleNorm(nn.Module):
356
+ def __init__(self, dim, eps=1e-5):
357
+ super().__init__()
358
+ self.scale = dim**-0.5
359
+ self.eps = eps
360
+ self.g = nn.Parameter(torch.ones(1))
361
+
362
+ def forward(self, x):
363
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
364
+ return x / norm.clamp(min=self.eps) * self.g
365
+
366
+
367
+ class RMSNorm(nn.Module):
368
+ def __init__(self, dim, eps=1e-8):
369
+ super().__init__()
370
+ self.scale = dim**-0.5
371
+ self.eps = eps
372
+ self.g = nn.Parameter(torch.ones(dim))
373
+
374
+ def forward(self, x):
375
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
376
+ return x / norm.clamp(min=self.eps) * self.g
377
+
378
+
379
+ class RMSScaleShiftNorm(nn.Module):
380
+ def __init__(self, dim, eps=1e-8):
381
+ super().__init__()
382
+ self.scale = dim**-0.5
383
+ self.eps = eps
384
+ self.g = nn.Parameter(torch.ones(dim))
385
+ self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
386
+
387
+ def forward(self, x, norm_scale_shift_inp):
388
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
389
+ norm = x / norm.clamp(min=self.eps) * self.g
390
+
391
+ ss_emb = self.scale_shift_process(norm_scale_shift_inp)
392
+ scale, shift = torch.chunk(ss_emb, 2, dim=1)
393
+ h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
394
+ return h
395
+
396
+
397
+ # residual and residual gates
398
+
399
+
400
+ class Residual(nn.Module):
401
+ def __init__(self, dim, scale_residual=False):
402
+ super().__init__()
403
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
404
+
405
+ def forward(self, x, residual):
406
+ if exists(self.residual_scale):
407
+ residual = residual * self.residual_scale
408
+
409
+ return x + residual
410
+
411
+
412
+ class GRUGating(nn.Module):
413
+ def __init__(self, dim, scale_residual=False):
414
+ super().__init__()
415
+ self.gru = nn.GRUCell(dim, dim)
416
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
417
+
418
+ def forward(self, x, residual):
419
+ if exists(self.residual_scale):
420
+ residual = residual * self.residual_scale
421
+
422
+ gated_output = self.gru(
423
+ rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")
424
+ )
425
+
426
+ return gated_output.reshape_as(x)
427
+
428
+
429
+ # token shifting
430
+
431
+
432
+ def shift(t, amount, mask=None):
433
+ if amount == 0:
434
+ return t
435
+
436
+ if exists(mask):
437
+ t = t.masked_fill(~mask[..., None], 0.0)
438
+
439
+ return F.pad(t, (0, 0, amount, -amount), value=0.0)
440
+
441
+
442
+ class ShiftTokens(nn.Module):
443
+ def __init__(self, shifts, fn):
444
+ super().__init__()
445
+ self.fn = fn
446
+ self.shifts = tuple(shifts)
447
+
448
+ def forward(self, x, **kwargs):
449
+ mask = kwargs.get("mask", None)
450
+ shifts = self.shifts
451
+ segments = len(shifts)
452
+ feats_per_shift = x.shape[-1] // segments
453
+ splitted = x.split(feats_per_shift, dim=-1)
454
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
455
+ segments_to_shift = list(
456
+ map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))
457
+ )
458
+ x = torch.cat((*segments_to_shift, *rest), dim=-1)
459
+ return self.fn(x, **kwargs)
460
+
461
+
462
+ # feedforward
463
+
464
+
465
+ class GLU(nn.Module):
466
+ def __init__(self, dim_in, dim_out, activation):
467
+ super().__init__()
468
+ self.act = activation
469
+ self.proj = nn.Linear(dim_in, dim_out * 2)
470
+
471
+ def forward(self, x):
472
+ x, gate = self.proj(x).chunk(2, dim=-1)
473
+ return x * self.act(gate)
474
+
475
+
476
+ class FeedForward(nn.Module):
477
+ def __init__(
478
+ self,
479
+ dim,
480
+ dim_out=None,
481
+ mult=4,
482
+ glu=False,
483
+ relu_squared=False,
484
+ post_act_ln=False,
485
+ dropout=0.0,
486
+ zero_init_output=False,
487
+ ):
488
+ super().__init__()
489
+ inner_dim = int(dim * mult)
490
+ dim_out = default(dim_out, dim)
491
+ activation = ReluSquared() if relu_squared else nn.GELU()
492
+
493
+ project_in = (
494
+ nn.Sequential(nn.Linear(dim, inner_dim), activation)
495
+ if not glu
496
+ else GLU(dim, inner_dim, activation)
497
+ )
498
+
499
+ self.net = nn.Sequential(
500
+ project_in,
501
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
502
+ nn.Dropout(dropout),
503
+ nn.Linear(inner_dim, dim_out),
504
+ )
505
+
506
+ # init last linear layer to 0
507
+ if zero_init_output:
508
+ init_zero_(self.net[-1])
509
+
510
+ def forward(self, x):
511
+ return self.net(x)
512
+
513
+
514
+ # attention.
515
+
516
+
517
+ class Attention(nn.Module):
518
+ def __init__(
519
+ self,
520
+ dim,
521
+ dim_head=DEFAULT_DIM_HEAD,
522
+ heads=8,
523
+ causal=False,
524
+ talking_heads=False,
525
+ head_scale=False,
526
+ collab_heads=False,
527
+ collab_compression=0.3,
528
+ sparse_topk=None,
529
+ use_entmax15=False,
530
+ num_mem_kv=0,
531
+ dropout=0.0,
532
+ on_attn=False,
533
+ gate_values=False,
534
+ zero_init_output=False,
535
+ max_attend_past=None,
536
+ qk_norm=False,
537
+ scale_init_value=None,
538
+ rel_pos_bias=False,
539
+ rel_pos_num_buckets=32,
540
+ rel_pos_max_distance=128,
541
+ ):
542
+ super().__init__()
543
+ self.scale = dim_head**-0.5
544
+
545
+ self.heads = heads
546
+ self.causal = causal
547
+ self.max_attend_past = max_attend_past
548
+
549
+ qk_dim = v_dim = dim_head * heads
550
+
551
+ # collaborative heads
552
+ self.collab_heads = collab_heads
553
+ if self.collab_heads:
554
+ qk_dim = int(collab_compression * qk_dim)
555
+ self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
556
+
557
+ self.to_q = nn.Linear(dim, qk_dim, bias=False)
558
+ self.to_k = nn.Linear(dim, qk_dim, bias=False)
559
+ self.to_v = nn.Linear(dim, v_dim, bias=False)
560
+
561
+ self.dropout = nn.Dropout(dropout)
562
+
563
+ # add GLU gating for aggregated values, from alphafold2
564
+ self.to_v_gate = None
565
+ if gate_values:
566
+ self.to_v_gate = nn.Linear(dim, v_dim)
567
+ nn.init.constant_(self.to_v_gate.weight, 0)
568
+ nn.init.constant_(self.to_v_gate.bias, 1)
569
+
570
+ # cosine sim attention
571
+ self.qk_norm = qk_norm
572
+ if qk_norm:
573
+ scale_init_value = default(
574
+ scale_init_value, -3
575
+ ) # if not provided, initialize as though it were sequence length of 1024
576
+ self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
577
+
578
+ # talking heads
579
+ self.talking_heads = talking_heads
580
+ if talking_heads:
581
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
582
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
583
+
584
+ # head scaling
585
+ self.head_scale = head_scale
586
+ if head_scale:
587
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
588
+
589
+ # explicit topk sparse attention
590
+ self.sparse_topk = sparse_topk
591
+
592
+ # entmax
593
+ self.attn_fn = F.softmax
594
+
595
+ # add memory key / values
596
+ self.num_mem_kv = num_mem_kv
597
+ if num_mem_kv > 0:
598
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
599
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
600
+
601
+ # attention on attention
602
+ self.attn_on_attn = on_attn
603
+ self.to_out = (
604
+ nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU())
605
+ if on_attn
606
+ else nn.Linear(v_dim, dim)
607
+ )
608
+
609
+ self.rel_pos_bias = rel_pos_bias
610
+ if rel_pos_bias:
611
+ assert (
612
+ rel_pos_num_buckets <= rel_pos_max_distance
613
+ ), "number of relative position buckets must be less than the relative position max distance"
614
+ self.rel_pos = RelativePositionBias(
615
+ scale=dim_head**0.5,
616
+ causal=causal,
617
+ heads=heads,
618
+ num_buckets=rel_pos_num_buckets,
619
+ max_distance=rel_pos_max_distance,
620
+ )
621
+
622
+ # init output projection 0
623
+ if zero_init_output:
624
+ init_zero_(self.to_out)
625
+
626
+ def forward(
627
+ self,
628
+ x,
629
+ context=None,
630
+ mask=None,
631
+ context_mask=None,
632
+ attn_mask=None,
633
+ sinusoidal_emb=None,
634
+ rotary_pos_emb=None,
635
+ prev_attn=None,
636
+ mem=None,
637
+ layer_past=None,
638
+ ):
639
+ (
640
+ b,
641
+ n,
642
+ _,
643
+ h,
644
+ talking_heads,
645
+ collab_heads,
646
+ head_scale,
647
+ scale,
648
+ device,
649
+ has_context,
650
+ ) = (
651
+ *x.shape,
652
+ self.heads,
653
+ self.talking_heads,
654
+ self.collab_heads,
655
+ self.head_scale,
656
+ self.scale,
657
+ x.device,
658
+ exists(context),
659
+ )
660
+ kv_input = default(context, x)
661
+
662
+ q_input = x
663
+ k_input = kv_input
664
+ v_input = kv_input
665
+
666
+ if exists(mem):
667
+ k_input = torch.cat((mem, k_input), dim=-2)
668
+ v_input = torch.cat((mem, v_input), dim=-2)
669
+
670
+ if exists(sinusoidal_emb):
671
+ # in shortformer, the query would start at a position offset depending on the past cached memory
672
+ offset = k_input.shape[-2] - q_input.shape[-2]
673
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
674
+ k_input = k_input + sinusoidal_emb(k_input)
675
+
676
+ q = self.to_q(q_input)
677
+ k = self.to_k(k_input)
678
+ v = self.to_v(v_input)
679
+
680
+ if not collab_heads:
681
+ q, k, v = map(
682
+ lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)
683
+ )
684
+ else:
685
+ q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
686
+ k = rearrange(k, "b n d -> b () n d")
687
+ v = rearrange(v, "b n (h d) -> b h n d", h=h)
688
+
689
+ if layer_past is not None:
690
+ past_key, past_value = layer_past
691
+ k = torch.cat([past_key, k], dim=-2)
692
+ v = torch.cat([past_value, v], dim=-2)
693
+ k_cache = k
694
+ v_cache = v
695
+
696
+ if exists(rotary_pos_emb) and not has_context:
697
+ l = rotary_pos_emb.shape[-1]
698
+ (ql, qr), (kl, kr), (vl, vr) = map(
699
+ lambda t: (t[..., :l], t[..., l:]), (q, k, v)
700
+ )
701
+ ql, kl, vl = map(
702
+ lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)
703
+ )
704
+ q, k, v = map(
705
+ lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))
706
+ )
707
+
708
+ input_mask = None
709
+ if any(map(exists, (mask, context_mask))):
710
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
711
+ k_mask = q_mask if not exists(context) else context_mask
712
+ k_mask = default(
713
+ k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()
714
+ )
715
+ q_mask = rearrange(q_mask, "b i -> b () i ()")
716
+ k_mask = rearrange(k_mask, "b j -> b () () j")
717
+ input_mask = q_mask * k_mask
718
+
719
+ if self.num_mem_kv > 0:
720
+ mem_k, mem_v = map(
721
+ lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)
722
+ )
723
+ k = torch.cat((mem_k, k), dim=-2)
724
+ v = torch.cat((mem_v, v), dim=-2)
725
+ if exists(input_mask):
726
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
727
+
728
+ if collab_heads:
729
+ k = k.expand(-1, h, -1, -1)
730
+
731
+ if self.qk_norm:
732
+ q, k = map(l2norm, (q, k))
733
+ scale = 1 / (self.scale.exp().clamp(min=1e-2))
734
+
735
+ dots = einsum("b h i d, b h j d -> b h i j", q, k) * scale
736
+ mask_value = max_neg_value(dots)
737
+
738
+ if exists(prev_attn):
739
+ dots = dots + prev_attn
740
+
741
+ pre_softmax_attn = dots.clone()
742
+
743
+ if talking_heads:
744
+ dots = einsum(
745
+ "b h i j, h k -> b k i j", dots, self.pre_softmax_proj
746
+ ).contiguous()
747
+
748
+ if self.rel_pos_bias:
749
+ dots = self.rel_pos(dots)
750
+
751
+ if exists(input_mask):
752
+ dots.masked_fill_(~input_mask, mask_value)
753
+ del input_mask
754
+
755
+ if exists(attn_mask):
756
+ assert (
757
+ 2 <= attn_mask.ndim <= 4
758
+ ), "attention mask must have greater than 2 dimensions but less than or equal to 4"
759
+ if attn_mask.ndim == 2:
760
+ attn_mask = rearrange(attn_mask, "i j -> () () i j")
761
+ elif attn_mask.ndim == 3:
762
+ attn_mask = rearrange(attn_mask, "h i j -> () h i j")
763
+ dots.masked_fill_(~attn_mask, mask_value)
764
+
765
+ if exists(self.max_attend_past):
766
+ i, j = dots.shape[-2:]
767
+ range_q = torch.arange(j - i, j, device=device)
768
+ range_k = torch.arange(j, device=device)
769
+ dist = rearrange(range_q, "i -> () () i ()") - rearrange(
770
+ range_k, "j -> () () () j"
771
+ )
772
+ mask = dist > self.max_attend_past
773
+ dots.masked_fill_(mask, mask_value)
774
+ del mask
775
+
776
+ if self.causal:
777
+ i, j = dots.shape[-2:]
778
+ r = torch.arange(i, device=device)
779
+ mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
780
+ mask = F.pad(mask, (j - i, 0), value=False)
781
+ dots.masked_fill_(mask, mask_value)
782
+ del mask
783
+
784
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
785
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
786
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
787
+ mask = dots < vk
788
+ dots.masked_fill_(mask, mask_value)
789
+ del mask
790
+
791
+ attn = self.attn_fn(dots, dim=-1)
792
+ post_softmax_attn = attn.clone()
793
+
794
+ attn = self.dropout(attn)
795
+
796
+ if talking_heads:
797
+ attn = einsum(
798
+ "b h i j, h k -> b k i j", attn, self.post_softmax_proj
799
+ ).contiguous()
800
+
801
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
802
+
803
+ if head_scale:
804
+ out = out * self.head_scale_params
805
+
806
+ out = rearrange(out, "b h n d -> b n (h d)")
807
+
808
+ if exists(self.to_v_gate):
809
+ gates = self.to_v_gate(x)
810
+ out = out * gates.sigmoid()
811
+
812
+ intermediates = Intermediates(
813
+ pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
814
+ )
815
+
816
+ return self.to_out(out), intermediates, k_cache, v_cache
817
+
818
+
819
+ class AttentionLayers(nn.Module):
820
+ def __init__(
821
+ self,
822
+ dim,
823
+ depth,
824
+ heads=8,
825
+ causal=False,
826
+ cross_attend=False,
827
+ only_cross=False,
828
+ use_scalenorm=False,
829
+ use_rms_scaleshift_norm=False,
830
+ use_rmsnorm=False,
831
+ use_rezero=False,
832
+ alibi_pos_bias=False,
833
+ alibi_num_heads=None,
834
+ alibi_learned=False,
835
+ position_infused_attn=False,
836
+ rotary_pos_emb=False,
837
+ rotary_emb_dim=None,
838
+ custom_layers=None,
839
+ sandwich_coef=None,
840
+ par_ratio=None,
841
+ residual_attn=False,
842
+ cross_residual_attn=False,
843
+ macaron=False,
844
+ pre_norm=True,
845
+ gate_residual=False,
846
+ scale_residual=False,
847
+ shift_tokens=0,
848
+ sandwich_norm=False,
849
+ use_qk_norm_attn=False,
850
+ qk_norm_attn_seq_len=None,
851
+ zero_init_branch_output=False,
852
+ **kwargs,
853
+ ):
854
+ super().__init__()
855
+ ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
856
+ attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
857
+
858
+ dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
859
+
860
+ self.dim = dim
861
+ self.depth = depth
862
+ self.layers = nn.ModuleList([])
863
+ self.causal = causal
864
+
865
+ rel_pos_bias = "rel_pos_bias" in attn_kwargs
866
+ self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
867
+ self.pia_pos_emb = (
868
+ FixedPositionalEmbedding(dim) if position_infused_attn else None
869
+ )
870
+
871
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
872
+ self.rotary_pos_emb = (
873
+ RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
874
+ )
875
+
876
+ assert not (
877
+ alibi_pos_bias and rel_pos_bias
878
+ ), "you can only choose Alibi positional bias or T5 relative positional bias, not both"
879
+
880
+ if alibi_pos_bias:
881
+ alibi_num_heads = default(alibi_num_heads, heads)
882
+ assert (
883
+ alibi_num_heads <= heads
884
+ ), "number of ALiBi heads must be less than the total number of heads"
885
+ alibi_pos_klass = (
886
+ LearnedAlibiPositionalBias
887
+ if alibi_learned or not causal
888
+ else AlibiPositionalBias
889
+ )
890
+ self.rel_pos = alibi_pos_klass(
891
+ heads=alibi_num_heads, bidirectional=not causal
892
+ )
893
+ else:
894
+ self.rel_pos = None
895
+
896
+ assert not (
897
+ not pre_norm and sandwich_norm
898
+ ), "sandwich norm cannot be used when not using prenorm"
899
+ self.pre_norm = pre_norm
900
+ self.sandwich_norm = sandwich_norm
901
+
902
+ self.residual_attn = residual_attn
903
+ self.cross_residual_attn = cross_residual_attn
904
+ self.cross_attend = cross_attend
905
+
906
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
907
+ norm_class = RMSNorm if use_rmsnorm else norm_class
908
+ norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
909
+ norm_fn = partial(norm_class, dim)
910
+
911
+ norm_fn = nn.Identity if use_rezero else norm_fn
912
+ branch_fn = Rezero if use_rezero else None
913
+
914
+ if cross_attend and not only_cross:
915
+ default_block = ("a", "c", "f")
916
+ elif cross_attend and only_cross:
917
+ default_block = ("c", "f")
918
+ else:
919
+ default_block = ("a", "f")
920
+
921
+ if macaron:
922
+ default_block = ("f",) + default_block
923
+
924
+ # qk normalization
925
+
926
+ if use_qk_norm_attn:
927
+ attn_scale_init_value = (
928
+ -math.log(math.log2(qk_norm_attn_seq_len**2 - qk_norm_attn_seq_len))
929
+ if exists(qk_norm_attn_seq_len)
930
+ else None
931
+ )
932
+ attn_kwargs = {
933
+ **attn_kwargs,
934
+ "qk_norm": True,
935
+ "scale_init_value": attn_scale_init_value,
936
+ }
937
+
938
+ # zero init
939
+
940
+ if zero_init_branch_output:
941
+ attn_kwargs = {**attn_kwargs, "zero_init_output": True}
942
+ ff_kwargs = {**ff_kwargs, "zero_init_output": True}
943
+
944
+ # calculate layer block order
945
+
946
+ if exists(custom_layers):
947
+ layer_types = custom_layers
948
+ elif exists(par_ratio):
949
+ par_depth = depth * len(default_block)
950
+ assert 1 < par_ratio <= par_depth, "par ratio out of range"
951
+ default_block = tuple(filter(not_equals("f"), default_block))
952
+ par_attn = par_depth // par_ratio
953
+ depth_cut = (
954
+ par_depth * 2 // 3
955
+ ) # 2 / 3 attention layer cutoff suggested by PAR paper
956
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
957
+ assert (
958
+ len(default_block) <= par_width
959
+ ), "default block is too large for par_ratio"
960
+ par_block = default_block + ("f",) * (par_width - len(default_block))
961
+ par_head = par_block * par_attn
962
+ layer_types = par_head + ("f",) * (par_depth - len(par_head))
963
+ elif exists(sandwich_coef):
964
+ assert (
965
+ sandwich_coef > 0 and sandwich_coef <= depth
966
+ ), "sandwich coefficient should be less than the depth"
967
+ layer_types = (
968
+ ("a",) * sandwich_coef
969
+ + default_block * (depth - sandwich_coef)
970
+ + ("f",) * sandwich_coef
971
+ )
972
+ else:
973
+ layer_types = default_block * depth
974
+
975
+ self.layer_types = layer_types
976
+ self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
977
+
978
+ # calculate token shifting
979
+
980
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
981
+
982
+ # iterate and construct layers
983
+
984
+ for ind, (layer_type, layer_shift_tokens) in enumerate(
985
+ zip(self.layer_types, shift_tokens)
986
+ ):
987
+ is_last_layer = ind == (len(self.layer_types) - 1)
988
+
989
+ if layer_type == "a":
990
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
991
+ elif layer_type == "c":
992
+ layer = Attention(dim, heads=heads, **attn_kwargs)
993
+ elif layer_type == "f":
994
+ layer = FeedForward(dim, **ff_kwargs)
995
+ layer = layer if not macaron else Scale(0.5, layer)
996
+ else:
997
+ raise Exception(f"invalid layer type {layer_type}")
998
+
999
+ if layer_shift_tokens > 0:
1000
+ shift_range_upper = layer_shift_tokens + 1
1001
+ shift_range_lower = -layer_shift_tokens if not causal else 0
1002
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1003
+
1004
+ if exists(branch_fn):
1005
+ layer = branch_fn(layer)
1006
+
1007
+ residual_fn = GRUGating if gate_residual else Residual
1008
+ residual = residual_fn(dim, scale_residual=scale_residual)
1009
+
1010
+ layer_uses_qk_norm = use_qk_norm_attn and layer_type in ("a", "c")
1011
+
1012
+ pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
1013
+ post_branch_norm = (
1014
+ norm_fn() if sandwich_norm or layer_uses_qk_norm else None
1015
+ )
1016
+ post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
1017
+
1018
+ norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm])
1019
+
1020
+ self.layers.append(nn.ModuleList([norms, layer, residual]))
1021
+
1022
+ def forward(
1023
+ self,
1024
+ x,
1025
+ context=None,
1026
+ full_context=None, # for passing a list of hidden states from an encoder
1027
+ mask=None,
1028
+ context_mask=None,
1029
+ attn_mask=None,
1030
+ mems=None,
1031
+ return_hiddens=False,
1032
+ norm_scale_shift_inp=None,
1033
+ past_key_values=None,
1034
+ expected_seq_len=None,
1035
+ ):
1036
+
1037
+ assert not (
1038
+ self.cross_attend ^ (exists(context) or exists(full_context))
1039
+ ), "context must be passed in if cross_attend is set to True"
1040
+ assert (
1041
+ context is None or full_context is None
1042
+ ), "only one of full_context or context can be provided"
1043
+
1044
+ hiddens = []
1045
+ intermediates = []
1046
+ prev_attn = None
1047
+ prev_cross_attn = None
1048
+
1049
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1050
+ norm_args = {}
1051
+ if exists(norm_scale_shift_inp):
1052
+ norm_args["norm_scale_shift_inp"] = norm_scale_shift_inp
1053
+
1054
+ rotary_pos_emb = None
1055
+ if exists(self.rotary_pos_emb):
1056
+ if not self.training and self.causal:
1057
+ assert (
1058
+ expected_seq_len is not None
1059
+ ), "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
1060
+ elif expected_seq_len is None:
1061
+ expected_seq_len = 0
1062
+ seq_len = x.shape[1]
1063
+ if past_key_values is not None:
1064
+ seq_len += past_key_values[0][0].shape[-2]
1065
+ max_rotary_emb_length = max(
1066
+ list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems))
1067
+ + [expected_seq_len]
1068
+ )
1069
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
1070
+
1071
+ present_key_values = []
1072
+ cross_attn_count = 0
1073
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
1074
+ zip(self.layer_types, self.layers)
1075
+ ):
1076
+ if layer_type == "a":
1077
+ layer_mem = mems.pop(0) if mems else None
1078
+
1079
+ residual = x
1080
+
1081
+ pre_branch_norm, post_branch_norm, post_main_norm = norm
1082
+
1083
+ if exists(pre_branch_norm):
1084
+ x = pre_branch_norm(x, **norm_args)
1085
+
1086
+ if layer_type == "a" or layer_type == "c":
1087
+ if past_key_values is not None:
1088
+ layer_kv = past_key_values.pop(0)
1089
+ layer_past = tuple(s.to(x.device) for s in layer_kv)
1090
+ else:
1091
+ layer_past = None
1092
+
1093
+ if layer_type == "a":
1094
+ out, inter, k, v = block(
1095
+ x,
1096
+ None,
1097
+ mask,
1098
+ None,
1099
+ attn_mask,
1100
+ self.pia_pos_emb,
1101
+ rotary_pos_emb,
1102
+ prev_attn,
1103
+ layer_mem,
1104
+ layer_past,
1105
+ )
1106
+ elif layer_type == "c":
1107
+ if exists(full_context):
1108
+ out, inter, k, v = block(
1109
+ x,
1110
+ full_context[cross_attn_count],
1111
+ mask,
1112
+ context_mask,
1113
+ None,
1114
+ None,
1115
+ None,
1116
+ prev_attn,
1117
+ None,
1118
+ layer_past,
1119
+ )
1120
+ else:
1121
+ out, inter, k, v = block(
1122
+ x,
1123
+ context,
1124
+ mask,
1125
+ context_mask,
1126
+ None,
1127
+ None,
1128
+ None,
1129
+ prev_attn,
1130
+ None,
1131
+ layer_past,
1132
+ )
1133
+ elif layer_type == "f":
1134
+ out = block(x)
1135
+
1136
+ if (
1137
+ layer_type == "a"
1138
+ or layer_type == "c"
1139
+ and present_key_values is not None
1140
+ ):
1141
+ present_key_values.append((k.detach(), v.detach()))
1142
+
1143
+ if exists(post_branch_norm):
1144
+ out = post_branch_norm(out, **norm_args)
1145
+
1146
+ x = residual_fn(out, residual)
1147
+
1148
+ if layer_type in ("a", "c"):
1149
+ intermediates.append(inter)
1150
+
1151
+ if layer_type == "a" and self.residual_attn:
1152
+ prev_attn = inter.pre_softmax_attn
1153
+ elif layer_type == "c" and self.cross_residual_attn:
1154
+ prev_cross_attn = inter.pre_softmax_attn
1155
+
1156
+ if exists(post_main_norm):
1157
+ x = post_main_norm(x, **norm_args)
1158
+
1159
+ if layer_type == "c":
1160
+ cross_attn_count += 1
1161
+
1162
+ if layer_type == "f":
1163
+ hiddens.append(x)
1164
+
1165
+ if return_hiddens:
1166
+ intermediates = LayerIntermediates(
1167
+ hiddens=hiddens,
1168
+ attn_intermediates=intermediates,
1169
+ past_key_values=present_key_values,
1170
+ )
1171
+
1172
+ return x, intermediates
1173
+
1174
+ return x
1175
+
1176
+
1177
+ class Encoder(AttentionLayers):
1178
+ def __init__(self, **kwargs):
1179
+ assert "causal" not in kwargs, "cannot set causality on encoder"
1180
+ super().__init__(causal=False, **kwargs)
1181
+
1182
+
1183
+ class Decoder(AttentionLayers):
1184
+ def __init__(self, **kwargs):
1185
+ assert "causal" not in kwargs, "cannot set causality on decoder"
1186
+ super().__init__(causal=True, **kwargs)
1187
+
1188
+
1189
+ class CrossAttender(AttentionLayers):
1190
+ def __init__(self, **kwargs):
1191
+ super().__init__(cross_attend=True, only_cross=True, **kwargs)
1192
+
1193
+
1194
+ class ViTransformerWrapper(nn.Module):
1195
+ def __init__(
1196
+ self,
1197
+ *,
1198
+ image_size,
1199
+ patch_size,
1200
+ attn_layers,
1201
+ num_classes=None,
1202
+ dropout=0.0,
1203
+ emb_dropout=0.0,
1204
+ ):
1205
+ super().__init__()
1206
+ assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder"
1207
+ assert (
1208
+ image_size % patch_size == 0
1209
+ ), "image dimensions must be divisible by the patch size"
1210
+ dim = attn_layers.dim
1211
+ num_patches = (image_size // patch_size) ** 2
1212
+ patch_dim = 3 * patch_size**2
1213
+
1214
+ self.patch_size = patch_size
1215
+
1216
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1217
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
1218
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
1219
+ self.dropout = nn.Dropout(emb_dropout)
1220
+
1221
+ self.attn_layers = attn_layers
1222
+ self.norm = nn.LayerNorm(dim)
1223
+ self.mlp_head = (
1224
+ FeedForward(dim, dim_out=num_classes, dropout=dropout)
1225
+ if exists(num_classes)
1226
+ else None
1227
+ )
1228
+
1229
+ def forward(self, img, return_embeddings=False):
1230
+ p = self.patch_size
1231
+
1232
+ x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p)
1233
+ x = self.patch_to_embedding(x)
1234
+ b, n, _ = x.shape
1235
+
1236
+ cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
1237
+ x = torch.cat((cls_tokens, x), dim=1)
1238
+ x = x + self.pos_embedding[:, : (n + 1)]
1239
+ x = self.dropout(x)
1240
+
1241
+ x = self.attn_layers(x)
1242
+ x = self.norm(x)
1243
+
1244
+ if not exists(self.mlp_head) or return_embeddings:
1245
+ return x
1246
+
1247
+ return self.mlp_head(x[:, 0])
1248
+
1249
+
1250
+ class TransformerWrapper(nn.Module):
1251
+ def __init__(
1252
+ self,
1253
+ *,
1254
+ num_tokens,
1255
+ max_seq_len,
1256
+ attn_layers,
1257
+ emb_dim=None,
1258
+ max_mem_len=0.0,
1259
+ shift_mem_down=0,
1260
+ emb_dropout=0.0,
1261
+ num_memory_tokens=None,
1262
+ tie_embedding=False,
1263
+ use_pos_emb=True,
1264
+ ):
1265
+ super().__init__()
1266
+ assert isinstance(
1267
+ attn_layers, AttentionLayers
1268
+ ), "attention layers must be one of Encoder or Decoder"
1269
+
1270
+ dim = attn_layers.dim
1271
+ emb_dim = default(emb_dim, dim)
1272
+
1273
+ self.max_seq_len = max_seq_len
1274
+ self.max_mem_len = max_mem_len
1275
+ self.shift_mem_down = shift_mem_down
1276
+
1277
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
1278
+ self.pos_emb = (
1279
+ AbsolutePositionalEmbedding(emb_dim, max_seq_len)
1280
+ if (use_pos_emb and not attn_layers.has_pos_emb)
1281
+ else always(0)
1282
+ )
1283
+ self.emb_dropout = nn.Dropout(emb_dropout)
1284
+
1285
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1286
+ self.attn_layers = attn_layers
1287
+ self.norm = nn.LayerNorm(dim)
1288
+
1289
+ self.init_()
1290
+
1291
+ self.to_logits = (
1292
+ nn.Linear(dim, num_tokens)
1293
+ if not tie_embedding
1294
+ else lambda t: t @ self.token_emb.weight.t()
1295
+ )
1296
+
1297
+ # memory tokens (like [cls]) from Memory Transformers paper
1298
+ num_memory_tokens = default(num_memory_tokens, 0)
1299
+ self.num_memory_tokens = num_memory_tokens
1300
+ if num_memory_tokens > 0:
1301
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1302
+
1303
+ def init_(self):
1304
+ nn.init.kaiming_normal_(self.token_emb.weight)
1305
+
1306
+ def forward(
1307
+ self,
1308
+ x,
1309
+ return_embeddings=False,
1310
+ mask=None,
1311
+ return_hiddens=False,
1312
+ return_attn=False,
1313
+ mems=None,
1314
+ use_cache=False,
1315
+ **kwargs,
1316
+ ):
1317
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
1318
+ x = self.token_emb(x)
1319
+ x = x + self.pos_emb(x)
1320
+ x = self.emb_dropout(x)
1321
+
1322
+ x = self.project_emb(x)
1323
+
1324
+ if num_mem > 0:
1325
+ mem = repeat(self.memory_tokens, "n d -> b n d", b=b)
1326
+ x = torch.cat((mem, x), dim=1)
1327
+
1328
+ # auto-handle masking after appending memory tokens
1329
+ if exists(mask):
1330
+ mask = F.pad(mask, (num_mem, 0), value=True)
1331
+
1332
+ if self.shift_mem_down and exists(mems):
1333
+ mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :]
1334
+ mems = [*mems_r, *mems_l]
1335
+
1336
+ x, intermediates = self.attn_layers(
1337
+ x, mask=mask, mems=mems, return_hiddens=True, **kwargs
1338
+ )
1339
+ x = self.norm(x)
1340
+
1341
+ mem, x = x[:, :num_mem], x[:, num_mem:]
1342
+
1343
+ out = self.to_logits(x) if not return_embeddings else x
1344
+
1345
+ if return_hiddens:
1346
+ hiddens = intermediates.hiddens
1347
+ return out, hiddens
1348
+
1349
+ res = [out]
1350
+ if return_attn:
1351
+ attn_maps = list(
1352
+ map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
1353
+ )
1354
+ res.append(attn_maps)
1355
+ if use_cache:
1356
+ res.append(intermediates.past_key_values)
1357
+
1358
+ if len(res) > 1:
1359
+ return tuple(res)
1360
+ return res[0]
1361
+
1362
+
1363
+ class ContinuousTransformerWrapper(nn.Module):
1364
+ def __init__(
1365
+ self,
1366
+ *,
1367
+ max_seq_len,
1368
+ attn_layers,
1369
+ dim_in=None,
1370
+ dim_out=None,
1371
+ emb_dim=None,
1372
+ emb_dropout=0.0,
1373
+ use_pos_emb=True,
1374
+ ):
1375
+ super().__init__()
1376
+ assert isinstance(
1377
+ attn_layers, AttentionLayers
1378
+ ), "attention layers must be one of Encoder or Decoder"
1379
+
1380
+ dim = attn_layers.dim
1381
+
1382
+ self.max_seq_len = max_seq_len
1383
+
1384
+ self.pos_emb = (
1385
+ AbsolutePositionalEmbedding(dim, max_seq_len)
1386
+ if (use_pos_emb and not attn_layers.has_pos_emb)
1387
+ else always(0)
1388
+ )
1389
+ self.emb_dropout = nn.Dropout(emb_dropout)
1390
+
1391
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1392
+
1393
+ self.attn_layers = attn_layers
1394
+ self.norm = nn.LayerNorm(dim)
1395
+
1396
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1397
+
1398
+ def forward(
1399
+ self,
1400
+ x,
1401
+ return_embeddings=False,
1402
+ mask=None,
1403
+ return_attn=False,
1404
+ mems=None,
1405
+ use_cache=False,
1406
+ **kwargs,
1407
+ ):
1408
+ b, n, _, device = *x.shape, x.device
1409
+
1410
+ x = self.project_in(x)
1411
+ x = x + self.pos_emb(x)
1412
+ x = self.emb_dropout(x)
1413
+
1414
+ x, intermediates = self.attn_layers(
1415
+ x, mask=mask, mems=mems, return_hiddens=True, **kwargs
1416
+ )
1417
+ x = self.norm(x)
1418
+
1419
+ out = self.project_out(x) if not return_embeddings else x
1420
+
1421
+ res = [out]
1422
+ if return_attn:
1423
+ attn_maps = list(
1424
+ map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
1425
+ )
1426
+ res.append(attn_maps)
1427
+ if use_cache:
1428
+ res.append(intermediates.past_key_values)
1429
+
1430
+ if len(res) > 1:
1431
+ return tuple(res)
1432
+ return res[0]
tortoise/read.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from time import time
4
+
5
+ import torch
6
+ import torchaudio
7
+
8
+ from api import TextToSpeech, MODELS_DIR
9
+ from utils.audio import load_audio, load_voices
10
+ from utils.text import split_and_recombine_text
11
+
12
+
13
+ if __name__ == "__main__":
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--textfile",
17
+ type=str,
18
+ help="A file containing the text to read.",
19
+ default="tortoise/data/riding_hood.txt",
20
+ )
21
+ parser.add_argument(
22
+ "--voice",
23
+ type=str,
24
+ help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) "
25
+ "Use the & character to join two voices together. Use a comma to perform inference on multiple voices.",
26
+ default="pat",
27
+ )
28
+ parser.add_argument(
29
+ "--output_path",
30
+ type=str,
31
+ help="Where to store outputs.",
32
+ default="results/longform/",
33
+ )
34
+ parser.add_argument(
35
+ "--preset", type=str, help="Which voice preset to use.", default="standard"
36
+ )
37
+ parser.add_argument(
38
+ "--regenerate",
39
+ type=str,
40
+ help="Comma-separated list of clip numbers to re-generate, or nothing.",
41
+ default=None,
42
+ )
43
+ parser.add_argument(
44
+ "--candidates",
45
+ type=int,
46
+ help="How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.",
47
+ default=1,
48
+ )
49
+ parser.add_argument(
50
+ "--model_dir",
51
+ type=str,
52
+ help="Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this"
53
+ "should only be specified if you have custom checkpoints.",
54
+ default=MODELS_DIR,
55
+ )
56
+ parser.add_argument(
57
+ "--seed",
58
+ type=int,
59
+ help="Random seed which can be used to reproduce results.",
60
+ default=None,
61
+ )
62
+ parser.add_argument(
63
+ "--produce_debug_state",
64
+ type=bool,
65
+ help="Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.",
66
+ default=True,
67
+ )
68
+
69
+ args = parser.parse_args()
70
+ tts = TextToSpeech(models_dir=args.model_dir)
71
+
72
+ outpath = args.output_path
73
+ selected_voices = args.voice.split(",")
74
+ regenerate = args.regenerate
75
+ if regenerate is not None:
76
+ regenerate = [int(e) for e in regenerate.split(",")]
77
+
78
+ # Process text
79
+ with open(args.textfile, "r", encoding="utf-8") as f:
80
+ text = " ".join([l for l in f.readlines()])
81
+ if "|" in text:
82
+ print(
83
+ "Found the '|' character in your text, which I will use as a cue for where to split it up. If this was not"
84
+ "your intent, please remove all '|' characters from the input."
85
+ )
86
+ texts = text.split("|")
87
+ else:
88
+ texts = split_and_recombine_text(text)
89
+
90
+ seed = int(time()) if args.seed is None else args.seed
91
+ for selected_voice in selected_voices:
92
+ voice_outpath = os.path.join(outpath, selected_voice)
93
+ os.makedirs(voice_outpath, exist_ok=True)
94
+
95
+ if "&" in selected_voice:
96
+ voice_sel = selected_voice.split("&")
97
+ else:
98
+ voice_sel = [selected_voice]
99
+
100
+ voice_samples, conditioning_latents = load_voices(voice_sel)
101
+ all_parts = []
102
+ for j, text in enumerate(texts):
103
+ if regenerate is not None and j not in regenerate:
104
+ all_parts.append(
105
+ load_audio(os.path.join(voice_outpath, f"{j}.wav"), 24000)
106
+ )
107
+ continue
108
+ gen = tts.tts_with_preset(
109
+ text,
110
+ voice_samples=voice_samples,
111
+ conditioning_latents=conditioning_latents,
112
+ preset=args.preset,
113
+ k=args.candidates,
114
+ use_deterministic_seed=seed,
115
+ )
116
+ if args.candidates == 1:
117
+ gen = gen.squeeze(0).cpu()
118
+ torchaudio.save(os.path.join(voice_outpath, f"{j}.wav"), gen, 24000)
119
+ else:
120
+ candidate_dir = os.path.join(voice_outpath, str(j))
121
+ os.makedirs(candidate_dir, exist_ok=True)
122
+ for k, g in enumerate(gen):
123
+ torchaudio.save(
124
+ os.path.join(candidate_dir, f"{k}.wav"),
125
+ g.squeeze(0).cpu(),
126
+ 24000,
127
+ )
128
+ gen = gen[0].squeeze(0).cpu()
129
+ all_parts.append(gen)
130
+
131
+ if args.candidates == 1:
132
+ full_audio = torch.cat(all_parts, dim=-1)
133
+ torchaudio.save(
134
+ os.path.join(voice_outpath, "combined.wav"), full_audio, 24000
135
+ )
136
+
137
+ if args.produce_debug_state:
138
+ os.makedirs("debug_states", exist_ok=True)
139
+ dbg_state = (seed, texts, voice_samples, conditioning_latents)
140
+ torch.save(dbg_state, f"debug_states/read_debug_{selected_voice}.pth")
141
+
142
+ # Combine each candidate's audio clips.
143
+ if args.candidates > 1:
144
+ audio_clips = []
145
+ for candidate in range(args.candidates):
146
+ for line in range(len(texts)):
147
+ wav_file = os.path.join(
148
+ voice_outpath, str(line), f"{candidate}.wav"
149
+ )
150
+ audio_clips.append(load_audio(wav_file, 24000))
151
+ audio_clips = torch.cat(audio_clips, dim=-1)
152
+ torchaudio.save(
153
+ os.path.join(voice_outpath, f"combined_{candidate:02d}.wav"),
154
+ audio_clips,
155
+ 24000,
156
+ )
157
+ audio_clips = []
tortoise/utils/__init__.py ADDED
File without changes
tortoise/utils/audio.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+
4
+ import librosa
5
+ import torch
6
+ import torchaudio
7
+ import numpy as np
8
+ from scipy.io.wavfile import read
9
+
10
+ from tortoise.utils.stft import STFT
11
+
12
+
13
+ BUILTIN_VOICES_DIR = os.path.join(
14
+ os.path.dirname(os.path.realpath(__file__)), "../voices"
15
+ )
16
+
17
+
18
+ def load_wav_to_torch(full_path):
19
+ sampling_rate, data = read(full_path)
20
+ if data.dtype == np.int32:
21
+ norm_fix = 2**31
22
+ elif data.dtype == np.int16:
23
+ norm_fix = 2**15
24
+ elif data.dtype == np.float16 or data.dtype == np.float32:
25
+ norm_fix = 1.0
26
+ else:
27
+ raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
28
+ return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
29
+
30
+
31
+ def load_audio(audiopath, sampling_rate):
32
+ if audiopath[-4:] == ".wav":
33
+ audio, lsr = load_wav_to_torch(audiopath)
34
+ elif audiopath[-4:] == ".mp3":
35
+ audio, lsr = librosa.load(audiopath, sr=sampling_rate)
36
+ audio = torch.FloatTensor(audio)
37
+ else:
38
+ assert False, f"Unsupported audio format provided: {audiopath[-4:]}"
39
+
40
+ # Remove any channel data.
41
+ if len(audio.shape) > 1:
42
+ if audio.shape[0] < 5:
43
+ audio = audio[0]
44
+ else:
45
+ assert audio.shape[1] < 5
46
+ audio = audio[:, 0]
47
+
48
+ if lsr != sampling_rate:
49
+ audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
50
+
51
+ # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
52
+ # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
53
+ if torch.any(audio > 2) or not torch.any(audio < 0):
54
+ print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
55
+ audio.clip_(-1, 1)
56
+
57
+ return audio.unsqueeze(0)
58
+
59
+
60
+ TACOTRON_MEL_MAX = 2.3143386840820312
61
+ TACOTRON_MEL_MIN = -11.512925148010254
62
+
63
+
64
+ def denormalize_tacotron_mel(norm_mel):
65
+ return ((norm_mel + 1) / 2) * (
66
+ TACOTRON_MEL_MAX - TACOTRON_MEL_MIN
67
+ ) + TACOTRON_MEL_MIN
68
+
69
+
70
+ def normalize_tacotron_mel(mel):
71
+ return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
72
+
73
+
74
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
75
+ """
76
+ PARAMS
77
+ ------
78
+ C: compression factor
79
+ """
80
+ return torch.log(torch.clamp(x, min=clip_val) * C)
81
+
82
+
83
+ def dynamic_range_decompression(x, C=1):
84
+ """
85
+ PARAMS
86
+ ------
87
+ C: compression factor used to compress
88
+ """
89
+ return torch.exp(x) / C
90
+
91
+
92
+ def get_voices(extra_voice_dirs=[]):
93
+ dirs = [BUILTIN_VOICES_DIR] + extra_voice_dirs
94
+ voices = {}
95
+ for d in dirs:
96
+ subs = os.listdir(d)
97
+ for sub in subs:
98
+ subj = os.path.join(d, sub)
99
+ if os.path.isdir(subj):
100
+ voices[sub] = (
101
+ list(glob(f"{subj}/*.wav"))
102
+ + list(glob(f"{subj}/*.mp3"))
103
+ + list(glob(f"{subj}/*.pth"))
104
+ )
105
+ return voices
106
+
107
+
108
+ def load_voice(voice, extra_voice_dirs=[]):
109
+ if voice == "random":
110
+ return None, None
111
+
112
+ voices = get_voices(extra_voice_dirs)
113
+ paths = voices[voice]
114
+ if len(paths) == 1 and paths[0].endswith(".pth"):
115
+ return None, torch.load(paths[0])
116
+ else:
117
+ conds = []
118
+ for cond_path in paths:
119
+ c = load_audio(cond_path, 22050)
120
+ conds.append(c)
121
+ return conds, None
122
+
123
+
124
+ def load_voices(voices, extra_voice_dirs=[]):
125
+ latents = []
126
+ clips = []
127
+ for voice in voices:
128
+ if voice == "random":
129
+ if len(voices) > 1:
130
+ print(
131
+ "Cannot combine a random voice with a non-random voice. Just using a random voice."
132
+ )
133
+ return None, None
134
+ clip, latent = load_voice(voice, extra_voice_dirs)
135
+ if latent is None:
136
+ assert (
137
+ len(latents) == 0
138
+ ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
139
+ clips.extend(clip)
140
+ elif clip is None:
141
+ assert (
142
+ len(clips) == 0
143
+ ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
144
+ latents.append(latent)
145
+ if len(latents) == 0:
146
+ return clips, None
147
+ else:
148
+ latents_0 = torch.stack([l[0] for l in latents], dim=0).mean(dim=0)
149
+ latents_1 = torch.stack([l[1] for l in latents], dim=0).mean(dim=0)
150
+ latents = (latents_0, latents_1)
151
+ return None, latents
152
+
153
+
154
+ class TacotronSTFT(torch.nn.Module):
155
+ def __init__(
156
+ self,
157
+ filter_length=1024,
158
+ hop_length=256,
159
+ win_length=1024,
160
+ n_mel_channels=80,
161
+ sampling_rate=22050,
162
+ mel_fmin=0.0,
163
+ mel_fmax=8000.0,
164
+ ):
165
+ super(TacotronSTFT, self).__init__()
166
+ self.n_mel_channels = n_mel_channels
167
+ self.sampling_rate = sampling_rate
168
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
169
+ from librosa.filters import mel as librosa_mel_fn
170
+
171
+ mel_basis = librosa_mel_fn(
172
+ sr=sampling_rate,
173
+ n_fft=filter_length,
174
+ n_mels=n_mel_channels,
175
+ fmin=mel_fmin,
176
+ fmax=mel_fmax,
177
+ )
178
+ mel_basis = torch.from_numpy(mel_basis).float()
179
+ self.register_buffer("mel_basis", mel_basis)
180
+
181
+ def spectral_normalize(self, magnitudes):
182
+ output = dynamic_range_compression(magnitudes)
183
+ return output
184
+
185
+ def spectral_de_normalize(self, magnitudes):
186
+ output = dynamic_range_decompression(magnitudes)
187
+ return output
188
+
189
+ def mel_spectrogram(self, y):
190
+ """Computes mel-spectrograms from a batch of waves
191
+ PARAMS
192
+ ------
193
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
194
+
195
+ RETURNS
196
+ -------
197
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
198
+ """
199
+ assert torch.min(y.data) >= -10
200
+ assert torch.max(y.data) <= 10
201
+ y = torch.clip(y, min=-1, max=1)
202
+
203
+ magnitudes, phases = self.stft_fn.transform(y)
204
+ magnitudes = magnitudes.data
205
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
206
+ mel_output = self.spectral_normalize(mel_output)
207
+ return mel_output
208
+
209
+
210
+ def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
211
+ stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
212
+ stft = stft.to(device)
213
+ mel = stft.mel_spectrogram(wav)
214
+ if do_normalization:
215
+ mel = normalize_tacotron_mel(mel)
216
+ return mel
tortoise/utils/diffusion.py ADDED
@@ -0,0 +1,1277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself:
3
+
4
+ This code started out as a PyTorch port of Ho et al's diffusion models:
5
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
6
+
7
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
8
+ """
9
+
10
+ import enum
11
+ import math
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch as th
16
+ from tqdm import tqdm
17
+
18
+
19
+ def normal_kl(mean1, logvar1, mean2, logvar2):
20
+ """
21
+ Compute the KL divergence between two gaussians.
22
+
23
+ Shapes are automatically broadcasted, so batches can be compared to
24
+ scalars, among other use cases.
25
+ """
26
+ tensor = None
27
+ for obj in (mean1, logvar1, mean2, logvar2):
28
+ if isinstance(obj, th.Tensor):
29
+ tensor = obj
30
+ break
31
+ assert tensor is not None, "at least one argument must be a Tensor"
32
+
33
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
34
+ # Tensors, but it does not work for th.exp().
35
+ logvar1, logvar2 = [
36
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
37
+ for x in (logvar1, logvar2)
38
+ ]
39
+
40
+ return 0.5 * (
41
+ -1.0
42
+ + logvar2
43
+ - logvar1
44
+ + th.exp(logvar1 - logvar2)
45
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
46
+ )
47
+
48
+
49
+ def approx_standard_normal_cdf(x):
50
+ """
51
+ A fast approximation of the cumulative distribution function of the
52
+ standard normal.
53
+ """
54
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
55
+
56
+
57
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
58
+ """
59
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
60
+ given image.
61
+
62
+ :param x: the target images. It is assumed that this was uint8 values,
63
+ rescaled to the range [-1, 1].
64
+ :param means: the Gaussian mean Tensor.
65
+ :param log_scales: the Gaussian log stddev Tensor.
66
+ :return: a tensor like x of log probabilities (in nats).
67
+ """
68
+ assert x.shape == means.shape == log_scales.shape
69
+ centered_x = x - means
70
+ inv_stdv = th.exp(-log_scales)
71
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
72
+ cdf_plus = approx_standard_normal_cdf(plus_in)
73
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
74
+ cdf_min = approx_standard_normal_cdf(min_in)
75
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
76
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
77
+ cdf_delta = cdf_plus - cdf_min
78
+ log_probs = th.where(
79
+ x < -0.999,
80
+ log_cdf_plus,
81
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
82
+ )
83
+ assert log_probs.shape == x.shape
84
+ return log_probs
85
+
86
+
87
+ def mean_flat(tensor):
88
+ """
89
+ Take the mean over all non-batch dimensions.
90
+ """
91
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
92
+
93
+
94
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
95
+ """
96
+ Get a pre-defined beta schedule for the given name.
97
+
98
+ The beta schedule library consists of beta schedules which remain similar
99
+ in the limit of num_diffusion_timesteps.
100
+ Beta schedules may be added, but should not be removed or changed once
101
+ they are committed to maintain backwards compatibility.
102
+ """
103
+ if schedule_name == "linear":
104
+ # Linear schedule from Ho et al, extended to work for any number of
105
+ # diffusion steps.
106
+ scale = 1000 / num_diffusion_timesteps
107
+ beta_start = scale * 0.0001
108
+ beta_end = scale * 0.02
109
+ return np.linspace(
110
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
111
+ )
112
+ elif schedule_name == "cosine":
113
+ return betas_for_alpha_bar(
114
+ num_diffusion_timesteps,
115
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
116
+ )
117
+ else:
118
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
119
+
120
+
121
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
122
+ """
123
+ Create a beta schedule that discretizes the given alpha_t_bar function,
124
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
125
+
126
+ :param num_diffusion_timesteps: the number of betas to produce.
127
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
128
+ produces the cumulative product of (1-beta) up to that
129
+ part of the diffusion process.
130
+ :param max_beta: the maximum beta to use; use values lower than 1 to
131
+ prevent singularities.
132
+ """
133
+ betas = []
134
+ for i in range(num_diffusion_timesteps):
135
+ t1 = i / num_diffusion_timesteps
136
+ t2 = (i + 1) / num_diffusion_timesteps
137
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
138
+ return np.array(betas)
139
+
140
+
141
+ class ModelMeanType(enum.Enum):
142
+ """
143
+ Which type of output the model predicts.
144
+ """
145
+
146
+ PREVIOUS_X = "previous_x" # the model predicts x_{t-1}
147
+ START_X = "start_x" # the model predicts x_0
148
+ EPSILON = "epsilon" # the model predicts epsilon
149
+
150
+
151
+ class ModelVarType(enum.Enum):
152
+ """
153
+ What is used as the model's output variance.
154
+
155
+ The LEARNED_RANGE option has been added to allow the model to predict
156
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
157
+ """
158
+
159
+ LEARNED = "learned"
160
+ FIXED_SMALL = "fixed_small"
161
+ FIXED_LARGE = "fixed_large"
162
+ LEARNED_RANGE = "learned_range"
163
+
164
+
165
+ class LossType(enum.Enum):
166
+ MSE = "mse" # use raw MSE loss (and KL when learning variances)
167
+ RESCALED_MSE = (
168
+ "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances)
169
+ )
170
+ KL = "kl" # use the variational lower-bound
171
+ RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB
172
+
173
+ def is_vb(self):
174
+ return self == LossType.KL or self == LossType.RESCALED_KL
175
+
176
+
177
+ class GaussianDiffusion:
178
+ """
179
+ Utilities for training and sampling diffusion models.
180
+
181
+ Ported directly from here, and then adapted over time to further experimentation.
182
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
183
+
184
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
185
+ starting at T and going to 1.
186
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
187
+ :param model_var_type: a ModelVarType determining how variance is output.
188
+ :param loss_type: a LossType determining the loss function to use.
189
+ :param rescale_timesteps: if True, pass floating point timesteps into the
190
+ model so that they are always scaled like in the
191
+ original paper (0 to 1000).
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ *,
197
+ betas,
198
+ model_mean_type,
199
+ model_var_type,
200
+ loss_type,
201
+ rescale_timesteps=False,
202
+ conditioning_free=False,
203
+ conditioning_free_k=1,
204
+ ramp_conditioning_free=True,
205
+ ):
206
+ self.model_mean_type = ModelMeanType(model_mean_type)
207
+ self.model_var_type = ModelVarType(model_var_type)
208
+ self.loss_type = LossType(loss_type)
209
+ self.rescale_timesteps = rescale_timesteps
210
+ self.conditioning_free = conditioning_free
211
+ self.conditioning_free_k = conditioning_free_k
212
+ self.ramp_conditioning_free = ramp_conditioning_free
213
+
214
+ # Use float64 for accuracy.
215
+ betas = np.array(betas, dtype=np.float64)
216
+ self.betas = betas
217
+ assert len(betas.shape) == 1, "betas must be 1-D"
218
+ assert (betas > 0).all() and (betas <= 1).all()
219
+
220
+ self.num_timesteps = int(betas.shape[0])
221
+
222
+ alphas = 1.0 - betas
223
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
224
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
225
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
226
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
227
+
228
+ # calculations for diffusion q(x_t | x_{t-1}) and others
229
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
230
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
231
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
232
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
233
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
234
+
235
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
236
+ self.posterior_variance = (
237
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
238
+ )
239
+ # log calculation clipped because the posterior variance is 0 at the
240
+ # beginning of the diffusion chain.
241
+ self.posterior_log_variance_clipped = np.log(
242
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
243
+ )
244
+ self.posterior_mean_coef1 = (
245
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
246
+ )
247
+ self.posterior_mean_coef2 = (
248
+ (1.0 - self.alphas_cumprod_prev)
249
+ * np.sqrt(alphas)
250
+ / (1.0 - self.alphas_cumprod)
251
+ )
252
+
253
+ def q_mean_variance(self, x_start, t):
254
+ """
255
+ Get the distribution q(x_t | x_0).
256
+
257
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
258
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
259
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
260
+ """
261
+ mean = (
262
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
263
+ )
264
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
265
+ log_variance = _extract_into_tensor(
266
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
267
+ )
268
+ return mean, variance, log_variance
269
+
270
+ def q_sample(self, x_start, t, noise=None):
271
+ """
272
+ Diffuse the data for a given number of diffusion steps.
273
+
274
+ In other words, sample from q(x_t | x_0).
275
+
276
+ :param x_start: the initial data batch.
277
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
278
+ :param noise: if specified, the split-out normal noise.
279
+ :return: A noisy version of x_start.
280
+ """
281
+ if noise is None:
282
+ noise = th.randn_like(x_start)
283
+ assert noise.shape == x_start.shape
284
+ return (
285
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
286
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
287
+ * noise
288
+ )
289
+
290
+ def q_posterior_mean_variance(self, x_start, x_t, t):
291
+ """
292
+ Compute the mean and variance of the diffusion posterior:
293
+
294
+ q(x_{t-1} | x_t, x_0)
295
+
296
+ """
297
+ assert x_start.shape == x_t.shape
298
+ posterior_mean = (
299
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
300
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
301
+ )
302
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
303
+ posterior_log_variance_clipped = _extract_into_tensor(
304
+ self.posterior_log_variance_clipped, t, x_t.shape
305
+ )
306
+ assert (
307
+ posterior_mean.shape[0]
308
+ == posterior_variance.shape[0]
309
+ == posterior_log_variance_clipped.shape[0]
310
+ == x_start.shape[0]
311
+ )
312
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
313
+
314
+ def p_mean_variance(
315
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
316
+ ):
317
+ """
318
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
319
+ the initial x, x_0.
320
+
321
+ :param model: the model, which takes a signal and a batch of timesteps
322
+ as input.
323
+ :param x: the [N x C x ...] tensor at time t.
324
+ :param t: a 1-D Tensor of timesteps.
325
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
326
+ :param denoised_fn: if not None, a function which applies to the
327
+ x_start prediction before it is used to sample. Applies before
328
+ clip_denoised.
329
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
330
+ pass to the model. This can be used for conditioning.
331
+ :return: a dict with the following keys:
332
+ - 'mean': the model mean output.
333
+ - 'variance': the model variance output.
334
+ - 'log_variance': the log of 'variance'.
335
+ - 'pred_xstart': the prediction for x_0.
336
+ """
337
+ if model_kwargs is None:
338
+ model_kwargs = {}
339
+
340
+ B, C = x.shape[:2]
341
+ assert t.shape == (B,)
342
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
343
+ if self.conditioning_free:
344
+ model_output_no_conditioning = model(
345
+ x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs
346
+ )
347
+
348
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
349
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
350
+ model_output, model_var_values = th.split(model_output, C, dim=1)
351
+ if self.conditioning_free:
352
+ model_output_no_conditioning, _ = th.split(
353
+ model_output_no_conditioning, C, dim=1
354
+ )
355
+ if self.model_var_type == ModelVarType.LEARNED:
356
+ model_log_variance = model_var_values
357
+ model_variance = th.exp(model_log_variance)
358
+ else:
359
+ min_log = _extract_into_tensor(
360
+ self.posterior_log_variance_clipped, t, x.shape
361
+ )
362
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
363
+ # The model_var_values is [-1, 1] for [min_var, max_var].
364
+ frac = (model_var_values + 1) / 2
365
+ model_log_variance = frac * max_log + (1 - frac) * min_log
366
+ model_variance = th.exp(model_log_variance)
367
+ else:
368
+ model_variance, model_log_variance = {
369
+ # for fixedlarge, we set the initial (log-)variance like so
370
+ # to get a better decoder log likelihood.
371
+ ModelVarType.FIXED_LARGE: (
372
+ np.append(self.posterior_variance[1], self.betas[1:]),
373
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
374
+ ),
375
+ ModelVarType.FIXED_SMALL: (
376
+ self.posterior_variance,
377
+ self.posterior_log_variance_clipped,
378
+ ),
379
+ }[self.model_var_type]
380
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
381
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
382
+
383
+ if self.conditioning_free:
384
+ if self.ramp_conditioning_free:
385
+ assert t.shape[0] == 1 # This should only be used in inference.
386
+ cfk = self.conditioning_free_k * (
387
+ 1 - self._scale_timesteps(t)[0].item() / self.num_timesteps
388
+ )
389
+ else:
390
+ cfk = self.conditioning_free_k
391
+ model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
392
+
393
+ def process_xstart(x):
394
+ if denoised_fn is not None:
395
+ x = denoised_fn(x)
396
+ if clip_denoised:
397
+ return x.clamp(-1, 1)
398
+ return x
399
+
400
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
401
+ pred_xstart = process_xstart(
402
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
403
+ )
404
+ model_mean = model_output
405
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
406
+ if self.model_mean_type == ModelMeanType.START_X:
407
+ pred_xstart = process_xstart(model_output)
408
+ else:
409
+ pred_xstart = process_xstart(
410
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
411
+ )
412
+ model_mean, _, _ = self.q_posterior_mean_variance(
413
+ x_start=pred_xstart, x_t=x, t=t
414
+ )
415
+ else:
416
+ raise NotImplementedError(self.model_mean_type)
417
+
418
+ assert (
419
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
420
+ )
421
+ return {
422
+ "mean": model_mean,
423
+ "variance": model_variance,
424
+ "log_variance": model_log_variance,
425
+ "pred_xstart": pred_xstart,
426
+ }
427
+
428
+ def _predict_xstart_from_eps(self, x_t, t, eps):
429
+ assert x_t.shape == eps.shape
430
+ return (
431
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
432
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
433
+ )
434
+
435
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
436
+ assert x_t.shape == xprev.shape
437
+ return ( # (xprev - coef2*x_t) / coef1
438
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
439
+ - _extract_into_tensor(
440
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
441
+ )
442
+ * x_t
443
+ )
444
+
445
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
446
+ return (
447
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
448
+ - pred_xstart
449
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
450
+
451
+ def _scale_timesteps(self, t):
452
+ if self.rescale_timesteps:
453
+ return t.float() * (1000.0 / self.num_timesteps)
454
+ return t
455
+
456
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
457
+ """
458
+ Compute the mean for the previous step, given a function cond_fn that
459
+ computes the gradient of a conditional log probability with respect to
460
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
461
+ condition on y.
462
+
463
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
464
+ """
465
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
466
+ new_mean = (
467
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
468
+ )
469
+ return new_mean
470
+
471
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
472
+ """
473
+ Compute what the p_mean_variance output would have been, should the
474
+ model's score function be conditioned by cond_fn.
475
+
476
+ See condition_mean() for details on cond_fn.
477
+
478
+ Unlike condition_mean(), this instead uses the conditioning strategy
479
+ from Song et al (2020).
480
+ """
481
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
482
+
483
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
484
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
485
+ x, self._scale_timesteps(t), **model_kwargs
486
+ )
487
+
488
+ out = p_mean_var.copy()
489
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
490
+ out["mean"], _, _ = self.q_posterior_mean_variance(
491
+ x_start=out["pred_xstart"], x_t=x, t=t
492
+ )
493
+ return out
494
+
495
+ def p_sample(
496
+ self,
497
+ model,
498
+ x,
499
+ t,
500
+ clip_denoised=True,
501
+ denoised_fn=None,
502
+ cond_fn=None,
503
+ model_kwargs=None,
504
+ ):
505
+ """
506
+ Sample x_{t-1} from the model at the given timestep.
507
+
508
+ :param model: the model to sample from.
509
+ :param x: the current tensor at x_{t-1}.
510
+ :param t: the value of t, starting at 0 for the first diffusion step.
511
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
512
+ :param denoised_fn: if not None, a function which applies to the
513
+ x_start prediction before it is used to sample.
514
+ :param cond_fn: if not None, this is a gradient function that acts
515
+ similarly to the model.
516
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
517
+ pass to the model. This can be used for conditioning.
518
+ :return: a dict containing the following keys:
519
+ - 'sample': a random sample from the model.
520
+ - 'pred_xstart': a prediction of x_0.
521
+ """
522
+ out = self.p_mean_variance(
523
+ model,
524
+ x,
525
+ t,
526
+ clip_denoised=clip_denoised,
527
+ denoised_fn=denoised_fn,
528
+ model_kwargs=model_kwargs,
529
+ )
530
+ noise = th.randn_like(x)
531
+ nonzero_mask = (
532
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
533
+ ) # no noise when t == 0
534
+ if cond_fn is not None:
535
+ out["mean"] = self.condition_mean(
536
+ cond_fn, out, x, t, model_kwargs=model_kwargs
537
+ )
538
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
539
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
540
+
541
+ def p_sample_loop(
542
+ self,
543
+ model,
544
+ shape,
545
+ noise=None,
546
+ clip_denoised=True,
547
+ denoised_fn=None,
548
+ cond_fn=None,
549
+ model_kwargs=None,
550
+ device=None,
551
+ progress=False,
552
+ ):
553
+ """
554
+ Generate samples from the model.
555
+
556
+ :param model: the model module.
557
+ :param shape: the shape of the samples, (N, C, H, W).
558
+ :param noise: if specified, the noise from the encoder to sample.
559
+ Should be of the same shape as `shape`.
560
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
561
+ :param denoised_fn: if not None, a function which applies to the
562
+ x_start prediction before it is used to sample.
563
+ :param cond_fn: if not None, this is a gradient function that acts
564
+ similarly to the model.
565
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
566
+ pass to the model. This can be used for conditioning.
567
+ :param device: if specified, the device to create the samples on.
568
+ If not specified, use a model parameter's device.
569
+ :param progress: if True, show a tqdm progress bar.
570
+ :return: a non-differentiable batch of samples.
571
+ """
572
+ final = None
573
+ for sample in self.p_sample_loop_progressive(
574
+ model,
575
+ shape,
576
+ noise=noise,
577
+ clip_denoised=clip_denoised,
578
+ denoised_fn=denoised_fn,
579
+ cond_fn=cond_fn,
580
+ model_kwargs=model_kwargs,
581
+ device=device,
582
+ progress=progress,
583
+ ):
584
+ final = sample
585
+ return final["sample"]
586
+
587
+ def p_sample_loop_progressive(
588
+ self,
589
+ model,
590
+ shape,
591
+ noise=None,
592
+ clip_denoised=True,
593
+ denoised_fn=None,
594
+ cond_fn=None,
595
+ model_kwargs=None,
596
+ device=None,
597
+ progress=False,
598
+ ):
599
+ """
600
+ Generate samples from the model and yield intermediate samples from
601
+ each timestep of diffusion.
602
+
603
+ Arguments are the same as p_sample_loop().
604
+ Returns a generator over dicts, where each dict is the return value of
605
+ p_sample().
606
+ """
607
+ if device is None:
608
+ device = next(model.parameters()).device
609
+ assert isinstance(shape, (tuple, list))
610
+ if noise is not None:
611
+ img = noise
612
+ else:
613
+ img = th.randn(*shape, device=device)
614
+ indices = list(range(self.num_timesteps))[::-1]
615
+
616
+ for i in tqdm(indices, disable=not progress):
617
+ t = th.tensor([i] * shape[0], device=device)
618
+ with th.no_grad():
619
+ out = self.p_sample(
620
+ model,
621
+ img,
622
+ t,
623
+ clip_denoised=clip_denoised,
624
+ denoised_fn=denoised_fn,
625
+ cond_fn=cond_fn,
626
+ model_kwargs=model_kwargs,
627
+ )
628
+ yield out
629
+ img = out["sample"]
630
+
631
+ def ddim_sample(
632
+ self,
633
+ model,
634
+ x,
635
+ t,
636
+ clip_denoised=True,
637
+ denoised_fn=None,
638
+ cond_fn=None,
639
+ model_kwargs=None,
640
+ eta=0.0,
641
+ ):
642
+ """
643
+ Sample x_{t-1} from the model using DDIM.
644
+
645
+ Same usage as p_sample().
646
+ """
647
+ out = self.p_mean_variance(
648
+ model,
649
+ x,
650
+ t,
651
+ clip_denoised=clip_denoised,
652
+ denoised_fn=denoised_fn,
653
+ model_kwargs=model_kwargs,
654
+ )
655
+ if cond_fn is not None:
656
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
657
+
658
+ # Usually our model outputs epsilon, but we re-derive it
659
+ # in case we used x_start or x_prev prediction.
660
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
661
+
662
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
663
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
664
+ sigma = (
665
+ eta
666
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
667
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
668
+ )
669
+ # Equation 12.
670
+ noise = th.randn_like(x)
671
+ mean_pred = (
672
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
673
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
674
+ )
675
+ nonzero_mask = (
676
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
677
+ ) # no noise when t == 0
678
+ sample = mean_pred + nonzero_mask * sigma * noise
679
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
680
+
681
+ def ddim_reverse_sample(
682
+ self,
683
+ model,
684
+ x,
685
+ t,
686
+ clip_denoised=True,
687
+ denoised_fn=None,
688
+ model_kwargs=None,
689
+ eta=0.0,
690
+ ):
691
+ """
692
+ Sample x_{t+1} from the model using DDIM reverse ODE.
693
+ """
694
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
695
+ out = self.p_mean_variance(
696
+ model,
697
+ x,
698
+ t,
699
+ clip_denoised=clip_denoised,
700
+ denoised_fn=denoised_fn,
701
+ model_kwargs=model_kwargs,
702
+ )
703
+ # Usually our model outputs epsilon, but we re-derive it
704
+ # in case we used x_start or x_prev prediction.
705
+ eps = (
706
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
707
+ - out["pred_xstart"]
708
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
709
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
710
+
711
+ # Equation 12. reversed
712
+ mean_pred = (
713
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
714
+ + th.sqrt(1 - alpha_bar_next) * eps
715
+ )
716
+
717
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
718
+
719
+ def ddim_sample_loop(
720
+ self,
721
+ model,
722
+ shape,
723
+ noise=None,
724
+ clip_denoised=True,
725
+ denoised_fn=None,
726
+ cond_fn=None,
727
+ model_kwargs=None,
728
+ device=None,
729
+ progress=False,
730
+ eta=0.0,
731
+ ):
732
+ """
733
+ Generate samples from the model using DDIM.
734
+
735
+ Same usage as p_sample_loop().
736
+ """
737
+ final = None
738
+ for sample in self.ddim_sample_loop_progressive(
739
+ model,
740
+ shape,
741
+ noise=noise,
742
+ clip_denoised=clip_denoised,
743
+ denoised_fn=denoised_fn,
744
+ cond_fn=cond_fn,
745
+ model_kwargs=model_kwargs,
746
+ device=device,
747
+ progress=progress,
748
+ eta=eta,
749
+ ):
750
+ final = sample
751
+ return final["sample"]
752
+
753
+ def ddim_sample_loop_progressive(
754
+ self,
755
+ model,
756
+ shape,
757
+ noise=None,
758
+ clip_denoised=True,
759
+ denoised_fn=None,
760
+ cond_fn=None,
761
+ model_kwargs=None,
762
+ device=None,
763
+ progress=False,
764
+ eta=0.0,
765
+ ):
766
+ """
767
+ Use DDIM to sample from the model and yield intermediate samples from
768
+ each timestep of DDIM.
769
+
770
+ Same usage as p_sample_loop_progressive().
771
+ """
772
+ if device is None:
773
+ device = next(model.parameters()).device
774
+ assert isinstance(shape, (tuple, list))
775
+ if noise is not None:
776
+ img = noise
777
+ else:
778
+ img = th.randn(*shape, device=device)
779
+ indices = list(range(self.num_timesteps))[::-1]
780
+
781
+ if progress:
782
+ # Lazy import so that we don't depend on tqdm.
783
+ from tqdm.auto import tqdm
784
+
785
+ indices = tqdm(indices, disable=not progress)
786
+
787
+ for i in indices:
788
+ t = th.tensor([i] * shape[0], device=device)
789
+ with th.no_grad():
790
+ out = self.ddim_sample(
791
+ model,
792
+ img,
793
+ t,
794
+ clip_denoised=clip_denoised,
795
+ denoised_fn=denoised_fn,
796
+ cond_fn=cond_fn,
797
+ model_kwargs=model_kwargs,
798
+ eta=eta,
799
+ )
800
+ yield out
801
+ img = out["sample"]
802
+
803
+ def _vb_terms_bpd(
804
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
805
+ ):
806
+ """
807
+ Get a term for the variational lower-bound.
808
+
809
+ The resulting units are bits (rather than nats, as one might expect).
810
+ This allows for comparison to other papers.
811
+
812
+ :return: a dict with the following keys:
813
+ - 'output': a shape [N] tensor of NLLs or KLs.
814
+ - 'pred_xstart': the x_0 predictions.
815
+ """
816
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
817
+ x_start=x_start, x_t=x_t, t=t
818
+ )
819
+ out = self.p_mean_variance(
820
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
821
+ )
822
+ kl = normal_kl(
823
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
824
+ )
825
+ kl = mean_flat(kl) / np.log(2.0)
826
+
827
+ decoder_nll = -discretized_gaussian_log_likelihood(
828
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
829
+ )
830
+ assert decoder_nll.shape == x_start.shape
831
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
832
+
833
+ # At the first timestep return the decoder NLL,
834
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
835
+ output = th.where((t == 0), decoder_nll, kl)
836
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
837
+
838
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
839
+ """
840
+ Compute training losses for a single timestep.
841
+
842
+ :param model: the model to evaluate loss on.
843
+ :param x_start: the [N x C x ...] tensor of inputs.
844
+ :param t: a batch of timestep indices.
845
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
846
+ pass to the model. This can be used for conditioning.
847
+ :param noise: if specified, the specific Gaussian noise to try to remove.
848
+ :return: a dict with the key "loss" containing a tensor of shape [N].
849
+ Some mean or variance settings may also have other keys.
850
+ """
851
+ if model_kwargs is None:
852
+ model_kwargs = {}
853
+ if noise is None:
854
+ noise = th.randn_like(x_start)
855
+ x_t = self.q_sample(x_start, t, noise=noise)
856
+
857
+ terms = {}
858
+
859
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
860
+ # TODO: support multiple model outputs for this mode.
861
+ terms["loss"] = self._vb_terms_bpd(
862
+ model=model,
863
+ x_start=x_start,
864
+ x_t=x_t,
865
+ t=t,
866
+ clip_denoised=False,
867
+ model_kwargs=model_kwargs,
868
+ )["output"]
869
+ if self.loss_type == LossType.RESCALED_KL:
870
+ terms["loss"] *= self.num_timesteps
871
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
872
+ model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
873
+ if isinstance(model_outputs, tuple):
874
+ model_output = model_outputs[0]
875
+ terms["extra_outputs"] = model_outputs[1:]
876
+ else:
877
+ model_output = model_outputs
878
+
879
+ if self.model_var_type in [
880
+ ModelVarType.LEARNED,
881
+ ModelVarType.LEARNED_RANGE,
882
+ ]:
883
+ B, C = x_t.shape[:2]
884
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
885
+ model_output, model_var_values = th.split(model_output, C, dim=1)
886
+ # Learn the variance using the variational bound, but don't let
887
+ # it affect our mean prediction.
888
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
889
+ terms["vb"] = self._vb_terms_bpd(
890
+ model=lambda *args, r=frozen_out: r,
891
+ x_start=x_start,
892
+ x_t=x_t,
893
+ t=t,
894
+ clip_denoised=False,
895
+ )["output"]
896
+ if self.loss_type == LossType.RESCALED_MSE:
897
+ # Divide by 1000 for equivalence with initial implementation.
898
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
899
+ terms["vb"] *= self.num_timesteps / 1000.0
900
+
901
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
902
+ target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[
903
+ 0
904
+ ]
905
+ x_start_pred = torch.zeros(x_start) # Not supported.
906
+ elif self.model_mean_type == ModelMeanType.START_X:
907
+ target = x_start
908
+ x_start_pred = model_output
909
+ elif self.model_mean_type == ModelMeanType.EPSILON:
910
+ target = noise
911
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
912
+ else:
913
+ raise NotImplementedError(self.model_mean_type)
914
+ assert model_output.shape == target.shape == x_start.shape
915
+ terms["mse"] = mean_flat((target - model_output) ** 2)
916
+ terms["x_start_predicted"] = x_start_pred
917
+ if "vb" in terms:
918
+ terms["loss"] = terms["mse"] + terms["vb"]
919
+ else:
920
+ terms["loss"] = terms["mse"]
921
+ else:
922
+ raise NotImplementedError(self.loss_type)
923
+
924
+ return terms
925
+
926
+ def autoregressive_training_losses(
927
+ self,
928
+ model,
929
+ x_start,
930
+ t,
931
+ model_output_keys,
932
+ gd_out_key,
933
+ model_kwargs=None,
934
+ noise=None,
935
+ ):
936
+ """
937
+ Compute training losses for a single timestep.
938
+
939
+ :param model: the model to evaluate loss on.
940
+ :param x_start: the [N x C x ...] tensor of inputs.
941
+ :param t: a batch of timestep indices.
942
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
943
+ pass to the model. This can be used for conditioning.
944
+ :param noise: if specified, the specific Gaussian noise to try to remove.
945
+ :return: a dict with the key "loss" containing a tensor of shape [N].
946
+ Some mean or variance settings may also have other keys.
947
+ """
948
+ if model_kwargs is None:
949
+ model_kwargs = {}
950
+ if noise is None:
951
+ noise = th.randn_like(x_start)
952
+ x_t = self.q_sample(x_start, t, noise=noise)
953
+ terms = {}
954
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
955
+ assert False # not currently supported for this type of diffusion.
956
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
957
+ model_outputs = model(
958
+ x_t, x_start, self._scale_timesteps(t), **model_kwargs
959
+ )
960
+ terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
961
+ model_output = terms[gd_out_key]
962
+ if self.model_var_type in [
963
+ ModelVarType.LEARNED,
964
+ ModelVarType.LEARNED_RANGE,
965
+ ]:
966
+ B, C = x_t.shape[:2]
967
+ assert model_output.shape == (B, C, 2, *x_t.shape[2:])
968
+ model_output, model_var_values = (
969
+ model_output[:, :, 0],
970
+ model_output[:, :, 1],
971
+ )
972
+ # Learn the variance using the variational bound, but don't let
973
+ # it affect our mean prediction.
974
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
975
+ terms["vb"] = self._vb_terms_bpd(
976
+ model=lambda *args, r=frozen_out: r,
977
+ x_start=x_start,
978
+ x_t=x_t,
979
+ t=t,
980
+ clip_denoised=False,
981
+ )["output"]
982
+ if self.loss_type == LossType.RESCALED_MSE:
983
+ # Divide by 1000 for equivalence with initial implementation.
984
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
985
+ terms["vb"] *= self.num_timesteps / 1000.0
986
+
987
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
988
+ target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[
989
+ 0
990
+ ]
991
+ x_start_pred = torch.zeros(x_start) # Not supported.
992
+ elif self.model_mean_type == ModelMeanType.START_X:
993
+ target = x_start
994
+ x_start_pred = model_output
995
+ elif self.model_mean_type == ModelMeanType.EPSILON:
996
+ target = noise
997
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
998
+ else:
999
+ raise NotImplementedError(self.model_mean_type)
1000
+ assert model_output.shape == target.shape == x_start.shape
1001
+ terms["mse"] = mean_flat((target - model_output) ** 2)
1002
+ terms["x_start_predicted"] = x_start_pred
1003
+ if "vb" in terms:
1004
+ terms["loss"] = terms["mse"] + terms["vb"]
1005
+ else:
1006
+ terms["loss"] = terms["mse"]
1007
+ else:
1008
+ raise NotImplementedError(self.loss_type)
1009
+
1010
+ return terms
1011
+
1012
+ def _prior_bpd(self, x_start):
1013
+ """
1014
+ Get the prior KL term for the variational lower-bound, measured in
1015
+ bits-per-dim.
1016
+
1017
+ This term can't be optimized, as it only depends on the encoder.
1018
+
1019
+ :param x_start: the [N x C x ...] tensor of inputs.
1020
+ :return: a batch of [N] KL values (in bits), one per batch element.
1021
+ """
1022
+ batch_size = x_start.shape[0]
1023
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1024
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1025
+ kl_prior = normal_kl(
1026
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1027
+ )
1028
+ return mean_flat(kl_prior) / np.log(2.0)
1029
+
1030
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
1031
+ """
1032
+ Compute the entire variational lower-bound, measured in bits-per-dim,
1033
+ as well as other related quantities.
1034
+
1035
+ :param model: the model to evaluate loss on.
1036
+ :param x_start: the [N x C x ...] tensor of inputs.
1037
+ :param clip_denoised: if True, clip denoised samples.
1038
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1039
+ pass to the model. This can be used for conditioning.
1040
+
1041
+ :return: a dict containing the following keys:
1042
+ - total_bpd: the total variational lower-bound, per batch element.
1043
+ - prior_bpd: the prior term in the lower-bound.
1044
+ - vb: an [N x T] tensor of terms in the lower-bound.
1045
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1046
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1047
+ """
1048
+ device = x_start.device
1049
+ batch_size = x_start.shape[0]
1050
+
1051
+ vb = []
1052
+ xstart_mse = []
1053
+ mse = []
1054
+ for t in list(range(self.num_timesteps))[::-1]:
1055
+ t_batch = th.tensor([t] * batch_size, device=device)
1056
+ noise = th.randn_like(x_start)
1057
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1058
+ # Calculate VLB term at the current timestep
1059
+ with th.no_grad():
1060
+ out = self._vb_terms_bpd(
1061
+ model,
1062
+ x_start=x_start,
1063
+ x_t=x_t,
1064
+ t=t_batch,
1065
+ clip_denoised=clip_denoised,
1066
+ model_kwargs=model_kwargs,
1067
+ )
1068
+ vb.append(out["output"])
1069
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1070
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1071
+ mse.append(mean_flat((eps - noise) ** 2))
1072
+
1073
+ vb = th.stack(vb, dim=1)
1074
+ xstart_mse = th.stack(xstart_mse, dim=1)
1075
+ mse = th.stack(mse, dim=1)
1076
+
1077
+ prior_bpd = self._prior_bpd(x_start)
1078
+ total_bpd = vb.sum(dim=1) + prior_bpd
1079
+ return {
1080
+ "total_bpd": total_bpd,
1081
+ "prior_bpd": prior_bpd,
1082
+ "vb": vb,
1083
+ "xstart_mse": xstart_mse,
1084
+ "mse": mse,
1085
+ }
1086
+
1087
+
1088
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
1089
+ """
1090
+ Get a pre-defined beta schedule for the given name.
1091
+
1092
+ The beta schedule library consists of beta schedules which remain similar
1093
+ in the limit of num_diffusion_timesteps.
1094
+ Beta schedules may be added, but should not be removed or changed once
1095
+ they are committed to maintain backwards compatibility.
1096
+ """
1097
+ if schedule_name == "linear":
1098
+ # Linear schedule from Ho et al, extended to work for any number of
1099
+ # diffusion steps.
1100
+ scale = 1000 / num_diffusion_timesteps
1101
+ beta_start = scale * 0.0001
1102
+ beta_end = scale * 0.02
1103
+ return np.linspace(
1104
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
1105
+ )
1106
+ elif schedule_name == "cosine":
1107
+ return betas_for_alpha_bar(
1108
+ num_diffusion_timesteps,
1109
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
1110
+ )
1111
+ else:
1112
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1113
+
1114
+
1115
+ class SpacedDiffusion(GaussianDiffusion):
1116
+ """
1117
+ A diffusion process which can skip steps in a base diffusion process.
1118
+
1119
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
1120
+ original diffusion process to retain.
1121
+ :param kwargs: the kwargs to create the base diffusion process.
1122
+ """
1123
+
1124
+ def __init__(self, use_timesteps, **kwargs):
1125
+ self.use_timesteps = set(use_timesteps)
1126
+ self.timestep_map = []
1127
+ self.original_num_steps = len(kwargs["betas"])
1128
+
1129
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
1130
+ last_alpha_cumprod = 1.0
1131
+ new_betas = []
1132
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
1133
+ if i in self.use_timesteps:
1134
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
1135
+ last_alpha_cumprod = alpha_cumprod
1136
+ self.timestep_map.append(i)
1137
+ kwargs["betas"] = np.array(new_betas)
1138
+ super().__init__(**kwargs)
1139
+
1140
+ def p_mean_variance(
1141
+ self, model, *args, **kwargs
1142
+ ): # pylint: disable=signature-differs
1143
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
1144
+
1145
+ def training_losses(
1146
+ self, model, *args, **kwargs
1147
+ ): # pylint: disable=signature-differs
1148
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
1149
+
1150
+ def autoregressive_training_losses(
1151
+ self, model, *args, **kwargs
1152
+ ): # pylint: disable=signature-differs
1153
+ return super().autoregressive_training_losses(
1154
+ self._wrap_model(model, True), *args, **kwargs
1155
+ )
1156
+
1157
+ def condition_mean(self, cond_fn, *args, **kwargs):
1158
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
1159
+
1160
+ def condition_score(self, cond_fn, *args, **kwargs):
1161
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
1162
+
1163
+ def _wrap_model(self, model, autoregressive=False):
1164
+ if isinstance(model, _WrappedModel) or isinstance(
1165
+ model, _WrappedAutoregressiveModel
1166
+ ):
1167
+ return model
1168
+ mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
1169
+ return mod(
1170
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
1171
+ )
1172
+
1173
+ def _scale_timesteps(self, t):
1174
+ # Scaling is done by the wrapped model.
1175
+ return t
1176
+
1177
+
1178
+ def space_timesteps(num_timesteps, section_counts):
1179
+ """
1180
+ Create a list of timesteps to use from an original diffusion process,
1181
+ given the number of timesteps we want to take from equally-sized portions
1182
+ of the original process.
1183
+
1184
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
1185
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
1186
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
1187
+
1188
+ If the stride is a string starting with "ddim", then the fixed striding
1189
+ from the DDIM paper is used, and only one section is allowed.
1190
+
1191
+ :param num_timesteps: the number of diffusion steps in the original
1192
+ process to divide up.
1193
+ :param section_counts: either a list of numbers, or a string containing
1194
+ comma-separated numbers, indicating the step count
1195
+ per section. As a special case, use "ddimN" where N
1196
+ is a number of steps to use the striding from the
1197
+ DDIM paper.
1198
+ :return: a set of diffusion steps from the original process to use.
1199
+ """
1200
+ if isinstance(section_counts, str):
1201
+ if section_counts.startswith("ddim"):
1202
+ desired_count = int(section_counts[len("ddim") :])
1203
+ for i in range(1, num_timesteps):
1204
+ if len(range(0, num_timesteps, i)) == desired_count:
1205
+ return set(range(0, num_timesteps, i))
1206
+ raise ValueError(
1207
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
1208
+ )
1209
+ section_counts = [int(x) for x in section_counts.split(",")]
1210
+ size_per = num_timesteps // len(section_counts)
1211
+ extra = num_timesteps % len(section_counts)
1212
+ start_idx = 0
1213
+ all_steps = []
1214
+ for i, section_count in enumerate(section_counts):
1215
+ size = size_per + (1 if i < extra else 0)
1216
+ if size < section_count:
1217
+ raise ValueError(
1218
+ f"cannot divide section of {size} steps into {section_count}"
1219
+ )
1220
+ if section_count <= 1:
1221
+ frac_stride = 1
1222
+ else:
1223
+ frac_stride = (size - 1) / (section_count - 1)
1224
+ cur_idx = 0.0
1225
+ taken_steps = []
1226
+ for _ in range(section_count):
1227
+ taken_steps.append(start_idx + round(cur_idx))
1228
+ cur_idx += frac_stride
1229
+ all_steps += taken_steps
1230
+ start_idx += size
1231
+ return set(all_steps)
1232
+
1233
+
1234
+ class _WrappedModel:
1235
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1236
+ self.model = model
1237
+ self.timestep_map = timestep_map
1238
+ self.rescale_timesteps = rescale_timesteps
1239
+ self.original_num_steps = original_num_steps
1240
+
1241
+ def __call__(self, x, ts, **kwargs):
1242
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1243
+ new_ts = map_tensor[ts]
1244
+ if self.rescale_timesteps:
1245
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1246
+ return self.model(x, new_ts, **kwargs)
1247
+
1248
+
1249
+ class _WrappedAutoregressiveModel:
1250
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1251
+ self.model = model
1252
+ self.timestep_map = timestep_map
1253
+ self.rescale_timesteps = rescale_timesteps
1254
+ self.original_num_steps = original_num_steps
1255
+
1256
+ def __call__(self, x, x0, ts, **kwargs):
1257
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1258
+ new_ts = map_tensor[ts]
1259
+ if self.rescale_timesteps:
1260
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1261
+ return self.model(x, x0, new_ts, **kwargs)
1262
+
1263
+
1264
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1265
+ """
1266
+ Extract values from a 1-D numpy array for a batch of indices.
1267
+
1268
+ :param arr: the 1-D numpy array.
1269
+ :param timesteps: a tensor of indices into the array to extract.
1270
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1271
+ dimension equal to the length of timesteps.
1272
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1273
+ """
1274
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1275
+ while len(res.shape) < len(broadcast_shape):
1276
+ res = res[..., None]
1277
+ return res.expand(broadcast_shape)
tortoise/utils/stft.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+
4
+ Copyright (c) 2017, Prem Seetharaman
5
+ All rights reserved.
6
+
7
+ * Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice,
11
+ this list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice, this
14
+ list of conditions and the following disclaimer in the
15
+ documentation and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from this
19
+ software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ """
32
+
33
+ import torch
34
+ import numpy as np
35
+ import torch.nn.functional as F
36
+ from torch.autograd import Variable
37
+ from scipy.signal import get_window
38
+ from librosa.util import pad_center, tiny
39
+ import librosa.util as librosa_util
40
+
41
+
42
+ def window_sumsquare(
43
+ window,
44
+ n_frames,
45
+ hop_length=200,
46
+ win_length=800,
47
+ n_fft=800,
48
+ dtype=np.float32,
49
+ norm=None,
50
+ ):
51
+ """
52
+ # from librosa 0.6
53
+ Compute the sum-square envelope of a window function at a given hop length.
54
+
55
+ This is used to estimate modulation effects induced by windowing
56
+ observations in short-time fourier transforms.
57
+
58
+ Parameters
59
+ ----------
60
+ window : string, tuple, number, callable, or list-like
61
+ Window specification, as in `get_window`
62
+
63
+ n_frames : int > 0
64
+ The number of analysis frames
65
+
66
+ hop_length : int > 0
67
+ The number of samples to advance between frames
68
+
69
+ win_length : [optional]
70
+ The length of the window function. By default, this matches `n_fft`.
71
+
72
+ n_fft : int > 0
73
+ The length of each analysis frame.
74
+
75
+ dtype : np.dtype
76
+ The data type of the output
77
+
78
+ Returns
79
+ -------
80
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
81
+ The sum-squared envelope of the window function
82
+ """
83
+ if win_length is None:
84
+ win_length = n_fft
85
+
86
+ n = n_fft + hop_length * (n_frames - 1)
87
+ x = np.zeros(n, dtype=dtype)
88
+
89
+ # Compute the squared window at the desired length
90
+ win_sq = get_window(window, win_length, fftbins=True)
91
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
92
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
93
+
94
+ # Fill the envelope
95
+ for i in range(n_frames):
96
+ sample = i * hop_length
97
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
98
+ return x
99
+
100
+
101
+ class STFT(torch.nn.Module):
102
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
103
+
104
+ def __init__(
105
+ self, filter_length=800, hop_length=200, win_length=800, window="hann"
106
+ ):
107
+ super(STFT, self).__init__()
108
+ self.filter_length = filter_length
109
+ self.hop_length = hop_length
110
+ self.win_length = win_length
111
+ self.window = window
112
+ self.forward_transform = None
113
+ scale = self.filter_length / self.hop_length
114
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
115
+
116
+ cutoff = int((self.filter_length / 2 + 1))
117
+ fourier_basis = np.vstack(
118
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
119
+ )
120
+
121
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
122
+ inverse_basis = torch.FloatTensor(
123
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
124
+ )
125
+
126
+ if window is not None:
127
+ assert filter_length >= win_length
128
+ # get window and zero center pad it to filter_length
129
+ fft_window = get_window(window, win_length, fftbins=True)
130
+ fft_window = pad_center(fft_window, size=filter_length)
131
+ fft_window = torch.from_numpy(fft_window).float()
132
+
133
+ # window the bases
134
+ forward_basis *= fft_window
135
+ inverse_basis *= fft_window
136
+
137
+ self.register_buffer("forward_basis", forward_basis.float())
138
+ self.register_buffer("inverse_basis", inverse_basis.float())
139
+
140
+ def transform(self, input_data):
141
+ num_batches = input_data.size(0)
142
+ num_samples = input_data.size(1)
143
+
144
+ self.num_samples = num_samples
145
+
146
+ # similar to librosa, reflect-pad the input
147
+ input_data = input_data.view(num_batches, 1, num_samples)
148
+ input_data = F.pad(
149
+ input_data.unsqueeze(1),
150
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
151
+ mode="reflect",
152
+ )
153
+ input_data = input_data.squeeze(1)
154
+
155
+ forward_transform = F.conv1d(
156
+ input_data,
157
+ Variable(self.forward_basis, requires_grad=False),
158
+ stride=self.hop_length,
159
+ padding=0,
160
+ )
161
+
162
+ cutoff = int((self.filter_length / 2) + 1)
163
+ real_part = forward_transform[:, :cutoff, :]
164
+ imag_part = forward_transform[:, cutoff:, :]
165
+
166
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
167
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
168
+
169
+ return magnitude, phase
170
+
171
+ def inverse(self, magnitude, phase):
172
+ recombine_magnitude_phase = torch.cat(
173
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
174
+ )
175
+
176
+ inverse_transform = F.conv_transpose1d(
177
+ recombine_magnitude_phase,
178
+ Variable(self.inverse_basis, requires_grad=False),
179
+ stride=self.hop_length,
180
+ padding=0,
181
+ )
182
+
183
+ if self.window is not None:
184
+ window_sum = window_sumsquare(
185
+ self.window,
186
+ magnitude.size(-1),
187
+ hop_length=self.hop_length,
188
+ win_length=self.win_length,
189
+ n_fft=self.filter_length,
190
+ dtype=np.float32,
191
+ )
192
+ # remove modulation effects
193
+ approx_nonzero_indices = torch.from_numpy(
194
+ np.where(window_sum > tiny(window_sum))[0]
195
+ )
196
+ window_sum = torch.autograd.Variable(
197
+ torch.from_numpy(window_sum), requires_grad=False
198
+ )
199
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
200
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
201
+ approx_nonzero_indices
202
+ ]
203
+
204
+ # scale by hop ratio
205
+ inverse_transform *= float(self.filter_length) / self.hop_length
206
+
207
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
208
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
209
+
210
+ return inverse_transform
211
+
212
+ def forward(self, input_data):
213
+ self.magnitude, self.phase = self.transform(input_data)
214
+ reconstruction = self.inverse(self.magnitude, self.phase)
215
+ return reconstruction
tortoise/utils/text.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def split_and_recombine_text(text, desired_length=200, max_length=300):
5
+ """Split text it into chunks of a desired length trying to keep sentences intact."""
6
+ # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
7
+ text = re.sub(r"\n\n+", "\n", text)
8
+ text = re.sub(r"\s+", " ", text)
9
+ text = re.sub(r"[“”]", '"', text)
10
+
11
+ rv = []
12
+ in_quote = False
13
+ current = ""
14
+ split_pos = []
15
+ pos = -1
16
+ end_pos = len(text) - 1
17
+
18
+ def seek(delta):
19
+ nonlocal pos, in_quote, current
20
+ is_neg = delta < 0
21
+ for _ in range(abs(delta)):
22
+ if is_neg:
23
+ pos -= 1
24
+ current = current[:-1]
25
+ else:
26
+ pos += 1
27
+ current += text[pos]
28
+ if text[pos] == '"':
29
+ in_quote = not in_quote
30
+ return text[pos]
31
+
32
+ def peek(delta):
33
+ p = pos + delta
34
+ return text[p] if p < end_pos and p >= 0 else ""
35
+
36
+ def commit():
37
+ nonlocal rv, current, split_pos
38
+ rv.append(current)
39
+ current = ""
40
+ split_pos = []
41
+
42
+ while pos < end_pos:
43
+ c = seek(1)
44
+ # do we need to force a split?
45
+ if len(current) >= max_length:
46
+ if len(split_pos) > 0 and len(current) > (desired_length / 2):
47
+ # we have at least one sentence and we are over half the desired length, seek back to the last split
48
+ d = pos - split_pos[-1]
49
+ seek(-d)
50
+ else:
51
+ # no full sentences, seek back until we are not in the middle of a word and split there
52
+ while c not in "!?.\n " and pos > 0 and len(current) > desired_length:
53
+ c = seek(-1)
54
+ commit()
55
+ # check for sentence boundaries
56
+ elif not in_quote and (c in "!?\n" or (c == "." and peek(1) in "\n ")):
57
+ # seek forward if we have consecutive boundary markers but still within the max length
58
+ while (
59
+ pos < len(text) - 1 and len(current) < max_length and peek(1) in "!?."
60
+ ):
61
+ c = seek(1)
62
+ split_pos.append(pos)
63
+ if len(current) >= desired_length:
64
+ commit()
65
+ # treat end of quote as a boundary if its followed by a space or newline
66
+ elif in_quote and peek(1) == '"' and peek(2) in "\n ":
67
+ seek(2)
68
+ split_pos.append(pos)
69
+ rv.append(current)
70
+
71
+ # clean up, remove lines with only whitespace or punctuation
72
+ rv = [s.strip() for s in rv]
73
+ rv = [s for s in rv if len(s) > 0 and not re.match(r"^[\s\.,;:!?]*$", s)]
74
+
75
+ return rv
76
+
77
+
78
+ if __name__ == "__main__":
79
+ import os
80
+ import unittest
81
+
82
+ class Test(unittest.TestCase):
83
+ def test_split_and_recombine_text(self):
84
+ text = """
85
+ This is a sample sentence.
86
+ This is another sample sentence.
87
+ This is a longer sample sentence that should force a split inthemiddlebutinotinthislongword.
88
+ "Don't split my quote... please"
89
+ """
90
+ self.assertEqual(
91
+ split_and_recombine_text(text, desired_length=20, max_length=40),
92
+ [
93
+ "This is a sample sentence.",
94
+ "This is another sample sentence.",
95
+ "This is a longer sample sentence that",
96
+ "should force a split",
97
+ "inthemiddlebutinotinthislongword.",
98
+ '"Don\'t split my quote... please"',
99
+ ],
100
+ )
101
+
102
+ def test_split_and_recombine_text_2(self):
103
+ text = """
104
+ When you are really angry sometimes you use consecutive exclamation marks!!!!!! Is this a good thing to do?!?!?!
105
+ I don't know but we should handle this situation..........................
106
+ """
107
+ self.assertEqual(
108
+ split_and_recombine_text(text, desired_length=30, max_length=50),
109
+ [
110
+ "When you are really angry sometimes you use",
111
+ "consecutive exclamation marks!!!!!!",
112
+ "Is this a good thing to do?!?!?!",
113
+ "I don't know but we should handle this situation.",
114
+ ],
115
+ )
116
+
117
+ def test_split_and_recombine_text_3(self):
118
+ text_src = os.path.join(
119
+ os.path.dirname(__file__), "../data/riding_hood.txt"
120
+ )
121
+ with open(text_src, "r") as f:
122
+ text = f.read()
123
+ self.assertEqual(
124
+ split_and_recombine_text(text),
125
+ [
126
+ "Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her.",
127
+ 'It suited the girl so extremely well that everybody called her Little Red Riding Hood. One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter."',
128
+ "Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village. As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest.",
129
+ 'He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother." "Does she live far off?" said the wolf "Oh I say,"',
130
+ 'answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village." "Well," said the wolf, "and I\'ll go and see her too. I\'ll go this way and go you that, and we shall see who will be there first."',
131
+ "The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers.",
132
+ 'It was not long before the wolf arrived at the old woman\'s house. He knocked at the door: tap, tap. "Who\'s there?" "Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother."',
133
+ 'The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up."',
134
+ "The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten.",
135
+ "He then shut the door and got into the grandmother's bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap. \"Who's there?\"",
136
+ 'Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you."',
137
+ 'The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up." Little Red Riding Hood pulled the bobbin, and the door opened.',
138
+ 'The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me." Little Red Riding Hood took off her clothes and got into bed.',
139
+ 'She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!" "All the better to hug you with, my dear." "Grandmother, what big legs you have!" "All the better to run with, my child." "Grandmother, what big ears you have!"',
140
+ '"All the better to hear with, my child." "Grandmother, what big eyes you have!" "All the better to see with, my child." "Grandmother, what big teeth you have got!" "All the better to eat you up with." And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up.',
141
+ ],
142
+ )
143
+
144
+ unittest.main()
tortoise/utils/tokenizer.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import inflect
5
+ import torch
6
+ from tokenizers import Tokenizer
7
+
8
+
9
+ # Regular expression matching whitespace:
10
+ from unidecode import unidecode
11
+
12
+ _whitespace_re = re.compile(r"\s+")
13
+
14
+
15
+ # List of (regular expression, replacement) pairs for abbreviations:
16
+ _abbreviations = [
17
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
18
+ for x in [
19
+ ("mrs", "misess"),
20
+ ("mr", "mister"),
21
+ ("dr", "doctor"),
22
+ ("st", "saint"),
23
+ ("co", "company"),
24
+ ("jr", "junior"),
25
+ ("maj", "major"),
26
+ ("gen", "general"),
27
+ ("drs", "doctors"),
28
+ ("rev", "reverend"),
29
+ ("lt", "lieutenant"),
30
+ ("hon", "honorable"),
31
+ ("sgt", "sergeant"),
32
+ ("capt", "captain"),
33
+ ("esq", "esquire"),
34
+ ("ltd", "limited"),
35
+ ("col", "colonel"),
36
+ ("ft", "fort"),
37
+ ]
38
+ ]
39
+
40
+
41
+ def expand_abbreviations(text):
42
+ for regex, replacement in _abbreviations:
43
+ text = re.sub(regex, replacement, text)
44
+ return text
45
+
46
+
47
+ _inflect = inflect.engine()
48
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
49
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
50
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
51
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
52
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
53
+ _number_re = re.compile(r"[0-9]+")
54
+
55
+
56
+ def _remove_commas(m):
57
+ return m.group(1).replace(",", "")
58
+
59
+
60
+ def _expand_decimal_point(m):
61
+ return m.group(1).replace(".", " point ")
62
+
63
+
64
+ def _expand_dollars(m):
65
+ match = m.group(1)
66
+ parts = match.split(".")
67
+ if len(parts) > 2:
68
+ return match + " dollars" # Unexpected format
69
+ dollars = int(parts[0]) if parts[0] else 0
70
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
71
+ if dollars and cents:
72
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
73
+ cent_unit = "cent" if cents == 1 else "cents"
74
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
75
+ elif dollars:
76
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
77
+ return "%s %s" % (dollars, dollar_unit)
78
+ elif cents:
79
+ cent_unit = "cent" if cents == 1 else "cents"
80
+ return "%s %s" % (cents, cent_unit)
81
+ else:
82
+ return "zero dollars"
83
+
84
+
85
+ def _expand_ordinal(m):
86
+ return _inflect.number_to_words(m.group(0))
87
+
88
+
89
+ def _expand_number(m):
90
+ num = int(m.group(0))
91
+ if num > 1000 and num < 3000:
92
+ if num == 2000:
93
+ return "two thousand"
94
+ elif num > 2000 and num < 2010:
95
+ return "two thousand " + _inflect.number_to_words(num % 100)
96
+ elif num % 100 == 0:
97
+ return _inflect.number_to_words(num // 100) + " hundred"
98
+ else:
99
+ return _inflect.number_to_words(
100
+ num, andword="", zero="oh", group=2
101
+ ).replace(", ", " ")
102
+ else:
103
+ return _inflect.number_to_words(num, andword="")
104
+
105
+
106
+ def normalize_numbers(text):
107
+ text = re.sub(_comma_number_re, _remove_commas, text)
108
+ text = re.sub(_pounds_re, r"\1 pounds", text)
109
+ text = re.sub(_dollars_re, _expand_dollars, text)
110
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
111
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
112
+ text = re.sub(_number_re, _expand_number, text)
113
+ return text
114
+
115
+
116
+ def expand_numbers(text):
117
+ return normalize_numbers(text)
118
+
119
+
120
+ def lowercase(text):
121
+ return text.lower()
122
+
123
+
124
+ def collapse_whitespace(text):
125
+ return re.sub(_whitespace_re, " ", text)
126
+
127
+
128
+ def convert_to_ascii(text):
129
+ return unidecode(text)
130
+
131
+
132
+ def basic_cleaners(text):
133
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
134
+ text = lowercase(text)
135
+ text = collapse_whitespace(text)
136
+ return text
137
+
138
+
139
+ def transliteration_cleaners(text):
140
+ """Pipeline for non-English text that transliterates to ASCII."""
141
+ text = convert_to_ascii(text)
142
+ text = lowercase(text)
143
+ text = collapse_whitespace(text)
144
+ return text
145
+
146
+
147
+ def english_cleaners(text):
148
+ """Pipeline for English text, including number and abbreviation expansion."""
149
+ text = convert_to_ascii(text)
150
+ text = lowercase(text)
151
+ text = expand_numbers(text)
152
+ text = expand_abbreviations(text)
153
+ text = collapse_whitespace(text)
154
+ text = text.replace('"', "")
155
+ return text
156
+
157
+
158
+ def lev_distance(s1, s2):
159
+ if len(s1) > len(s2):
160
+ s1, s2 = s2, s1
161
+
162
+ distances = range(len(s1) + 1)
163
+ for i2, c2 in enumerate(s2):
164
+ distances_ = [i2 + 1]
165
+ for i1, c1 in enumerate(s1):
166
+ if c1 == c2:
167
+ distances_.append(distances[i1])
168
+ else:
169
+ distances_.append(
170
+ 1 + min((distances[i1], distances[i1 + 1], distances_[-1]))
171
+ )
172
+ distances = distances_
173
+ return distances[-1]
174
+
175
+
176
+ DEFAULT_VOCAB_FILE = os.path.join(
177
+ os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json"
178
+ )
179
+
180
+
181
+ class VoiceBpeTokenizer:
182
+ def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
183
+ if vocab_file is not None:
184
+ self.tokenizer = Tokenizer.from_file(vocab_file)
185
+
186
+ def preprocess_text(self, txt):
187
+ txt = english_cleaners(txt)
188
+ return txt
189
+
190
+ def encode(self, txt):
191
+ txt = self.preprocess_text(txt)
192
+ txt = txt.replace(" ", "[SPACE]")
193
+ return self.tokenizer.encode(txt).ids
194
+
195
+ def decode(self, seq):
196
+ if isinstance(seq, torch.Tensor):
197
+ seq = seq.cpu().numpy()
198
+ txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
199
+ txt = txt.replace("[SPACE]", " ")
200
+ txt = txt.replace("[STOP]", "")
201
+ txt = txt.replace("[UNK]", "")
202
+ return txt
tortoise/utils/typical_sampling.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import LogitsWarper
3
+
4
+
5
+ class TypicalLogitsWarper(LogitsWarper):
6
+ def __init__(
7
+ self,
8
+ mass: float = 0.9,
9
+ filter_value: float = -float("Inf"),
10
+ min_tokens_to_keep: int = 1,
11
+ ):
12
+ self.filter_value = filter_value
13
+ self.mass = mass
14
+ self.min_tokens_to_keep = min_tokens_to_keep
15
+
16
+ def __call__(
17
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
18
+ ) -> torch.FloatTensor:
19
+ # calculate entropy
20
+ normalized = torch.nn.functional.log_softmax(scores, dim=-1)
21
+ p = torch.exp(normalized)
22
+ ent = -(normalized * p).nansum(-1, keepdim=True)
23
+
24
+ # shift and sort
25
+ shifted_scores = torch.abs((-normalized) - ent)
26
+ sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
27
+ sorted_logits = scores.gather(-1, sorted_indices)
28
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
29
+
30
+ # Remove tokens with cumulative mass above the threshold
31
+ last_ind = (cumulative_probs < self.mass).sum(dim=1)
32
+ last_ind[last_ind < 0] = 0
33
+ sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
34
+ 1, last_ind.view(-1, 1)
35
+ )
36
+ if self.min_tokens_to_keep > 1:
37
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
38
+ sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
39
+ indices_to_remove = sorted_indices_to_remove.scatter(
40
+ 1, sorted_indices, sorted_indices_to_remove
41
+ )
42
+
43
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
44
+ return scores
tortoise/utils/wav2vec_alignment.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ import torchaudio
5
+ from transformers import (
6
+ Wav2Vec2ForCTC,
7
+ Wav2Vec2FeatureExtractor,
8
+ Wav2Vec2CTCTokenizer,
9
+ Wav2Vec2Processor,
10
+ )
11
+
12
+ from tortoise.utils.audio import load_audio
13
+
14
+
15
+ def max_alignment(s1, s2, skip_character="~", record=None):
16
+ """
17
+ A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
18
+ used to replace that character.
19
+
20
+ Finally got to use my DP skills!
21
+ """
22
+ if record is None:
23
+ record = {}
24
+ assert (
25
+ skip_character not in s1
26
+ ), f"Found the skip character {skip_character} in the provided string, {s1}"
27
+ if len(s1) == 0:
28
+ return ""
29
+ if len(s2) == 0:
30
+ return skip_character * len(s1)
31
+ if s1 == s2:
32
+ return s1
33
+ if s1[0] == s2[0]:
34
+ return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record)
35
+
36
+ take_s1_key = (len(s1), len(s2) - 1)
37
+ if take_s1_key in record:
38
+ take_s1, take_s1_score = record[take_s1_key]
39
+ else:
40
+ take_s1 = max_alignment(s1, s2[1:], skip_character, record)
41
+ take_s1_score = len(take_s1.replace(skip_character, ""))
42
+ record[take_s1_key] = (take_s1, take_s1_score)
43
+
44
+ take_s2_key = (len(s1) - 1, len(s2))
45
+ if take_s2_key in record:
46
+ take_s2, take_s2_score = record[take_s2_key]
47
+ else:
48
+ take_s2 = max_alignment(s1[1:], s2, skip_character, record)
49
+ take_s2_score = len(take_s2.replace(skip_character, ""))
50
+ record[take_s2_key] = (take_s2, take_s2_score)
51
+
52
+ return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2
53
+
54
+
55
+ class Wav2VecAlignment:
56
+ """
57
+ Uses wav2vec2 to perform audio<->text alignment.
58
+ """
59
+
60
+ def __init__(self, device="cuda"):
61
+ self.model = Wav2Vec2ForCTC.from_pretrained(
62
+ "jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli"
63
+ ).cpu()
64
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
65
+ f"facebook/wav2vec2-large-960h"
66
+ )
67
+ self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
68
+ "jbetker/tacotron-symbols"
69
+ )
70
+ self.device = device
71
+
72
+ def align(self, audio, expected_text, audio_sample_rate=24000):
73
+ orig_len = audio.shape[-1]
74
+
75
+ with torch.no_grad():
76
+ self.model = self.model.to(self.device)
77
+ audio = audio.to(self.device)
78
+ audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
79
+ clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
80
+ logits = self.model(clip_norm).logits
81
+ self.model = self.model.cpu()
82
+
83
+ logits = logits[0]
84
+ pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
85
+
86
+ fixed_expectation = max_alignment(expected_text.lower(), pred_string)
87
+ w2v_compression = orig_len // logits.shape[0]
88
+ expected_tokens = self.tokenizer.encode(fixed_expectation)
89
+ expected_chars = list(fixed_expectation)
90
+ if len(expected_tokens) == 1:
91
+ return [0] # The alignment is simple; there is only one token.
92
+ expected_tokens.pop(0) # The first token is a given.
93
+ expected_chars.pop(0)
94
+
95
+ alignments = [0]
96
+
97
+ def pop_till_you_win():
98
+ if len(expected_tokens) == 0:
99
+ return None
100
+ popped = expected_tokens.pop(0)
101
+ popped_char = expected_chars.pop(0)
102
+ while popped_char == "~":
103
+ alignments.append(-1)
104
+ if len(expected_tokens) == 0:
105
+ return None
106
+ popped = expected_tokens.pop(0)
107
+ popped_char = expected_chars.pop(0)
108
+ return popped
109
+
110
+ next_expected_token = pop_till_you_win()
111
+ for i, logit in enumerate(logits):
112
+ top = logit.argmax()
113
+ if next_expected_token == top:
114
+ alignments.append(i * w2v_compression)
115
+ if len(expected_tokens) > 0:
116
+ next_expected_token = pop_till_you_win()
117
+ else:
118
+ break
119
+
120
+ pop_till_you_win()
121
+ if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)):
122
+ torch.save([audio, expected_text], "alignment_debug.pth")
123
+ assert False, (
124
+ "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to"
125
+ "your current working directory. Please report this along with the file so it can get fixed."
126
+ )
127
+
128
+ # Now fix up alignments. Anything with -1 should be interpolated.
129
+ alignments.append(
130
+ orig_len
131
+ ) # This'll get removed but makes the algorithm below more readable.
132
+ for i in range(len(alignments)):
133
+ if alignments[i] == -1:
134
+ for j in range(i + 1, len(alignments)):
135
+ if alignments[j] != -1:
136
+ next_found_token = j
137
+ break
138
+ for j in range(i, next_found_token):
139
+ gap = alignments[next_found_token] - alignments[i - 1]
140
+ alignments[j] = (j - i + 1) * gap // (
141
+ next_found_token - i + 1
142
+ ) + alignments[i - 1]
143
+
144
+ return alignments[:-1]
145
+
146
+ def redact(self, audio, expected_text, audio_sample_rate=24000):
147
+ if "[" not in expected_text:
148
+ return audio
149
+ splitted = expected_text.split("[")
150
+ fully_split = [splitted[0]]
151
+ for spl in splitted[1:]:
152
+ assert (
153
+ "]" in spl
154
+ ), 'Every "[" character must be paired with a "]" with no nesting.'
155
+ fully_split.extend(spl.split("]"))
156
+
157
+ # At this point, fully_split is a list of strings, with every other string being something that should be redacted.
158
+ non_redacted_intervals = []
159
+ last_point = 0
160
+ for i in range(len(fully_split)):
161
+ if i % 2 == 0:
162
+ end_interval = max(0, last_point + len(fully_split[i]) - 1)
163
+ non_redacted_intervals.append((last_point, end_interval))
164
+ last_point += len(fully_split[i])
165
+
166
+ bare_text = "".join(fully_split)
167
+ alignments = self.align(audio, bare_text, audio_sample_rate)
168
+
169
+ output_audio = []
170
+ for nri in non_redacted_intervals:
171
+ start, stop = nri
172
+ output_audio.append(audio[:, alignments[start] : alignments[stop]])
173
+ return torch.cat(output_audio, dim=-1)
tortoise/voices/angie/1.wav ADDED
Binary file (625 kB). View file
 
tortoise/voices/angie/2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d17f37eb96823dcafd7ba1030b299120a61a93da995ddd3eba44869fc6232c0
3
+ size 1101490
tortoise/voices/angie/3.wav ADDED
Binary file (827 kB). View file
 
tortoise/voices/applejack/1.wav ADDED
Binary file (328 kB). View file
 
tortoise/voices/applejack/2.wav ADDED
Binary file (331 kB). View file
 
tortoise/voices/applejack/3.wav ADDED
Binary file (327 kB). View file
 
tortoise/voices/cond_latent_example/pat.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b9178bae37af1f308c8e6847ff820cad65748e2c63fcf08e83cab8e3434e2b5
3
+ size 13223
tortoise/voices/daniel/1.wav ADDED
Binary file (618 kB). View file
 
tortoise/voices/daniel/2.wav ADDED
Binary file (329 kB). View file
 
tortoise/voices/daniel/3.wav ADDED
Binary file (371 kB). View file
 
tortoise/voices/daniel/4.wav ADDED
Binary file (275 kB). View file
 
tortoise/voices/deniro/1.wav ADDED
Binary file (407 kB). View file
 
tortoise/voices/deniro/2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0adee547ffefa9bb8602b6e28535538966f4de6b45e344284a2b2580bb1288b
3
+ size 1219858
tortoise/voices/deniro/3.wav ADDED
Binary file (793 kB). View file
 
tortoise/voices/deniro/4.wav ADDED
Binary file (942 kB). View file
 
tortoise/voices/deutsch/de_speaker_2.mp3 ADDED
Binary file (24.2 kB). View file
 
tortoise/voices/deutsch/de_speaker_3.mp3 ADDED
Binary file (30.7 kB). View file