Pranjal12345 commited on
Commit
97e4faf
1 Parent(s): acb3965

Upload 128 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +124 -0
  2. requirements.txt +25 -0
  3. tortoise/__init__.py +0 -0
  4. tortoise/api.py +418 -0
  5. tortoise/data/.gitattributes +1 -0
  6. tortoise/data/mel_norms.pth +3 -0
  7. tortoise/data/tokenizer.json +1 -0
  8. tortoise/do_tts.py +52 -0
  9. tortoise/eval.py +27 -0
  10. tortoise/get_conditioning_latents.py +30 -0
  11. tortoise/is_this_from_tortoise.py +14 -0
  12. tortoise/models/__init__.py +0 -0
  13. tortoise/models/arch_util.py +373 -0
  14. tortoise/models/autoregressive.py +582 -0
  15. tortoise/models/classifier.py +148 -0
  16. tortoise/models/clvp.py +155 -0
  17. tortoise/models/cvvp.py +142 -0
  18. tortoise/models/diffusion_decoder.py +336 -0
  19. tortoise/models/hifigan_decoder.py +302 -0
  20. tortoise/models/random_latent_generator.py +55 -0
  21. tortoise/models/stream_generator.py +1057 -0
  22. tortoise/models/transformer.py +219 -0
  23. tortoise/models/vocoder.py +327 -0
  24. tortoise/models/xtransformers.py +1248 -0
  25. tortoise/read.py +101 -0
  26. tortoise/utils/__init__.py +0 -0
  27. tortoise/utils/audio.py +189 -0
  28. tortoise/utils/diffusion.py +1250 -0
  29. tortoise/utils/stft.py +193 -0
  30. tortoise/utils/text.py +132 -0
  31. tortoise/utils/tokenizer.py +194 -0
  32. tortoise/utils/typical_sampling.py +33 -0
  33. tortoise/utils/wav2vec_alignment.py +150 -0
  34. tortoise/voices/angie/1.wav +0 -0
  35. tortoise/voices/angie/2.wav +0 -0
  36. tortoise/voices/angie/3.wav +0 -0
  37. tortoise/voices/applejack/1.wav +0 -0
  38. tortoise/voices/applejack/2.wav +0 -0
  39. tortoise/voices/applejack/3.wav +0 -0
  40. tortoise/voices/atkins/1.wav +0 -0
  41. tortoise/voices/atkins/2.wav +0 -0
  42. tortoise/voices/daniel/1.wav +0 -0
  43. tortoise/voices/daniel/2.wav +0 -0
  44. tortoise/voices/daniel/3.wav +0 -0
  45. tortoise/voices/daniel/4.wav +0 -0
  46. tortoise/voices/daws/1.mp3 +0 -0
  47. tortoise/voices/daws/2.mp3 +0 -0
  48. tortoise/voices/daws/3.mp3 +0 -0
  49. tortoise/voices/deniro/1.wav +0 -0
  50. tortoise/voices/deniro/2.wav +0 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import torchaudio
5
+ import time
6
+ from datetime import datetime
7
+ from tortoise.api import TextToSpeech
8
+ from tortoise.utils.text import split_and_recombine_text
9
+ from tortoise.utils.audio import load_audio, load_voice, load_voices
10
+
11
+ VOICE_OPTIONS = [
12
+ "angie",
13
+ "deniro",
14
+ "freeman",
15
+ "random", # special option for random voice
16
+ ]
17
+
18
+
19
+ def inference(
20
+ text,
21
+ script,
22
+ voice,
23
+ voice_b,
24
+ seed,
25
+ split_by_newline,
26
+ ):
27
+ if text is None or text.strip() == "":
28
+ with open(script.name) as f:
29
+ text = f.read()
30
+ if text.strip() == "":
31
+ raise gr.Error("Please provide either text or script file with content.")
32
+
33
+ if split_by_newline == "Yes":
34
+ texts = list(filter(lambda x: x.strip() != "", text.split("\n")))
35
+ else:
36
+ texts = split_and_recombine_text(text)
37
+
38
+ voices = [voice]
39
+ if voice_b != "disabled":
40
+ voices.append(voice_b)
41
+
42
+ if len(voices) == 1:
43
+ voice_samples, conditioning_latents = load_voice(voice)
44
+ else:
45
+ voice_samples, conditioning_latents = load_voices(voices)
46
+
47
+ start_time = time.time()
48
+
49
+ # all_parts = []
50
+ for j, text in enumerate(texts):
51
+ for audio_frame in tts.tts_with_preset(
52
+ text,
53
+ voice_samples=voice_samples,
54
+ conditioning_latents=conditioning_latents,
55
+ preset="ultra_fast",
56
+ k=1
57
+ ):
58
+ # print("Time taken: ", time.time() - start_time)
59
+ # all_parts.append(audio_frame)
60
+ yield (24000, audio_frame.cpu().detach().numpy())
61
+
62
+ # wav = torch.cat(all_parts, dim=0).unsqueeze(0)
63
+ # print(wav.shape)
64
+ # torchaudio.save("output.wav", wav.cpu(), 24000)
65
+ # yield (None, gr.make_waveform(audio="output.wav",))
66
+ def main():
67
+ title = "Tortoise TTS 🐢"
68
+ description = """
69
+ A text-to-speech system which powers lot of organizations in Speech synthesis domain.
70
+ <br/>
71
+ a model with strong multi-voice capabilities, highly realistic prosody and intonation.
72
+ <br/>
73
+ for faster inference, use the 'ultra_fast' preset and duplicate space if you don't want to wait in a queue.
74
+ <br/>
75
+ """
76
+ text = gr.Textbox(
77
+ lines=4,
78
+ label="Text (Provide either text, or upload a newline separated text file below):",
79
+ )
80
+ script = gr.File(label="Upload a text file")
81
+
82
+ voice = gr.Dropdown(
83
+ VOICE_OPTIONS, value="jane_eyre", label="Select voice:", type="value"
84
+ )
85
+ voice_b = gr.Dropdown(
86
+ VOICE_OPTIONS,
87
+ value="disabled",
88
+ label="(Optional) Select second voice:",
89
+ type="value",
90
+ )
91
+ split_by_newline = gr.Radio(
92
+ ["Yes", "No"],
93
+ label="Split by newline (If [No], it will automatically try to find relevant splits):",
94
+ type="value",
95
+ value="No",
96
+ )
97
+
98
+ output_audio = gr.Audio(label="streaming audio:", streaming=True, autoplay=True)
99
+ # download_audio = gr.Audio(label="dowanload audio:")
100
+ interface = gr.Interface(
101
+ fn=inference,
102
+ inputs=[
103
+ text,
104
+ script,
105
+ voice,
106
+ voice_b,
107
+ split_by_newline,
108
+ ],
109
+ title=title,
110
+ description=description,
111
+ outputs=[output_audio],
112
+ )
113
+ interface.queue().launch()
114
+
115
+
116
+ if __name__ == "__main__":
117
+ tts = TextToSpeech(kv_cache=True, use_deepspeed=True, half=True)
118
+
119
+ with open("Tortoise_TTS_Runs_Scripts.log", "a") as f:
120
+ f.write(
121
+ f"\n\n-------------------------Tortoise TTS Scripts Logs, {datetime.now()}-------------------------\n"
122
+ )
123
+
124
+ main()
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ rotary_embedding_torch
3
+ transformers==4.31.0
4
+ tokenizers
5
+ inflect
6
+ progressbar
7
+ einops==0.4.1
8
+ unidecode
9
+ scipy
10
+ librosa==0.9.1
11
+ ffmpeg
12
+ numpy
13
+ numba
14
+ torch==2.0.0
15
+ torchaudio==2.0.0
16
+ threadpoolctl
17
+ llvmlite
18
+ appdirs
19
+ nbconvert==5.3.1
20
+ tornado==4.2
21
+ pydantic==1.9.1
22
+ deepspeed==0.8.3
23
+ py-cpuinfo
24
+ hjson
25
+ psutil
tortoise/__init__.py ADDED
File without changes
tortoise/api.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ from time import time
5
+ from urllib import request
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import progressbar
10
+ import torchaudio
11
+ import numpy as np
12
+ from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead
13
+ from tortoise.models.diffusion_decoder import DiffusionTts
14
+ from tortoise.models.autoregressive import UnifiedVoice
15
+ from tqdm import tqdm
16
+ from tortoise.models.arch_util import TorchMelSpectrogram
17
+ from tortoise.models.clvp import CLVP
18
+ from tortoise.models.cvvp import CVVP
19
+ from tortoise.models.hifigan_decoder import HifiganGenerator
20
+ from tortoise.models.random_latent_generator import RandomLatentConverter
21
+ from tortoise.models.vocoder import UnivNetGenerator
22
+ from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
23
+ from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
24
+ from tortoise.utils.tokenizer import VoiceBpeTokenizer
25
+ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
26
+ from contextlib import contextmanager
27
+ from tortoise.models.stream_generator import init_stream_support
28
+ from huggingface_hub import hf_hub_download
29
+ pbar = None
30
+ init_stream_support()
31
+ DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models')
32
+ MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR)
33
+
34
+ MODELS = {
35
+ 'autoregressive.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/autoregressive.pth',
36
+ 'classifier.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/classifier.pth',
37
+ 'rlg_auto.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/rlg_auto.pth',
38
+ 'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth',
39
+ }
40
+
41
+ def get_model_path(model_name, models_dir=MODELS_DIR):
42
+ """
43
+ Get path to given model, download it if it doesn't exist.
44
+ """
45
+ if model_name not in MODELS:
46
+ raise ValueError(f'Model {model_name} not found in available models.')
47
+ model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=MODELS_DIR)
48
+ return model_path
49
+
50
+
51
+ def pad_or_truncate(t, length):
52
+ """
53
+ Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
54
+ """
55
+ if t.shape[-1] == length:
56
+ return t
57
+ elif t.shape[-1] < length:
58
+ return F.pad(t, (0, length-t.shape[-1]))
59
+ else:
60
+ return t[..., :length]
61
+
62
+
63
+ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
64
+ """
65
+ Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
66
+ """
67
+ return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
68
+ model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
69
+ conditioning_free=cond_free, conditioning_free_k=cond_free_k)
70
+
71
+
72
+ def format_conditioning(clip, cond_length=132300, device="cuda" if not torch.backends.mps.is_available() else 'mps'):
73
+ """
74
+ Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
75
+ """
76
+ gap = clip.shape[-1] - cond_length
77
+ if gap < 0:
78
+ clip = F.pad(clip, pad=(0, abs(gap)))
79
+ elif gap > 0:
80
+ rand_start = random.randint(0, gap)
81
+ clip = clip[:, rand_start:rand_start + cond_length]
82
+ mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
83
+ return mel_clip.unsqueeze(0).to(device)
84
+
85
+
86
+ def fix_autoregressive_output(codes, stop_token, complain=True):
87
+ """
88
+ This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
89
+ trained on and what the autoregressive code generator creates (which has no padding or end).
90
+ This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
91
+ a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
92
+ and copying out the last few codes.
93
+
94
+ Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
95
+ """
96
+ # Strip off the autoregressive stop token and add padding.
97
+ stop_token_indices = (codes == stop_token).nonzero()
98
+ if len(stop_token_indices) == 0:
99
+ if complain:
100
+ print("No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
101
+ "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
102
+ "try breaking up your input text.")
103
+ return codes
104
+ else:
105
+ codes[stop_token_indices] = 83
106
+ stm = stop_token_indices.min().item()
107
+ codes[stm:] = 83
108
+ if stm - 3 < codes.shape[0]:
109
+ codes[-3] = 45
110
+ codes[-2] = 45
111
+ codes[-1] = 248
112
+
113
+ return codes
114
+
115
+
116
+ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True):
117
+ """
118
+ Uses the specified diffusion model to convert discrete codes into a spectrogram.
119
+ """
120
+ with torch.no_grad():
121
+ output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
122
+ output_shape = (latents.shape[0], 100, output_seq_len)
123
+ precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False)
124
+
125
+ noise = torch.randn(output_shape, device=latents.device) * temperature
126
+ mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
127
+ model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
128
+ progress=verbose)
129
+ return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
130
+
131
+
132
+ def classify_audio_clip(clip):
133
+ """
134
+ Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
135
+ :param clip: torch tensor containing audio waveform data (get it from load_audio)
136
+ :return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
137
+ """
138
+ classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
139
+ resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
140
+ dropout=0, kernel_size=5, distribute_zero_label=False)
141
+ classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu')))
142
+ clip = clip.cpu().unsqueeze(0)
143
+ results = F.softmax(classifier(clip), dim=-1)
144
+ return results[0][0]
145
+
146
+
147
+ def pick_best_batch_size_for_gpu():
148
+ """
149
+ Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
150
+ you a good shot.
151
+ """
152
+ if torch.cuda.is_available():
153
+ _, available = torch.cuda.mem_get_info()
154
+ availableGb = available / (1024 ** 3)
155
+ if availableGb > 14:
156
+ return 16
157
+ elif availableGb > 10:
158
+ return 8
159
+ elif availableGb > 7:
160
+ return 4
161
+ if torch.backends.mps.is_available():
162
+ import psutil
163
+ available = psutil.virtual_memory().total
164
+ availableGb = available / (1024 ** 3)
165
+ if availableGb > 14:
166
+ return 16
167
+ elif availableGb > 10:
168
+ return 8
169
+ elif availableGb > 7:
170
+ return 4
171
+ return 1
172
+
173
+ class TextToSpeech:
174
+ """
175
+ Main entry point into Tortoise.
176
+ """
177
+
178
+ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR,
179
+ enable_redaction=True, kv_cache=False, use_deepspeed=False, half=False, device=None,
180
+ tokenizer_vocab_file=None, tokenizer_basic=False):
181
+
182
+ """
183
+ Constructor
184
+ :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
185
+ GPU OOM errors. Larger numbers generates slightly faster.
186
+ :param models_dir: Where model weights are stored. This should only be specified if you are providing your own
187
+ models, otherwise use the defaults.
188
+ :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
189
+ (but are still rendered by the model). This can be used for prompt engineering.
190
+ Default is true.
191
+ :param device: Device to use when running the model. If omitted, the device will be automatically chosen.
192
+ """
193
+ self.models_dir = models_dir
194
+ self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
195
+ self.enable_redaction = enable_redaction
196
+ self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
197
+ if torch.backends.mps.is_available():
198
+ self.device = torch.device('mps')
199
+ if self.enable_redaction:
200
+ self.aligner = Wav2VecAlignment()
201
+
202
+ self.tokenizer = VoiceBpeTokenizer(
203
+ vocab_file=tokenizer_vocab_file,
204
+ use_basic_cleaners=tokenizer_basic,
205
+ )
206
+ self.half = half
207
+ if os.path.exists(f'{models_dir}/autoregressive.ptt'):
208
+ # Assume this is a traced directory.
209
+ self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
210
+ else:
211
+ self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
212
+ model_dim=1024,
213
+ heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
214
+ train_solo_embeddings=False).to(self.device).eval()
215
+ self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
216
+ self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
217
+
218
+ self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1",
219
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11],
220
+ upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2],
221
+ cond_channels=1024).to(self.device).eval()
222
+ hifi_model = torch.load(get_model_path('hifidecoder.pth'), map_location=torch.device(self.device))
223
+ self.hifi_decoder.load_state_dict(hifi_model, strict=False)
224
+ # Random latent generators (RLGs) are loaded lazily.
225
+ self.rlg_auto = None
226
+ def get_conditioning_latents(self, voice_samples, return_mels=False):
227
+ """
228
+ Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
229
+ These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
230
+ properties.
231
+ :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
232
+ """
233
+ with torch.no_grad():
234
+ voice_samples = [v.to(self.device) for v in voice_samples]
235
+
236
+ auto_conds = []
237
+ if not isinstance(voice_samples, list):
238
+ voice_samples = [voice_samples]
239
+ for vs in voice_samples:
240
+ auto_conds.append(format_conditioning(vs, device=self.device))
241
+ auto_conds = torch.stack(auto_conds, dim=1)
242
+ auto_latent = self.autoregressive.get_conditioning(auto_conds)
243
+
244
+ if return_mels:
245
+ return auto_latent
246
+ else:
247
+ return auto_latent
248
+
249
+ def get_random_conditioning_latents(self):
250
+ # Lazy-load the RLG models.
251
+ if self.rlg_auto is None:
252
+ self.rlg_auto = RandomLatentConverter(1024).eval()
253
+ self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu')))
254
+ with torch.no_grad():
255
+ return self.rlg_auto(torch.tensor([0.0]))
256
+
257
+ def tts_with_preset(self, text, preset='fast', **kwargs):
258
+ """
259
+ Calls TTS with one of a set of preset generation parameters. Options:
260
+ 'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
261
+ 'fast': Decent quality speech at a decent inference rate. A good choice for mass inference.
262
+ 'standard': Very good quality. This is generally about as good as you are going to get.
263
+ 'high_quality': Use if you want the absolute best. This is not really worth the compute, though.
264
+ """
265
+ # Use generally found best tuning knobs for generation.
266
+ settings = {'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
267
+ 'top_p': .8,
268
+ 'cond_free_k': 2.0, 'diffusion_temperature': 1.0}
269
+ # Presets are defined here.
270
+ presets = {
271
+ 'ultra_fast': {'num_autoregressive_samples': 1, 'diffusion_iterations': 10},
272
+ 'fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 50},
273
+ 'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200},
274
+ 'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400},
275
+ }
276
+ settings.update(presets[preset])
277
+ settings.update(kwargs) # allow overriding of preset settings with kwargs
278
+ for audio_frame in self.tts(text, **settings):
279
+ yield audio_frame
280
+
281
+ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
282
+ """Handle chunk formatting in streaming mode"""
283
+ wav_chunk = wav_gen[:-overlap_len]
284
+ if wav_gen_prev is not None:
285
+ wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
286
+ if wav_overlap is not None:
287
+ crossfade_wav = wav_chunk[:overlap_len]
288
+ crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
289
+ wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
290
+ wav_chunk[:overlap_len] += crossfade_wav
291
+ wav_overlap = wav_gen[-overlap_len:]
292
+ wav_gen_prev = wav_gen
293
+ return wav_chunk, wav_gen_prev, wav_overlap
294
+
295
+
296
+ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
297
+ return_deterministic_state=False, overlap_wav_len=1024, stream_chunk_size=40,
298
+ # autoregressive generation parameters follow
299
+ num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
300
+ # CVVP parameters follow
301
+ cvvp_amount=.0,
302
+ # diffusion generation parameters follow
303
+ diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
304
+ **hf_generate_kwargs):
305
+ """
306
+ Produces an audio clip of the given text being spoken with the given reference voice.
307
+ :param text: Text to be spoken.
308
+ :param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data.
309
+ :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which
310
+ can be provided in lieu of voice_samples. This is ignored unless voice_samples=None.
311
+ Conditioning latents can be retrieved via get_conditioning_latents().
312
+ :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned.
313
+ :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true.
314
+ ~~AUTOREGRESSIVE KNOBS~~
315
+ :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
316
+ As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great".
317
+ :param temperature: The softmax temperature of the autoregressive model.
318
+ :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs.
319
+ :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence
320
+ of long silences or "uhhhhhhs", etc.
321
+ :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs.
322
+ :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second.
323
+ ~~DIFFUSION KNOBS~~
324
+ :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
325
+ the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
326
+ however.
327
+ :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for
328
+ each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output
329
+ of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and
330
+ dramatically improves realism.
331
+ :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
332
+ As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
333
+ Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k
334
+ :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
335
+ are the "mean" prediction of the diffusion network and will sound bland and smeared.
336
+ ~~OTHER STUFF~~
337
+ :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer.
338
+ Extra keyword args fed to this function get forwarded directly to that API. Documentation
339
+ here: https://huggingface.co/docs/transformers/internal/generation_utils
340
+ :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
341
+ Sample rate is 24kHz.
342
+ """
343
+ deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
344
+
345
+ text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
346
+ text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
347
+ assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
348
+ if voice_samples is not None:
349
+ auto_conditioning = self.get_conditioning_latents(voice_samples, return_mels=False)
350
+ else:
351
+ auto_conditioning = self.get_random_conditioning_latents()
352
+ auto_conditioning = auto_conditioning.to(self.device)
353
+
354
+ with torch.no_grad():
355
+ calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
356
+ if verbose:
357
+ print("Generating autoregressive samples..")
358
+ with torch.autocast(
359
+ device_type="cuda" , dtype=torch.float16, enabled=self.half
360
+ ):
361
+ fake_inputs = self.autoregressive.compute_embeddings(
362
+ auto_conditioning,
363
+ text_tokens,
364
+ )
365
+ gpt_generator = self.autoregressive.get_generator(
366
+ fake_inputs=fake_inputs,
367
+ top_k=50,
368
+ top_p=top_p,
369
+ temperature=temperature,
370
+ do_sample=True,
371
+ num_beams=1,
372
+ num_return_sequences=1,
373
+ length_penalty=float(length_penalty),
374
+ repetition_penalty=float(repetition_penalty),
375
+ output_attentions=False,
376
+ output_hidden_states=True,
377
+ **hf_generate_kwargs,
378
+ )
379
+ all_latents = []
380
+ codes_ = []
381
+ wav_gen_prev = None
382
+ wav_overlap = None
383
+ is_end = False
384
+ first_buffer = 40
385
+ while not is_end:
386
+ try:
387
+ with torch.autocast(
388
+ device_type="cuda", dtype=torch.float16, enabled=self.half
389
+ ):
390
+ codes, latent = next(gpt_generator)
391
+ all_latents += [latent]
392
+ codes_ += [codes]
393
+ except StopIteration:
394
+ is_end = True
395
+
396
+ if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)):
397
+ first_buffer = 0
398
+ gpt_latents = torch.cat(all_latents, dim=0)[None, :]
399
+ wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning)
400
+ wav_gen = wav_gen.squeeze()
401
+ wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
402
+ wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
403
+ )
404
+ codes_ = []
405
+ yield wav_chunk
406
+
407
+ def deterministic_state(self, seed=None):
408
+ """
409
+ Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
410
+ reproduced.
411
+ """
412
+ seed = int(time()) if seed is None else seed
413
+ torch.manual_seed(seed)
414
+ random.seed(seed)
415
+ # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
416
+ # torch.use_deterministic_algorithms(True)
417
+
418
+ return seed
tortoise/data/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
tortoise/data/mel_norms.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f69422a8a8f344c4fca2f0c6b8d41d2151d6615b7321e48e6bb15ae949b119c
3
+ size 1067
tortoise/data/tokenizer.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
tortoise/do_tts.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ import torchaudio
6
+
7
+ from api import TextToSpeech, MODELS_DIR
8
+ from utils.audio import load_voices
9
+
10
+ if __name__ == '__main__':
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--text', type=str, help='Text to speak.', default="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.")
13
+ parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
14
+ 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
15
+ parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='ultra_fast')
16
+ parser.add_argument('--use_deepspeed', type=str, help='use deepspeed or not for inference speed gain ~2x.', default=True)
17
+ parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True)
18
+ parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True)
19
+ parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
20
+ parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
21
+ 'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
22
+ parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
23
+ parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
24
+ parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
25
+ parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.'
26
+ 'Increasing this can in some cases reduce the likelihood of multiple speakers. Defaults to 0 (disabled)', default=.0)
27
+ args = parser.parse_args()
28
+ if torch.backends.mps.is_available():
29
+ args.use_deepspeed = False
30
+ os.makedirs(args.output_path, exist_ok=True)
31
+ tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
32
+
33
+ selected_voices = args.voice.split(',')
34
+ for k, selected_voice in enumerate(selected_voices):
35
+ if '&' in selected_voice:
36
+ voice_sel = selected_voice.split('&')
37
+ else:
38
+ voice_sel = [selected_voice]
39
+ voice_samples, conditioning_latents = load_voices(voice_sel)
40
+
41
+ gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
42
+ preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount)
43
+ if isinstance(gen, list):
44
+ for j, g in enumerate(gen):
45
+ torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)
46
+ else:
47
+ torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)
48
+
49
+ if args.produce_debug_state:
50
+ os.makedirs('debug_states', exist_ok=True)
51
+ torch.save(dbg_state, f'debug_states/do_tts_debug_{selected_voice}.pth')
52
+
tortoise/eval.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torchaudio
5
+
6
+ from api import TextToSpeech
7
+ from tortoise.utils.audio import load_audio
8
+
9
+ if __name__ == '__main__':
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--eval_path', type=str, help='Path to TSV test file', default="D:\\tmp\\tortoise-tts-eval\\test.tsv")
12
+ parser.add_argument('--output_path', type=str, help='Where to put results', default="D:\\tmp\\tortoise-tts-eval\\baseline")
13
+ parser.add_argument('--preset', type=str, help='Rendering preset.', default="standard")
14
+ args = parser.parse_args()
15
+ os.makedirs(args.output_path, exist_ok=True)
16
+
17
+ tts = TextToSpeech()
18
+
19
+ with open(args.eval_path, 'r', encoding='utf-8') as f:
20
+ lines = f.readlines()
21
+
22
+ for line in lines:
23
+ text, real = line.strip().split('\t')
24
+ conds = [load_audio(real, 22050)]
25
+ gen = tts.tts_with_preset(text, voice_samples=conds, conditioning_latents=None, preset=args.preset)
26
+ torchaudio.save(os.path.join(args.output_path, os.path.basename(real)), gen.squeeze(0).cpu(), 24000)
27
+
tortoise/get_conditioning_latents.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+
5
+ from api import TextToSpeech
6
+ from tortoise.utils.audio import load_audio, get_voices
7
+
8
+ """
9
+ Dumps the conditioning latents for the specified voice to disk. These are expressive latents which can be used for
10
+ other ML models, or can be augmented manually and fed back into Tortoise to affect vocal qualities.
11
+ """
12
+ if __name__ == '__main__':
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--voice', type=str, help='Selects the voice to convert to conditioning latents', default='pat2')
15
+ parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='../results/conditioning_latents')
16
+ args = parser.parse_args()
17
+ os.makedirs(args.output_path, exist_ok=True)
18
+
19
+ tts = TextToSpeech()
20
+ voices = get_voices()
21
+ selected_voices = args.voice.split(',')
22
+ for voice in selected_voices:
23
+ cond_paths = voices[voice]
24
+ conds = []
25
+ for cond_path in cond_paths:
26
+ c = load_audio(cond_path, 22050)
27
+ conds.append(c)
28
+ conditioning_latents = tts.get_conditioning_latents(conds)
29
+ torch.save(conditioning_latents, os.path.join(args.output_path, f'{voice}.pth'))
30
+
tortoise/is_this_from_tortoise.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from api import classify_audio_clip
4
+ from tortoise.utils.audio import load_audio
5
+
6
+ if __name__ == '__main__':
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="../examples/favorite_riding_hood.mp3")
9
+ args = parser.parse_args()
10
+
11
+ clip = load_audio(args.clip, 24000)
12
+ clip = clip[:, :220000]
13
+ prob = classify_audio_clip(clip)
14
+ print(f"This classifier thinks there is a {prob*100}% chance that this clip was generated from Tortoise.")
tortoise/models/__init__.py ADDED
File without changes
tortoise/models/arch_util.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from tortoise.models.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
10
+
11
+
12
+ def zero_module(module):
13
+ """
14
+ Zero out the parameters of a module and return it.
15
+ """
16
+ for p in module.parameters():
17
+ p.detach().zero_()
18
+ return module
19
+
20
+
21
+ class GroupNorm32(nn.GroupNorm):
22
+ def forward(self, x):
23
+ return super().forward(x.float()).type(x.dtype)
24
+
25
+
26
+ def normalization(channels):
27
+ """
28
+ Make a standard normalization layer.
29
+
30
+ :param channels: number of input channels.
31
+ :return: an nn.Module for normalization.
32
+ """
33
+ groups = 32
34
+ if channels <= 16:
35
+ groups = 8
36
+ elif channels <= 64:
37
+ groups = 16
38
+ while channels % groups != 0:
39
+ groups = int(groups / 2)
40
+ assert groups > 2
41
+ return GroupNorm32(groups, channels)
42
+
43
+
44
+ class QKVAttentionLegacy(nn.Module):
45
+ """
46
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
47
+ """
48
+
49
+ def __init__(self, n_heads):
50
+ super().__init__()
51
+ self.n_heads = n_heads
52
+
53
+ def forward(self, qkv, mask=None, rel_pos=None):
54
+ """
55
+ Apply QKV attention.
56
+
57
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
58
+ :return: an [N x (H * C) x T] tensor after attention.
59
+ """
60
+ bs, width, length = qkv.shape
61
+ assert width % (3 * self.n_heads) == 0
62
+ ch = width // (3 * self.n_heads)
63
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
64
+ scale = 1 / math.sqrt(math.sqrt(ch))
65
+ weight = torch.einsum(
66
+ "bct,bcs->bts", q * scale, k * scale
67
+ ) # More stable with f16 than dividing afterwards
68
+ if rel_pos is not None:
69
+ weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ if mask is not None:
72
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
73
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
74
+ weight = weight * mask
75
+ a = torch.einsum("bts,bcs->bct", weight, v)
76
+
77
+ return a.reshape(bs, -1, length)
78
+
79
+
80
+ class AttentionBlock(nn.Module):
81
+ """
82
+ An attention block that allows spatial positions to attend to each other.
83
+
84
+ Originally ported from here, but adapted to the N-d case.
85
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ channels,
91
+ num_heads=1,
92
+ num_head_channels=-1,
93
+ do_checkpoint=True,
94
+ relative_pos_embeddings=False,
95
+ ):
96
+ super().__init__()
97
+ self.channels = channels
98
+ self.do_checkpoint = do_checkpoint
99
+ if num_head_channels == -1:
100
+ self.num_heads = num_heads
101
+ else:
102
+ assert (
103
+ channels % num_head_channels == 0
104
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
105
+ self.num_heads = channels // num_head_channels
106
+ self.norm = normalization(channels)
107
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
108
+ # split heads before split qkv
109
+ self.attention = QKVAttentionLegacy(self.num_heads)
110
+
111
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
112
+ if relative_pos_embeddings:
113
+ self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
114
+ else:
115
+ self.relative_pos_embeddings = None
116
+
117
+ def forward(self, x, mask=None):
118
+ b, c, *spatial = x.shape
119
+ x = x.reshape(b, c, -1)
120
+ qkv = self.qkv(self.norm(x))
121
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
122
+ h = self.proj_out(h)
123
+ return (x + h).reshape(b, c, *spatial)
124
+
125
+
126
+ class Upsample(nn.Module):
127
+ """
128
+ An upsampling layer with an optional convolution.
129
+
130
+ :param channels: channels in the inputs and outputs.
131
+ :param use_conv: a bool determining if a convolution is applied.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv, out_channels=None, factor=4):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.factor = factor
140
+ if use_conv:
141
+ ksize = 5
142
+ pad = 2
143
+ self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
144
+
145
+ def forward(self, x):
146
+ assert x.shape[1] == self.channels
147
+ x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
148
+ if self.use_conv:
149
+ x = self.conv(x)
150
+ return x
151
+
152
+
153
+ class Downsample(nn.Module):
154
+ """
155
+ A downsampling layer with an optional convolution.
156
+
157
+ :param channels: channels in the inputs and outputs.
158
+ :param use_conv: a bool determining if a convolution is applied.
159
+ """
160
+
161
+ def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
162
+ super().__init__()
163
+ self.channels = channels
164
+ self.out_channels = out_channels or channels
165
+ self.use_conv = use_conv
166
+
167
+ stride = factor
168
+ if use_conv:
169
+ self.op = nn.Conv1d(
170
+ self.channels, self.out_channels, ksize, stride=stride, padding=pad
171
+ )
172
+ else:
173
+ assert self.channels == self.out_channels
174
+ self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
175
+
176
+ def forward(self, x):
177
+ assert x.shape[1] == self.channels
178
+ return self.op(x)
179
+
180
+
181
+ class ResBlock(nn.Module):
182
+ def __init__(
183
+ self,
184
+ channels,
185
+ dropout,
186
+ out_channels=None,
187
+ use_conv=False,
188
+ use_scale_shift_norm=False,
189
+ up=False,
190
+ down=False,
191
+ kernel_size=3,
192
+ ):
193
+ super().__init__()
194
+ self.channels = channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_scale_shift_norm = use_scale_shift_norm
199
+ padding = 1 if kernel_size == 3 else 2
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False)
211
+ self.x_upd = Upsample(channels, False)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False)
214
+ self.x_upd = Downsample(channels, False)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.out_layers = nn.Sequential(
219
+ normalization(self.out_channels),
220
+ nn.SiLU(),
221
+ nn.Dropout(p=dropout),
222
+ zero_module(
223
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
224
+ ),
225
+ )
226
+
227
+ if self.out_channels == channels:
228
+ self.skip_connection = nn.Identity()
229
+ elif use_conv:
230
+ self.skip_connection = nn.Conv1d(
231
+ channels, self.out_channels, kernel_size, padding=padding
232
+ )
233
+ else:
234
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
235
+
236
+ def forward(self, x):
237
+ if self.updown:
238
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
239
+ h = in_rest(x)
240
+ h = self.h_upd(h)
241
+ x = self.x_upd(x)
242
+ h = in_conv(h)
243
+ else:
244
+ h = self.in_layers(x)
245
+ h = self.out_layers(h)
246
+ return self.skip_connection(x) + h
247
+
248
+
249
+ class AudioMiniEncoder(nn.Module):
250
+ def __init__(self,
251
+ spec_dim,
252
+ embedding_dim,
253
+ base_channels=128,
254
+ depth=2,
255
+ resnet_blocks=2,
256
+ attn_blocks=4,
257
+ num_attn_heads=4,
258
+ dropout=0,
259
+ downsample_factor=2,
260
+ kernel_size=3):
261
+ super().__init__()
262
+ self.init = nn.Sequential(
263
+ nn.Conv1d(spec_dim, base_channels, 3, padding=1)
264
+ )
265
+ ch = base_channels
266
+ res = []
267
+ for l in range(depth):
268
+ for r in range(resnet_blocks):
269
+ res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
270
+ res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
271
+ ch *= 2
272
+ self.res = nn.Sequential(*res)
273
+ self.final = nn.Sequential(
274
+ normalization(ch),
275
+ nn.SiLU(),
276
+ nn.Conv1d(ch, embedding_dim, 1)
277
+ )
278
+ attn = []
279
+ for a in range(attn_blocks):
280
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
281
+ self.attn = nn.Sequential(*attn)
282
+ self.dim = embedding_dim
283
+
284
+ def forward(self, x):
285
+ h = self.init(x)
286
+ h = self.res(h)
287
+ h = self.final(h)
288
+ h = self.attn(h)
289
+ return h[:, :, 0]
290
+
291
+
292
+ DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth')
293
+
294
+
295
+ class TorchMelSpectrogram(nn.Module):
296
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000,
297
+ sampling_rate=22050, normalize=False, mel_norm_file=DEFAULT_MEL_NORM_FILE):
298
+ super().__init__()
299
+ # These are the default tacotron values for the MEL spectrogram.
300
+ self.filter_length = filter_length
301
+ self.hop_length = hop_length
302
+ self.win_length = win_length
303
+ self.n_mel_channels = n_mel_channels
304
+ self.mel_fmin = mel_fmin
305
+ self.mel_fmax = mel_fmax
306
+ self.sampling_rate = sampling_rate
307
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
308
+ win_length=self.win_length, power=2, normalized=normalize,
309
+ sample_rate=self.sampling_rate, f_min=self.mel_fmin,
310
+ f_max=self.mel_fmax, n_mels=self.n_mel_channels,
311
+ norm="slaney")
312
+ self.mel_norm_file = mel_norm_file
313
+ if self.mel_norm_file is not None:
314
+ self.mel_norms = torch.load(self.mel_norm_file)
315
+ else:
316
+ self.mel_norms = None
317
+
318
+ def forward(self, inp):
319
+ if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
320
+ inp = inp.squeeze(1)
321
+ assert len(inp.shape) == 2
322
+ if torch.backends.mps.is_available():
323
+ inp = inp.to('cpu')
324
+ self.mel_stft = self.mel_stft.to(inp.device)
325
+ mel = self.mel_stft(inp)
326
+ # Perform dynamic range compression
327
+ mel = torch.log(torch.clamp(mel, min=1e-5))
328
+ if self.mel_norms is not None:
329
+ self.mel_norms = self.mel_norms.to(mel.device)
330
+ mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
331
+ return mel
332
+
333
+
334
+ class CheckpointedLayer(nn.Module):
335
+ """
336
+ Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
337
+ checkpoint for all other args.
338
+ """
339
+ def __init__(self, wrap):
340
+ super().__init__()
341
+ self.wrap = wrap
342
+
343
+ def forward(self, x, *args, **kwargs):
344
+ for k, v in kwargs.items():
345
+ assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
346
+ partial = functools.partial(self.wrap, **kwargs)
347
+ return partial(x, *args)
348
+
349
+
350
+ class CheckpointedXTransformerEncoder(nn.Module):
351
+ """
352
+ Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
353
+ to channels-last that XTransformer expects.
354
+ """
355
+ def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
356
+ super().__init__()
357
+ self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
358
+ self.needs_permute = needs_permute
359
+ self.exit_permute = exit_permute
360
+
361
+ if not checkpoint:
362
+ return
363
+ for i in range(len(self.transformer.attn_layers.layers)):
364
+ n, b, r = self.transformer.attn_layers.layers[i]
365
+ self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
366
+
367
+ def forward(self, x, **kwargs):
368
+ if self.needs_permute:
369
+ x = x.permute(0,2,1)
370
+ h = self.transformer(x, **kwargs)
371
+ if self.exit_permute:
372
+ h = h.permute(0,2,1)
373
+ return h
tortoise/models/autoregressive.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
9
+ from tortoise.models.arch_util import AttentionBlock
10
+ from tortoise.utils.typical_sampling import TypicalLogitsWarper
11
+
12
+
13
+ def null_position_embeddings(range, dim):
14
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
15
+
16
+
17
+ class ResBlock(nn.Module):
18
+ """
19
+ Basic residual convolutional block that uses GroupNorm.
20
+ """
21
+ def __init__(self, chan):
22
+ super().__init__()
23
+ self.net = nn.Sequential(
24
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
25
+ nn.GroupNorm(chan//8, chan),
26
+ nn.ReLU(),
27
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
28
+ nn.GroupNorm(chan//8, chan)
29
+ )
30
+
31
+ def forward(self, x):
32
+ return F.relu(self.net(x) + x)
33
+
34
+
35
+ class GPT2InferenceModel(GPT2PreTrainedModel):
36
+ def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
37
+ super().__init__(config)
38
+ self.transformer = gpt
39
+ self.text_pos_embedding = text_pos_emb
40
+ self.embeddings = embeddings
41
+ self.final_norm = norm
42
+ self.lm_head = nn.Sequential(norm, linear)
43
+ self.kv_cache = kv_cache
44
+
45
+ # Model parallel
46
+ self.model_parallel = False
47
+ self.device_map = None
48
+ self.cached_mel_emb = None
49
+ def parallelize(self, device_map=None):
50
+ self.device_map = (
51
+ get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
52
+ if device_map is None
53
+ else device_map
54
+ )
55
+ assert_device_map(self.device_map, len(self.transformer.h))
56
+ self.transformer.parallelize(self.device_map)
57
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
58
+ self.model_parallel = True
59
+
60
+ def deparallelize(self):
61
+ self.transformer.deparallelize()
62
+ self.transformer = self.transformer.to("cpu")
63
+ self.lm_head = self.lm_head.to("cpu")
64
+ self.model_parallel = False
65
+ torch.cuda.empty_cache()
66
+ if torch.backends.mps.is_available():
67
+ torch.mps.empty_cache()
68
+
69
+ def get_output_embeddings(self):
70
+ return self.lm_head
71
+
72
+ def set_output_embeddings(self, new_embeddings):
73
+ self.lm_head = new_embeddings
74
+
75
+ def store_mel_emb(self, mel_emb):
76
+ self.cached_mel_emb = mel_emb
77
+
78
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
79
+ token_type_ids = kwargs.get("token_type_ids", None) # usually None
80
+ if not self.kv_cache:
81
+ past_key_values = None
82
+ # only last token for inputs_ids if past is defined in kwargs
83
+ if past_key_values:
84
+ input_ids = input_ids[:, -1].unsqueeze(-1)
85
+ if token_type_ids is not None:
86
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
87
+
88
+ attention_mask = kwargs.get("attention_mask", None)
89
+ position_ids = kwargs.get("position_ids", None)
90
+
91
+ if attention_mask is not None and position_ids is None:
92
+ # create position_ids on the fly for batch generation
93
+ position_ids = attention_mask.long().cumsum(-1) - 1
94
+ position_ids.masked_fill_(attention_mask == 0, 1)
95
+ if past_key_values:
96
+ position_ids = position_ids[:, -1].unsqueeze(-1)
97
+ else:
98
+ position_ids = None
99
+ return {
100
+ "input_ids": input_ids,
101
+ "past_key_values": past_key_values,
102
+ "use_cache": kwargs.get("use_cache"),
103
+ "position_ids": position_ids,
104
+ "attention_mask": attention_mask,
105
+ "token_type_ids": token_type_ids,
106
+ }
107
+
108
+ def forward(
109
+ self,
110
+ input_ids=None,
111
+ past_key_values=None,
112
+ attention_mask=None,
113
+ token_type_ids=None,
114
+ position_ids=None,
115
+ head_mask=None,
116
+ inputs_embeds=None,
117
+ encoder_hidden_states=None,
118
+ encoder_attention_mask=None,
119
+ labels=None,
120
+ use_cache=None,
121
+ output_attentions=None,
122
+ output_hidden_states=None,
123
+ return_dict=None,
124
+ ):
125
+ assert self.cached_mel_emb is not None
126
+ assert inputs_embeds is None # Not supported by this inference model.
127
+ assert labels is None # Training not supported by this inference model.
128
+ return_dict = (
129
+ return_dict if return_dict is not None else self.config.use_return_dict
130
+ )
131
+
132
+ # Create embedding
133
+ mel_len = self.cached_mel_emb.shape[1]
134
+ if input_ids.shape[1] != 1:
135
+ text_inputs = input_ids[:, mel_len:]
136
+ text_emb = self.embeddings(text_inputs)
137
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
138
+ if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
139
+ mel_emb = self.cached_mel_emb.repeat_interleave(
140
+ text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
141
+ )
142
+ else: # this outcome only occurs once per loop in most cases
143
+ mel_emb = self.cached_mel_emb
144
+ emb = torch.cat([mel_emb, text_emb], dim=1)
145
+ else:
146
+ emb = self.embeddings(input_ids)
147
+ emb = emb + self.text_pos_embedding.get_fixed_embedding(
148
+ attention_mask.shape[1] - mel_len, attention_mask.device
149
+ )
150
+ transformer_outputs = self.transformer(
151
+ inputs_embeds=emb,
152
+ past_key_values=past_key_values,
153
+ attention_mask=attention_mask,
154
+ token_type_ids=token_type_ids,
155
+ position_ids=position_ids,
156
+ head_mask=head_mask,
157
+ encoder_hidden_states=encoder_hidden_states,
158
+ encoder_attention_mask=encoder_attention_mask,
159
+ use_cache=use_cache,
160
+ output_attentions=output_attentions,
161
+ output_hidden_states=output_hidden_states,
162
+ return_dict=return_dict,
163
+ )
164
+ hidden_states = transformer_outputs[0]
165
+
166
+ # Set device for model parallelism
167
+ if self.model_parallel:
168
+ if torch.backends.mps.is_available():
169
+ self.to(self.transformer.first_device)
170
+ else:
171
+ torch.cuda.set_device(self.transformer.first_device)
172
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
173
+
174
+ lm_logits = self.lm_head(hidden_states)
175
+
176
+ if not return_dict:
177
+ return (lm_logits,) + transformer_outputs[1:]
178
+
179
+ return CausalLMOutputWithCrossAttentions(
180
+ loss=None,
181
+ logits=lm_logits,
182
+ past_key_values=transformer_outputs.past_key_values,
183
+ hidden_states=transformer_outputs.hidden_states,
184
+ attentions=transformer_outputs.attentions,
185
+ cross_attentions=transformer_outputs.cross_attentions,
186
+ )
187
+
188
+ @staticmethod
189
+ def _reorder_cache(past, beam_idx):
190
+ """
191
+ This function is used to re-order the :obj:`past_key_values` cache if
192
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
193
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
194
+ """
195
+ return tuple(
196
+ tuple(
197
+ past_state.index_select(0, beam_idx.to(past_state.device))
198
+ for past_state in layer_past
199
+ )
200
+ for layer_past in past
201
+ )
202
+
203
+
204
+ class ConditioningEncoder(nn.Module):
205
+ def __init__(self,
206
+ spec_dim,
207
+ embedding_dim,
208
+ attn_blocks=6,
209
+ num_attn_heads=4,
210
+ do_checkpointing=False,
211
+ mean=False):
212
+ super().__init__()
213
+ attn = []
214
+ self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
215
+ for a in range(attn_blocks):
216
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads))
217
+ self.attn = nn.Sequential(*attn)
218
+ self.dim = embedding_dim
219
+ self.do_checkpointing = do_checkpointing
220
+ self.mean = mean
221
+
222
+ def forward(self, x):
223
+ h = self.init(x)
224
+ h = self.attn(h)
225
+ if self.mean:
226
+ return h.mean(dim=2)
227
+ else:
228
+ return h[:, :, 0]
229
+
230
+
231
+ class LearnedPositionEmbeddings(nn.Module):
232
+ def __init__(self, seq_len, model_dim, init=.02):
233
+ super().__init__()
234
+ self.emb = nn.Embedding(seq_len, model_dim)
235
+ # Initializing this way is standard for GPT-2
236
+ self.emb.weight.data.normal_(mean=0.0, std=init)
237
+
238
+ def forward(self, x):
239
+ sl = x.shape[1]
240
+ return self.emb(torch.arange(0, sl, device=x.device))
241
+
242
+ def get_fixed_embedding(self, ind, dev):
243
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
244
+
245
+
246
+ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
247
+ """
248
+ GPT-2 implemented by the HuggingFace library.
249
+ """
250
+ from transformers import GPT2Config, GPT2Model
251
+ gpt_config = GPT2Config(vocab_size=256, # Unused.
252
+ n_positions=max_mel_seq_len+max_text_seq_len,
253
+ n_ctx=max_mel_seq_len+max_text_seq_len,
254
+ n_embd=model_dim,
255
+ n_layer=layers,
256
+ n_head=heads,
257
+ gradient_checkpointing=checkpointing,
258
+ use_cache=not checkpointing)
259
+ gpt = GPT2Model(gpt_config)
260
+ # Override the built in positional embeddings
261
+ del gpt.wpe
262
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
263
+ # Built-in token embeddings are unused.
264
+ del gpt.wte
265
+ return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
266
+ None, None
267
+
268
+
269
+ class MelEncoder(nn.Module):
270
+ def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
271
+ super().__init__()
272
+ self.channels = channels
273
+ self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
274
+ nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
275
+ nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
276
+ nn.GroupNorm(channels//16, channels//2),
277
+ nn.ReLU(),
278
+ nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
279
+ nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
280
+ nn.GroupNorm(channels//8, channels),
281
+ nn.ReLU(),
282
+ nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
283
+ )
284
+ self.reduction = 4
285
+
286
+
287
+ def forward(self, x):
288
+ for e in self.encoder:
289
+ x = e(x)
290
+ return x.permute(0,2,1)
291
+
292
+
293
+ class UnifiedVoice(nn.Module):
294
+ def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
295
+ mel_length_compression=1024, number_text_tokens=256,
296
+ start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
297
+ stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
298
+ checkpointing=True, types=1):
299
+ """
300
+ Args:
301
+ layers: Number of layers in transformer stack.
302
+ model_dim: Operating dimensions of the transformer
303
+ heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
304
+ max_text_tokens: Maximum number of text tokens that will be encountered by model.
305
+ max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
306
+ max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
307
+ mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
308
+ number_text_tokens:
309
+ start_text_token:
310
+ stop_text_token:
311
+ number_mel_codes:
312
+ start_mel_token:
313
+ stop_mel_token:
314
+ train_solo_embeddings:
315
+ use_mel_codes_as_input:
316
+ checkpointing:
317
+ """
318
+ super().__init__()
319
+
320
+ self.number_text_tokens = number_text_tokens
321
+ self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
322
+ self.stop_text_token = 0
323
+ self.number_mel_codes = number_mel_codes
324
+ self.start_mel_token = start_mel_token
325
+ self.stop_mel_token = stop_mel_token
326
+ self.layers = layers
327
+ self.heads = heads
328
+ self.max_mel_tokens = max_mel_tokens
329
+ self.max_text_tokens = max_text_tokens
330
+ self.model_dim = model_dim
331
+ self.max_conditioning_inputs = max_conditioning_inputs
332
+ self.mel_length_compression = mel_length_compression
333
+ self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
334
+ self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
335
+ if use_mel_codes_as_input:
336
+ self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
337
+ else:
338
+ self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
339
+ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
340
+ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
341
+ if train_solo_embeddings:
342
+ self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
343
+ self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
344
+ else:
345
+ self.mel_solo_embedding = 0
346
+ self.text_solo_embedding = 0
347
+
348
+ self.final_norm = nn.LayerNorm(model_dim)
349
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
350
+ self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
351
+
352
+ # Initialize the embeddings per the GPT-2 scheme
353
+ embeddings = [self.text_embedding]
354
+ if use_mel_codes_as_input:
355
+ embeddings.append(self.mel_embedding)
356
+ for module in embeddings:
357
+ module.weight.data.normal_(mean=0.0, std=.02)
358
+ def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
359
+ seq_length = self.max_mel_tokens + self.max_text_tokens + 2
360
+ gpt_config = GPT2Config(
361
+ vocab_size=self.max_mel_tokens,
362
+ n_positions=seq_length,
363
+ n_ctx=seq_length,
364
+ n_embd=self.model_dim,
365
+ n_layer=self.layers,
366
+ n_head=self.heads,
367
+ gradient_checkpointing=False,
368
+ use_cache=True,
369
+ )
370
+ self.inference_model = GPT2InferenceModel(
371
+ gpt_config,
372
+ self.gpt,
373
+ self.mel_pos_embedding,
374
+ self.mel_embedding,
375
+ self.final_norm,
376
+ self.mel_head,
377
+ kv_cache=kv_cache,
378
+ )
379
+ if use_deepspeed and half and torch.cuda.is_available():
380
+ import deepspeed
381
+ self.ds_engine = deepspeed.init_inference(model=self.inference_model,
382
+ mp_size=1,
383
+ replace_with_kernel_inject=True,
384
+ dtype=torch.float16)
385
+ self.inference_model = self.ds_engine.module.eval()
386
+ elif use_deepspeed and torch.cuda.is_available():
387
+ import deepspeed
388
+ self.ds_engine = deepspeed.init_inference(model=self.inference_model,
389
+ mp_size=1,
390
+ replace_with_kernel_inject=True,
391
+ dtype=torch.float32)
392
+ self.inference_model = self.ds_engine.module.eval()
393
+ else:
394
+ self.inference_model = self.inference_model.eval()
395
+
396
+ # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
397
+ self.gpt.wte = self.mel_embedding
398
+ def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
399
+ inp = F.pad(input, (1,0), value=start_token)
400
+ tar = F.pad(input, (0,1), value=stop_token)
401
+ return inp, tar
402
+
403
+ def set_mel_padding(self, mel_input_tokens, wav_lengths):
404
+ """
405
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
406
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
407
+ preformatting to create a working TTS model.
408
+ """
409
+ # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
410
+ mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
411
+ for b in range(len(mel_lengths)):
412
+ actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
413
+ if actual_end < mel_input_tokens.shape[-1]:
414
+ mel_input_tokens[b, actual_end:] = self.stop_mel_token
415
+ return mel_input_tokens
416
+
417
+ def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
418
+ if second_inputs is not None:
419
+ emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
420
+ else:
421
+ emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
422
+
423
+ gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
424
+ if get_attns:
425
+ return gpt_out.attentions
426
+
427
+ enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
428
+ enc = self.final_norm(enc)
429
+
430
+ if return_latent:
431
+ return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
432
+
433
+ first_logits = enc[:, :first_inputs.shape[1]]
434
+ first_logits = first_head(first_logits)
435
+ first_logits = first_logits.permute(0,2,1)
436
+ if second_inputs is not None:
437
+ second_logits = enc[:, -second_inputs.shape[1]:]
438
+ second_logits = second_head(second_logits)
439
+ second_logits = second_logits.permute(0,2,1)
440
+ return first_logits, second_logits
441
+ else:
442
+ return first_logits
443
+
444
+ def get_conditioning(self, speech_conditioning_input):
445
+ speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
446
+ speech_conditioning_input.shape) == 3 else speech_conditioning_input
447
+ conds = []
448
+ for j in range(speech_conditioning_input.shape[1]):
449
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
450
+ conds = torch.stack(conds, dim=1)
451
+ conds = conds.mean(dim=1)
452
+ return conds
453
+
454
+ def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
455
+ return_latent=False, clip_inputs=True):
456
+ """
457
+ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
458
+ (actuated by `text_first`).
459
+
460
+ speech_conditioning_input: MEL float tensor, (b,1024)
461
+ text_inputs: long tensor, (b,t)
462
+ text_lengths: long tensor, (b,)
463
+ mel_inputs: long tensor, (b,m)
464
+ wav_lengths: long tensor, (b,)
465
+ raw_mels: MEL float tensor (b,80,s)
466
+
467
+ If return_attentions is specified, only logits are returned.
468
+ If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
469
+ If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
470
+ """
471
+ # Types are expressed by expanding the text embedding space.
472
+ if types is not None:
473
+ text_inputs = text_inputs * (1+types).unsqueeze(-1)
474
+
475
+ if clip_inputs:
476
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
477
+ # chopping the inputs by the maximum actual length.
478
+ max_text_len = text_lengths.max()
479
+ text_inputs = text_inputs[:, :max_text_len]
480
+ max_mel_len = wav_lengths.max() // self.mel_length_compression
481
+ mel_codes = mel_codes[:, :max_mel_len]
482
+ if raw_mels is not None:
483
+ raw_mels = raw_mels[:, :, :max_mel_len*4]
484
+ mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
485
+ text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
486
+ mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
487
+
488
+ conds = speech_conditioning_latent.unsqueeze(1)
489
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
490
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
491
+ mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
492
+ if raw_mels is not None:
493
+ mel_inp = F.pad(raw_mels, (0, 8))
494
+ else:
495
+ mel_inp = mel_codes
496
+ mel_emb = self.mel_embedding(mel_inp)
497
+ mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
498
+
499
+ if text_first:
500
+ text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
501
+ if return_latent:
502
+ return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
503
+ else:
504
+ mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
505
+ if return_latent:
506
+ return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
507
+
508
+ if return_attentions:
509
+ return mel_logits
510
+ loss_text = F.cross_entropy(text_logits, text_targets.long())
511
+ loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
512
+ return loss_text.mean(), loss_mel.mean(), mel_logits
513
+ def compute_embeddings(
514
+ self,
515
+ cond_latents,
516
+ text_inputs,
517
+ ):
518
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
519
+ text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
520
+ emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
521
+ conds = cond_latents.unsqueeze(1)
522
+ emb = torch.cat([conds, emb], dim=1)
523
+ self.inference_model.store_mel_emb(emb)
524
+ gpt_inputs = torch.full(
525
+ (
526
+ emb.shape[0],
527
+ emb.shape[1] + 1, # +1 for the start_mel_token
528
+ ),
529
+ fill_value=1,
530
+ dtype=torch.long,
531
+ device=text_inputs.device,
532
+ )
533
+ gpt_inputs[:, -1] = self.start_mel_token
534
+ return gpt_inputs
535
+ def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
536
+ max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
537
+
538
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
539
+ text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
540
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
541
+
542
+ conds = speech_conditioning_latent.unsqueeze(1)
543
+ emb = torch.cat([conds, text_emb], dim=1)
544
+ self.inference_model.store_mel_emb(emb)
545
+
546
+ fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
547
+ device=text_inputs.device)
548
+ fake_inputs[:, -1] = self.start_mel_token
549
+ trunc_index = fake_inputs.shape[1]
550
+ if input_tokens is None:
551
+ inputs = fake_inputs
552
+ else:
553
+ assert num_return_sequences % input_tokens.shape[0] == 0, "The number of return sequences must be divisible by the number of input sequences"
554
+ fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
555
+ input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
556
+ inputs = torch.cat([fake_inputs, input_tokens], dim=1)
557
+
558
+ logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
559
+ max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
560
+ gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
561
+ max_length=max_length, logits_processor=logits_processor,
562
+ num_return_sequences=num_return_sequences, **hf_generate_kwargs)
563
+ return gen[:, trunc_index:]
564
+
565
+ def get_generator(self, fake_inputs, **hf_generate_kwargs):
566
+ return self.inference_model.generate_stream(
567
+ fake_inputs,
568
+ bos_token_id=self.start_mel_token,
569
+ pad_token_id=self.stop_mel_token,
570
+ eos_token_id=self.stop_mel_token,
571
+ max_length=500,
572
+ do_stream=True,
573
+ **hf_generate_kwargs,
574
+ )
575
+ if __name__ == '__main__':
576
+ gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
577
+ l = gpt(torch.randn(2, 3, 80, 800),
578
+ torch.randint(high=120, size=(2,120)),
579
+ torch.tensor([32, 120]),
580
+ torch.randint(high=8192, size=(2,250)),
581
+ torch.tensor([250*256,195*256]))
582
+ gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
tortoise/models/classifier.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
5
+
6
+
7
+ class ResBlock(nn.Module):
8
+ def __init__(
9
+ self,
10
+ channels,
11
+ dropout,
12
+ out_channels=None,
13
+ use_conv=False,
14
+ use_scale_shift_norm=False,
15
+ dims=2,
16
+ up=False,
17
+ down=False,
18
+ kernel_size=3,
19
+ do_checkpoint=True,
20
+ ):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.dropout = dropout
24
+ self.out_channels = out_channels or channels
25
+ self.use_conv = use_conv
26
+ self.use_scale_shift_norm = use_scale_shift_norm
27
+ self.do_checkpoint = do_checkpoint
28
+ padding = 1 if kernel_size == 3 else 2
29
+
30
+ self.in_layers = nn.Sequential(
31
+ normalization(channels),
32
+ nn.SiLU(),
33
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
34
+ )
35
+
36
+ self.updown = up or down
37
+
38
+ if up:
39
+ self.h_upd = Upsample(channels, False, dims)
40
+ self.x_upd = Upsample(channels, False, dims)
41
+ elif down:
42
+ self.h_upd = Downsample(channels, False, dims)
43
+ self.x_upd = Downsample(channels, False, dims)
44
+ else:
45
+ self.h_upd = self.x_upd = nn.Identity()
46
+
47
+ self.out_layers = nn.Sequential(
48
+ normalization(self.out_channels),
49
+ nn.SiLU(),
50
+ nn.Dropout(p=dropout),
51
+ zero_module(
52
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
53
+ ),
54
+ )
55
+
56
+ if self.out_channels == channels:
57
+ self.skip_connection = nn.Identity()
58
+ elif use_conv:
59
+ self.skip_connection = nn.Conv1d(
60
+ dims, channels, self.out_channels, kernel_size, padding=padding
61
+ )
62
+ else:
63
+ self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
64
+
65
+ def forward(self, x):
66
+ if self.updown:
67
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
68
+ h = in_rest(x)
69
+ h = self.h_upd(h)
70
+ x = self.x_upd(x)
71
+ h = in_conv(h)
72
+ else:
73
+ h = self.in_layers(x)
74
+ h = self.out_layers(h)
75
+ return self.skip_connection(x) + h
76
+
77
+
78
+ class AudioMiniEncoder(nn.Module):
79
+ def __init__(self,
80
+ spec_dim,
81
+ embedding_dim,
82
+ base_channels=128,
83
+ depth=2,
84
+ resnet_blocks=2,
85
+ attn_blocks=4,
86
+ num_attn_heads=4,
87
+ dropout=0,
88
+ downsample_factor=2,
89
+ kernel_size=3):
90
+ super().__init__()
91
+ self.init = nn.Sequential(
92
+ nn.Conv1d(spec_dim, base_channels, 3, padding=1)
93
+ )
94
+ ch = base_channels
95
+ res = []
96
+ self.layers = depth
97
+ for l in range(depth):
98
+ for r in range(resnet_blocks):
99
+ res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
100
+ res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
101
+ ch *= 2
102
+ self.res = nn.Sequential(*res)
103
+ self.final = nn.Sequential(
104
+ normalization(ch),
105
+ nn.SiLU(),
106
+ nn.Conv1d(ch, embedding_dim, 1)
107
+ )
108
+ attn = []
109
+ for a in range(attn_blocks):
110
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
111
+ self.attn = nn.Sequential(*attn)
112
+ self.dim = embedding_dim
113
+
114
+ def forward(self, x):
115
+ h = self.init(x)
116
+ h = self.res(h)
117
+ h = self.final(h)
118
+ for blk in self.attn:
119
+ h = blk(h)
120
+ return h[:, :, 0]
121
+
122
+
123
+ class AudioMiniEncoderWithClassifierHead(nn.Module):
124
+ def __init__(self, classes, distribute_zero_label=True, **kwargs):
125
+ super().__init__()
126
+ self.enc = AudioMiniEncoder(**kwargs)
127
+ self.head = nn.Linear(self.enc.dim, classes)
128
+ self.num_classes = classes
129
+ self.distribute_zero_label = distribute_zero_label
130
+
131
+ def forward(self, x, labels=None):
132
+ h = self.enc(x)
133
+ logits = self.head(h)
134
+ if labels is None:
135
+ return logits
136
+ else:
137
+ if self.distribute_zero_label:
138
+ oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
139
+ zeros_indices = (labels == 0).unsqueeze(-1)
140
+ # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
141
+ zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
142
+ zero_extra_mass[:, 0] = -.2
143
+ zero_extra_mass = zero_extra_mass * zeros_indices
144
+ oh_labels = oh_labels + zero_extra_mass
145
+ else:
146
+ oh_labels = labels
147
+ loss = nn.functional.cross_entropy(logits, oh_labels)
148
+ return loss
tortoise/models/clvp.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import einsum
5
+
6
+ from tortoise.models.arch_util import CheckpointedXTransformerEncoder
7
+ from tortoise.models.transformer import Transformer
8
+ from tortoise.models.xtransformers import Encoder
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def masked_mean(t, mask, dim = 1):
16
+ t = t.masked_fill(~mask[:, :, None], 0.)
17
+ return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
18
+
19
+ class CLVP(nn.Module):
20
+ """
21
+ CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
22
+ transcribed text.
23
+
24
+ Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ *,
30
+ dim_text=512,
31
+ dim_speech=512,
32
+ dim_latent=512,
33
+ num_text_tokens=256,
34
+ text_enc_depth=6,
35
+ text_seq_len=120,
36
+ text_heads=8,
37
+ num_speech_tokens=8192,
38
+ speech_enc_depth=6,
39
+ speech_heads=8,
40
+ speech_seq_len=250,
41
+ text_mask_percentage=0,
42
+ voice_mask_percentage=0,
43
+ wav_token_compression=1024,
44
+ use_xformers=False,
45
+ ):
46
+ super().__init__()
47
+ self.text_emb = nn.Embedding(num_text_tokens, dim_text)
48
+ self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
49
+
50
+ self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
51
+ self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
52
+
53
+ if use_xformers:
54
+ self.text_transformer = CheckpointedXTransformerEncoder(
55
+ needs_permute=False,
56
+ exit_permute=False,
57
+ max_seq_len=-1,
58
+ attn_layers=Encoder(
59
+ dim=dim_text,
60
+ depth=text_enc_depth,
61
+ heads=text_heads,
62
+ ff_dropout=.1,
63
+ ff_mult=2,
64
+ attn_dropout=.1,
65
+ use_rmsnorm=True,
66
+ ff_glu=True,
67
+ rotary_pos_emb=True,
68
+ ))
69
+ self.speech_transformer = CheckpointedXTransformerEncoder(
70
+ needs_permute=False,
71
+ exit_permute=False,
72
+ max_seq_len=-1,
73
+ attn_layers=Encoder(
74
+ dim=dim_speech,
75
+ depth=speech_enc_depth,
76
+ heads=speech_heads,
77
+ ff_dropout=.1,
78
+ ff_mult=2,
79
+ attn_dropout=.1,
80
+ use_rmsnorm=True,
81
+ ff_glu=True,
82
+ rotary_pos_emb=True,
83
+ ))
84
+ else:
85
+ self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
86
+ heads=text_heads)
87
+ self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
88
+ depth=speech_enc_depth, heads=speech_heads)
89
+
90
+ self.temperature = nn.Parameter(torch.tensor(1.))
91
+ self.text_mask_percentage = text_mask_percentage
92
+ self.voice_mask_percentage = voice_mask_percentage
93
+ self.wav_token_compression = wav_token_compression
94
+ self.xformers = use_xformers
95
+ if not use_xformers:
96
+ self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
97
+ self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
98
+
99
+ def forward(
100
+ self,
101
+ text,
102
+ speech_tokens,
103
+ return_loss=False
104
+ ):
105
+ b, device = text.shape[0], text.device
106
+ if self.training:
107
+ text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
108
+ voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
109
+ else:
110
+ text_mask = torch.ones_like(text.float()).bool()
111
+ voice_mask = torch.ones_like(speech_tokens.float()).bool()
112
+
113
+ text_emb = self.text_emb(text)
114
+ speech_emb = self.speech_emb(speech_tokens)
115
+
116
+ if not self.xformers:
117
+ text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
118
+ speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
119
+
120
+ enc_text = self.text_transformer(text_emb, mask=text_mask)
121
+ enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
122
+
123
+ text_latents = masked_mean(enc_text, text_mask, dim=1)
124
+ speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
125
+
126
+ text_latents = self.to_text_latent(text_latents)
127
+ speech_latents = self.to_speech_latent(speech_latents)
128
+
129
+ text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
130
+
131
+ temp = self.temperature.exp()
132
+
133
+ if not return_loss:
134
+ sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
135
+ return sim
136
+
137
+ sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
138
+ labels = torch.arange(b, device=device)
139
+ loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
140
+ return loss
141
+
142
+
143
+ if __name__ == '__main__':
144
+ clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2)
145
+ clip(torch.randint(0,256,(2,120)),
146
+ torch.tensor([50,100]),
147
+ torch.randint(0,8192,(2,250)),
148
+ torch.tensor([101,102]),
149
+ return_loss=True)
150
+ nonloss = clip(torch.randint(0,256,(2,120)),
151
+ torch.tensor([50,100]),
152
+ torch.randint(0,8192,(2,250)),
153
+ torch.tensor([101,102]),
154
+ return_loss=False)
155
+ print(nonloss.shape)
tortoise/models/cvvp.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import einsum
5
+
6
+ from tortoise.models.arch_util import AttentionBlock
7
+ from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
8
+
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+
14
+ def masked_mean(t, mask):
15
+ t = t.masked_fill(~mask, 0.)
16
+ return t.sum(dim=1) / mask.sum(dim=1)
17
+
18
+
19
+ class CollapsingTransformer(nn.Module):
20
+ def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs):
21
+ super().__init__()
22
+ self.transformer = ContinuousTransformerWrapper(
23
+ max_seq_len=-1,
24
+ use_pos_emb=False,
25
+ attn_layers=Encoder(
26
+ dim=model_dim,
27
+ depth=depth,
28
+ heads=heads,
29
+ ff_dropout=dropout,
30
+ ff_mult=1,
31
+ attn_dropout=dropout,
32
+ use_rmsnorm=True,
33
+ ff_glu=True,
34
+ rotary_pos_emb=True,
35
+ **encoder_kwargs,
36
+ ))
37
+ self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1),
38
+ AttentionBlock(
39
+ output_dims, num_heads=heads, do_checkpoint=False),
40
+ nn.Conv1d(output_dims, output_dims, 1))
41
+ self.mask_percentage = mask_percentage
42
+
43
+ def forward(self, x, **transformer_kwargs):
44
+ h = self.transformer(x, **transformer_kwargs)
45
+ h = h.permute(0, 2, 1)
46
+ h = self.pre_combiner(h).permute(0, 2, 1)
47
+ if self.training:
48
+ mask = torch.rand_like(h.float()) > self.mask_percentage
49
+ else:
50
+ mask = torch.ones_like(h.float()).bool()
51
+ return masked_mean(h, mask)
52
+
53
+
54
+ class ConvFormatEmbedding(nn.Module):
55
+ def __init__(self, *args, **kwargs):
56
+ super().__init__()
57
+ self.emb = nn.Embedding(*args, **kwargs)
58
+
59
+ def forward(self, x):
60
+ y = self.emb(x)
61
+ return y.permute(0, 2, 1)
62
+
63
+
64
+ class CVVP(nn.Module):
65
+ def __init__(
66
+ self,
67
+ model_dim=512,
68
+ transformer_heads=8,
69
+ dropout=.1,
70
+ conditioning_enc_depth=8,
71
+ cond_mask_percentage=0,
72
+ mel_channels=80,
73
+ mel_codes=None,
74
+ speech_enc_depth=8,
75
+ speech_mask_percentage=0,
76
+ latent_multiplier=1,
77
+ ):
78
+ super().__init__()
79
+ latent_dim = latent_multiplier*model_dim
80
+ self.temperature = nn.Parameter(torch.tensor(1.))
81
+
82
+ self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
83
+ nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
84
+ self.conditioning_transformer = CollapsingTransformer(
85
+ model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
86
+ self.to_conditioning_latent = nn.Linear(
87
+ latent_dim, latent_dim, bias=False)
88
+
89
+ if mel_codes is None:
90
+ self.speech_emb = nn.Conv1d(
91
+ mel_channels, model_dim, kernel_size=5, padding=2)
92
+ else:
93
+ self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
94
+ self.speech_transformer = CollapsingTransformer(
95
+ model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
96
+ self.to_speech_latent = nn.Linear(
97
+ latent_dim, latent_dim, bias=False)
98
+
99
+ def get_grad_norm_parameter_groups(self):
100
+ return {
101
+ 'conditioning': list(self.conditioning_transformer.parameters()),
102
+ 'speech': list(self.speech_transformer.parameters()),
103
+ }
104
+
105
+ def forward(
106
+ self,
107
+ mel_cond,
108
+ mel_input,
109
+ return_loss=False
110
+ ):
111
+ cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1)
112
+ enc_cond = self.conditioning_transformer(cond_emb)
113
+ cond_latents = self.to_conditioning_latent(enc_cond)
114
+
115
+ speech_emb = self.speech_emb(mel_input).permute(0, 2, 1)
116
+ enc_speech = self.speech_transformer(speech_emb)
117
+ speech_latents = self.to_speech_latent(enc_speech)
118
+
119
+ cond_latents, speech_latents = map(lambda t: F.normalize(
120
+ t, p=2, dim=-1), (cond_latents, speech_latents))
121
+ temp = self.temperature.exp()
122
+
123
+ if not return_loss:
124
+ sim = einsum('n d, n d -> n', cond_latents,
125
+ speech_latents) * temp
126
+ return sim
127
+
128
+ sim = einsum('i d, j d -> i j', cond_latents,
129
+ speech_latents) * temp
130
+ labels = torch.arange(
131
+ cond_latents.shape[0], device=mel_input.device)
132
+ loss = (F.cross_entropy(sim, labels) +
133
+ F.cross_entropy(sim.t(), labels)) / 2
134
+
135
+ return loss
136
+
137
+
138
+ if __name__ == '__main__':
139
+ clvp = CVVP()
140
+ clvp(torch.randn(2, 80, 100),
141
+ torch.randn(2, 80, 95),
142
+ return_loss=True)
tortoise/models/diffusion_decoder.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from abc import abstractmethod
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import autocast
9
+
10
+ from tortoise.models.arch_util import normalization, AttentionBlock
11
+
12
+
13
+ def is_latent(t):
14
+ return t.dtype == torch.float
15
+
16
+
17
+ def is_sequence(t):
18
+ return t.dtype == torch.long
19
+
20
+
21
+ def timestep_embedding(timesteps, dim, max_period=10000):
22
+ """
23
+ Create sinusoidal timestep embeddings.
24
+
25
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
26
+ These may be fractional.
27
+ :param dim: the dimension of the output.
28
+ :param max_period: controls the minimum frequency of the embeddings.
29
+ :return: an [N x dim] Tensor of positional embeddings.
30
+ """
31
+ half = dim // 2
32
+ freqs = torch.exp(
33
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
34
+ ).to(device=timesteps.device)
35
+ args = timesteps[:, None].float() * freqs[None]
36
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
37
+ if dim % 2:
38
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
39
+ return embedding
40
+
41
+
42
+ class TimestepBlock(nn.Module):
43
+ @abstractmethod
44
+ def forward(self, x, emb):
45
+ """
46
+ Apply the module to `x` given `emb` timestep embeddings.
47
+ """
48
+
49
+
50
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
51
+ def forward(self, x, emb):
52
+ for layer in self:
53
+ if isinstance(layer, TimestepBlock):
54
+ x = layer(x, emb)
55
+ else:
56
+ x = layer(x)
57
+ return x
58
+
59
+
60
+ class ResBlock(TimestepBlock):
61
+ def __init__(
62
+ self,
63
+ channels,
64
+ emb_channels,
65
+ dropout,
66
+ out_channels=None,
67
+ dims=2,
68
+ kernel_size=3,
69
+ efficient_config=True,
70
+ use_scale_shift_norm=False,
71
+ ):
72
+ super().__init__()
73
+ self.channels = channels
74
+ self.emb_channels = emb_channels
75
+ self.dropout = dropout
76
+ self.out_channels = out_channels or channels
77
+ self.use_scale_shift_norm = use_scale_shift_norm
78
+ padding = {1: 0, 3: 1, 5: 2}[kernel_size]
79
+ eff_kernel = 1 if efficient_config else 3
80
+ eff_padding = 0 if efficient_config else 1
81
+
82
+ self.in_layers = nn.Sequential(
83
+ normalization(channels),
84
+ nn.SiLU(),
85
+ nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
86
+ )
87
+
88
+ self.emb_layers = nn.Sequential(
89
+ nn.SiLU(),
90
+ nn.Linear(
91
+ emb_channels,
92
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
93
+ ),
94
+ )
95
+ self.out_layers = nn.Sequential(
96
+ normalization(self.out_channels),
97
+ nn.SiLU(),
98
+ nn.Dropout(p=dropout),
99
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
100
+ )
101
+
102
+ if self.out_channels == channels:
103
+ self.skip_connection = nn.Identity()
104
+ else:
105
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
106
+
107
+ def forward(self, x, emb):
108
+ h = self.in_layers(x)
109
+ emb_out = self.emb_layers(emb).type(h.dtype)
110
+ while len(emb_out.shape) < len(h.shape):
111
+ emb_out = emb_out[..., None]
112
+ if self.use_scale_shift_norm:
113
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
114
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
115
+ h = out_norm(h) * (1 + scale) + shift
116
+ h = out_rest(h)
117
+ else:
118
+ h = h + emb_out
119
+ h = self.out_layers(h)
120
+ return self.skip_connection(x) + h
121
+
122
+
123
+ class DiffusionLayer(TimestepBlock):
124
+ def __init__(self, model_channels, dropout, num_heads):
125
+ super().__init__()
126
+ self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
127
+ self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
128
+
129
+ def forward(self, x, time_emb):
130
+ y = self.resblk(x, time_emb)
131
+ return self.attn(y)
132
+
133
+
134
+ class DiffusionTts(nn.Module):
135
+ def __init__(
136
+ self,
137
+ model_channels=512,
138
+ num_layers=8,
139
+ in_channels=100,
140
+ in_latent_channels=512,
141
+ in_tokens=8193,
142
+ out_channels=200, # mean and variance
143
+ dropout=0,
144
+ use_fp16=False,
145
+ num_heads=16,
146
+ # Parameters for regularization.
147
+ layer_drop=.1,
148
+ unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
149
+ ):
150
+ super().__init__()
151
+
152
+ self.in_channels = in_channels
153
+ self.model_channels = model_channels
154
+ self.out_channels = out_channels
155
+ self.dropout = dropout
156
+ self.num_heads = num_heads
157
+ self.unconditioned_percentage = unconditioned_percentage
158
+ self.enable_fp16 = use_fp16
159
+ self.layer_drop = layer_drop
160
+
161
+ self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
162
+ self.time_embed = nn.Sequential(
163
+ nn.Linear(model_channels, model_channels),
164
+ nn.SiLU(),
165
+ nn.Linear(model_channels, model_channels),
166
+ )
167
+
168
+ # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
169
+ # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
170
+ # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
171
+ # transformer network.
172
+ self.code_embedding = nn.Embedding(in_tokens, model_channels)
173
+ self.code_converter = nn.Sequential(
174
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
175
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
176
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
177
+ )
178
+ self.code_norm = normalization(model_channels)
179
+ self.latent_conditioner = nn.Sequential(
180
+ nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
181
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
182
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
183
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
184
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
185
+ )
186
+ self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
187
+ nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
188
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
189
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
190
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
191
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
192
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
193
+ self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
194
+ self.conditioning_timestep_integrator = TimestepEmbedSequential(
195
+ DiffusionLayer(model_channels, dropout, num_heads),
196
+ DiffusionLayer(model_channels, dropout, num_heads),
197
+ DiffusionLayer(model_channels, dropout, num_heads),
198
+ )
199
+
200
+ self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
201
+ self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
202
+
203
+ self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
204
+ [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
205
+
206
+ self.out = nn.Sequential(
207
+ normalization(model_channels),
208
+ nn.SiLU(),
209
+ nn.Conv1d(model_channels, out_channels, 3, padding=1),
210
+ )
211
+
212
+ def get_grad_norm_parameter_groups(self):
213
+ groups = {
214
+ 'minicoder': list(self.contextual_embedder.parameters()),
215
+ 'layers': list(self.layers.parameters()),
216
+ 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()),
217
+ 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
218
+ 'time_embed': list(self.time_embed.parameters()),
219
+ }
220
+ return groups
221
+
222
+ def get_conditioning(self, conditioning_input):
223
+ speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
224
+ conditioning_input.shape) == 3 else conditioning_input
225
+ conds = []
226
+ for j in range(speech_conditioning_input.shape[1]):
227
+ conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
228
+ conds = torch.cat(conds, dim=-1)
229
+ conds = conds.mean(dim=-1)
230
+ return conds
231
+
232
+ def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred):
233
+ # Shuffle aligned_latent to BxCxS format
234
+ if is_latent(aligned_conditioning):
235
+ aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
236
+
237
+ cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1)
238
+ if is_latent(aligned_conditioning):
239
+ code_emb = self.latent_conditioner(aligned_conditioning)
240
+ else:
241
+ code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
242
+ code_emb = self.code_converter(code_emb)
243
+ code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
244
+
245
+ unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
246
+ # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
247
+ if self.training and self.unconditioned_percentage > 0:
248
+ unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
249
+ device=code_emb.device) < self.unconditioned_percentage
250
+ code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
251
+ code_emb)
252
+ expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
253
+
254
+ if not return_code_pred:
255
+ return expanded_code_emb
256
+ else:
257
+ mel_pred = self.mel_head(expanded_code_emb)
258
+ # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
259
+ mel_pred = mel_pred * unconditioned_batches.logical_not()
260
+ return expanded_code_emb, mel_pred
261
+
262
+ def forward(self, x, timesteps, aligned_conditioning=None, conditioning_latent=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
263
+ """
264
+ Apply the model to an input batch.
265
+
266
+ :param x: an [N x C x ...] Tensor of inputs.
267
+ :param timesteps: a 1-D batch of timesteps.
268
+ :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
269
+ :param conditioning_latent: a pre-computed conditioning latent; see get_conditioning().
270
+ :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
271
+ :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
272
+ :return: an [N x C x ...] Tensor of outputs.
273
+ """
274
+ assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_latent is not None)
275
+ assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
276
+
277
+ unused_params = []
278
+ if conditioning_free:
279
+ code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
280
+ unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
281
+ unused_params.extend(list(self.latent_conditioner.parameters()))
282
+ else:
283
+ if precomputed_aligned_embeddings is not None:
284
+ code_emb = precomputed_aligned_embeddings
285
+ else:
286
+ code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_latent, x.shape[-1], True)
287
+ if is_latent(aligned_conditioning):
288
+ unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
289
+ else:
290
+ unused_params.extend(list(self.latent_conditioner.parameters()))
291
+
292
+ unused_params.append(self.unconditioned_embedding)
293
+
294
+ time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
295
+ code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
296
+ x = self.inp_block(x)
297
+ x = torch.cat([x, code_emb], dim=1)
298
+ x = self.integrating_conv(x)
299
+ for i, lyr in enumerate(self.layers):
300
+ # Do layer drop where applicable. Do not drop first and last layers.
301
+ if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop:
302
+ unused_params.extend(list(lyr.parameters()))
303
+ else:
304
+ # First and last blocks will have autocast disabled for improved precision.
305
+ if not torch.backends.mps.is_available():
306
+ with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
307
+ x = lyr(x, time_emb)
308
+ else:
309
+ x = lyr(x, time_emb)
310
+
311
+ x = x.float()
312
+ out = self.out(x)
313
+
314
+ # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
315
+ extraneous_addition = 0
316
+ for p in unused_params:
317
+ extraneous_addition = extraneous_addition + p.mean()
318
+ out = out + extraneous_addition * 0
319
+
320
+ if return_code_pred:
321
+ return out, mel_pred
322
+ return out
323
+
324
+
325
+ if __name__ == '__main__':
326
+ clip = torch.randn(2, 100, 400)
327
+ aligned_latent = torch.randn(2,388,512)
328
+ aligned_sequence = torch.randint(0,8192,(2,100))
329
+ cond = torch.randn(2, 100, 400)
330
+ ts = torch.LongTensor([600, 600])
331
+ model = DiffusionTts(512, layer_drop=.3, unconditioned_percentage=.5)
332
+ # Test with latent aligned conditioning
333
+ #o = model(clip, ts, aligned_latent, cond)
334
+ # Test with sequence aligned conditioning
335
+ o = model(clip, ts, aligned_sequence, cond)
336
+
tortoise/models/hifigan_decoder.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn import functional as F
6
+ from torch.nn.utils import remove_weight_norm, weight_norm
7
+
8
+ LRELU_SLOPE = 0.1
9
+
10
+
11
+ def get_padding(k, d):
12
+ return int((k * d - d) / 2)
13
+
14
+
15
+ class ResBlock1(torch.nn.Module):
16
+ """Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
17
+
18
+ Network::
19
+
20
+ x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
21
+ |--------------------------------------------------------------------------------------------------|
22
+
23
+
24
+ Args:
25
+ channels (int): number of hidden channels for the convolutional layers.
26
+ kernel_size (int): size of the convolution filter in each layer.
27
+ dilations (list): list of dilation value for each conv layer in a block.
28
+ """
29
+
30
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
31
+ super().__init__()
32
+ self.convs1 = nn.ModuleList(
33
+ [
34
+ weight_norm(
35
+ Conv1d(
36
+ channels,
37
+ channels,
38
+ kernel_size,
39
+ 1,
40
+ dilation=dilation[0],
41
+ padding=get_padding(kernel_size, dilation[0]),
42
+ )
43
+ ),
44
+ weight_norm(
45
+ Conv1d(
46
+ channels,
47
+ channels,
48
+ kernel_size,
49
+ 1,
50
+ dilation=dilation[1],
51
+ padding=get_padding(kernel_size, dilation[1]),
52
+ )
53
+ ),
54
+ weight_norm(
55
+ Conv1d(
56
+ channels,
57
+ channels,
58
+ kernel_size,
59
+ 1,
60
+ dilation=dilation[2],
61
+ padding=get_padding(kernel_size, dilation[2]),
62
+ )
63
+ ),
64
+ ]
65
+ )
66
+
67
+ self.convs2 = nn.ModuleList(
68
+ [
69
+ weight_norm(
70
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
71
+ ),
72
+ weight_norm(
73
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
74
+ ),
75
+ weight_norm(
76
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
77
+ ),
78
+ ]
79
+ )
80
+
81
+ def forward(self, x):
82
+ """
83
+ Args:
84
+ x (Tensor): input tensor.
85
+ Returns:
86
+ Tensor: output tensor.
87
+ Shapes:
88
+ x: [B, C, T]
89
+ """
90
+ for c1, c2 in zip(self.convs1, self.convs2):
91
+ xt = F.leaky_relu(x, LRELU_SLOPE)
92
+ xt = c1(xt)
93
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
94
+ xt = c2(xt)
95
+ x = xt + x
96
+ return x
97
+
98
+ def remove_weight_norm(self):
99
+ for l in self.convs1:
100
+ remove_weight_norm(l)
101
+ for l in self.convs2:
102
+ remove_weight_norm(l)
103
+
104
+
105
+ class ResBlock2(torch.nn.Module):
106
+ """Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
107
+
108
+ Network::
109
+
110
+ x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
111
+ |---------------------------------------------------|
112
+
113
+
114
+ Args:
115
+ channels (int): number of hidden channels for the convolutional layers.
116
+ kernel_size (int): size of the convolution filter in each layer.
117
+ dilations (list): list of dilation value for each conv layer in a block.
118
+ """
119
+
120
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
121
+ super().__init__()
122
+ self.convs = nn.ModuleList(
123
+ [
124
+ weight_norm(
125
+ Conv1d(
126
+ channels,
127
+ channels,
128
+ kernel_size,
129
+ 1,
130
+ dilation=dilation[0],
131
+ padding=get_padding(kernel_size, dilation[0]),
132
+ )
133
+ ),
134
+ weight_norm(
135
+ Conv1d(
136
+ channels,
137
+ channels,
138
+ kernel_size,
139
+ 1,
140
+ dilation=dilation[1],
141
+ padding=get_padding(kernel_size, dilation[1]),
142
+ )
143
+ ),
144
+ ]
145
+ )
146
+
147
+ def forward(self, x):
148
+ for c in self.convs:
149
+ xt = F.leaky_relu(x, LRELU_SLOPE)
150
+ xt = c(xt)
151
+ x = xt + x
152
+ return x
153
+
154
+ def remove_weight_norm(self):
155
+ for l in self.convs:
156
+ remove_weight_norm(l)
157
+
158
+
159
+ class HifiganGenerator(torch.nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ out_channels,
164
+ resblock_type,
165
+ resblock_dilation_sizes,
166
+ resblock_kernel_sizes,
167
+ upsample_kernel_sizes,
168
+ upsample_initial_channel,
169
+ upsample_factors,
170
+ inference_padding=5,
171
+ cond_channels=0,
172
+ conv_pre_weight_norm=True,
173
+ conv_post_weight_norm=True,
174
+ conv_post_bias=True,
175
+ ):
176
+ r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
177
+
178
+ Network:
179
+ x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
180
+ .. -> zI ---|
181
+ resblockN_kNx1 -> zN ---'
182
+
183
+ Args:
184
+ in_channels (int): number of input tensor channels.
185
+ out_channels (int): number of output tensor channels.
186
+ resblock_type (str): type of the `ResBlock`. '1' or '2'.
187
+ resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
188
+ resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
189
+ upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
190
+ upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
191
+ for each consecutive upsampling layer.
192
+ upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
193
+ inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
194
+ """
195
+ super().__init__()
196
+ self.inference_padding = inference_padding
197
+ self.num_kernels = len(resblock_kernel_sizes)
198
+ self.num_upsamples = len(upsample_factors)
199
+ # initial upsampling layers
200
+ self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
201
+ resblock = ResBlock1 if resblock_type == "1" else ResBlock2
202
+ # upsampling layers
203
+ self.ups = nn.ModuleList()
204
+ for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
205
+ self.ups.append(
206
+ weight_norm(
207
+ ConvTranspose1d(
208
+ upsample_initial_channel // (2**i),
209
+ upsample_initial_channel // (2 ** (i + 1)),
210
+ k,
211
+ u,
212
+ padding=(k - u) // 2,
213
+ )
214
+ )
215
+ )
216
+ # MRF blocks
217
+ self.resblocks = nn.ModuleList()
218
+ for i in range(len(self.ups)):
219
+ ch = upsample_initial_channel // (2 ** (i + 1))
220
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
221
+ self.resblocks.append(resblock(ch, k, d))
222
+ # post convolution layer
223
+ self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
224
+ if cond_channels > 0:
225
+ self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
226
+
227
+ if not conv_pre_weight_norm:
228
+ remove_weight_norm(self.conv_pre)
229
+
230
+ if not conv_post_weight_norm:
231
+ remove_weight_norm(self.conv_post)
232
+
233
+ self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
234
+ if torch.backends.mps.is_available():
235
+ self.device = torch.device('mps')
236
+ def forward(self, x, g=None):
237
+ """
238
+ Args:
239
+ x (Tensor): feature input tensor.
240
+ g (Tensor): global conditioning input tensor.
241
+
242
+ Returns:
243
+ Tensor: output waveform.
244
+
245
+ Shapes:
246
+ x: [B, C, T]
247
+ Tensor: [B, 1, T]
248
+ """
249
+ o = self.conv_pre(x)
250
+ if hasattr(self, "cond_layer"):
251
+ o = o + self.cond_layer(g)
252
+ for i in range(self.num_upsamples):
253
+ o = F.leaky_relu(o, LRELU_SLOPE)
254
+ o = self.ups[i](o)
255
+ z_sum = None
256
+ for j in range(self.num_kernels):
257
+ if z_sum is None:
258
+ z_sum = self.resblocks[i * self.num_kernels + j](o)
259
+ else:
260
+ z_sum += self.resblocks[i * self.num_kernels + j](o)
261
+ o = z_sum / self.num_kernels
262
+ o = F.leaky_relu(o)
263
+ o = self.conv_post(o)
264
+ o = torch.tanh(o)
265
+ return o
266
+
267
+ @torch.no_grad()
268
+ def inference(self, c, g=None):
269
+ """
270
+ Args:
271
+ x (Tensor): conditioning input tensor.
272
+
273
+ Returns:
274
+ Tensor: output waveform.
275
+
276
+ Shapes:
277
+ x: [B, C, T]
278
+ Tensor: [B, 1, T]
279
+ """
280
+ # c = c.to(self.conv_pre.weight.device)
281
+ # c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
282
+ up_1 = torch.nn.functional.interpolate(
283
+ c.transpose(1,2),
284
+ scale_factor=[1024 / 256],
285
+ mode="linear",
286
+ )
287
+ up_2 = torch.nn.functional.interpolate(
288
+ up_1,
289
+ scale_factor=[24000 / 22050],
290
+ mode="linear",
291
+ )
292
+ g = g.unsqueeze(0)
293
+ return self.forward(up_2.to(self.device), g.transpose(1,2))
294
+
295
+ def remove_weight_norm(self):
296
+ print("Removing weight norm...")
297
+ for l in self.ups:
298
+ remove_weight_norm(l)
299
+ for l in self.resblocks:
300
+ l.remove_weight_norm()
301
+ remove_weight_norm(self.conv_pre)
302
+ remove_weight_norm(self.conv_post)
tortoise/models/random_latent_generator.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
9
+ if bias is not None:
10
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
11
+ return (
12
+ F.leaky_relu(
13
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
14
+ )
15
+ * scale
16
+ )
17
+ else:
18
+ return F.leaky_relu(input, negative_slope=0.2) * scale
19
+
20
+
21
+ class EqualLinear(nn.Module):
22
+ def __init__(
23
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1
24
+ ):
25
+ super().__init__()
26
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
27
+ if bias:
28
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
29
+ else:
30
+ self.bias = None
31
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
32
+ self.lr_mul = lr_mul
33
+
34
+ def forward(self, input):
35
+ out = F.linear(input, self.weight * self.scale)
36
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
37
+ return out
38
+
39
+
40
+ class RandomLatentConverter(nn.Module):
41
+ def __init__(self, channels):
42
+ super().__init__()
43
+ self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
44
+ nn.Linear(channels, channels))
45
+ self.channels = channels
46
+
47
+ def forward(self, ref):
48
+ r = torch.randn(ref.shape[0], self.channels, device=ref.device)
49
+ y = self.layers(r)
50
+ return y
51
+
52
+
53
+ if __name__ == '__main__':
54
+ model = RandomLatentConverter(512)
55
+ model(torch.randn(5,512))
tortoise/models/stream_generator.py ADDED
@@ -0,0 +1,1057 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/LowinLi/transformers-stream-generator
2
+
3
+ from transformers import (
4
+ GenerationConfig,
5
+ GenerationMixin,
6
+ LogitsProcessorList,
7
+ StoppingCriteriaList,
8
+ DisjunctiveConstraint,
9
+ BeamSearchScorer,
10
+ PhrasalConstraint,
11
+ ConstrainedBeamSearchScorer,
12
+ PreTrainedModel,
13
+ )
14
+ import numpy as np
15
+ import random
16
+ import warnings
17
+ import inspect
18
+ from transformers.generation.utils import GenerateOutput, SampleOutput, logger
19
+ import torch
20
+ from typing import Callable, List, Optional, Union
21
+ from torch import nn
22
+ import torch.distributed as dist
23
+ import copy
24
+
25
+
26
+ def setup_seed(seed):
27
+ if seed == -1:
28
+ return
29
+ torch.manual_seed(seed)
30
+ if torch.cuda.is_available():
31
+ torch.cuda.manual_seed_all(seed)
32
+ np.random.seed(seed)
33
+ random.seed(seed)
34
+ torch.backends.cudnn.deterministic = True
35
+
36
+
37
+ class StreamGenerationConfig(GenerationConfig):
38
+ def __init__(self, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.do_stream = kwargs.pop("do_stream", False)
41
+
42
+
43
+ class NewGenerationMixin(GenerationMixin):
44
+ @torch.no_grad()
45
+ def generate(
46
+ self,
47
+ inputs: Optional[torch.Tensor] = None,
48
+ generation_config: Optional[StreamGenerationConfig] = None,
49
+ logits_processor: Optional[LogitsProcessorList] = None,
50
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
51
+ prefix_allowed_tokens_fn: Optional[
52
+ Callable[[int, torch.Tensor], List[int]]
53
+ ] = None,
54
+ synced_gpus: Optional[bool] = False,
55
+ seed=0,
56
+ **kwargs,
57
+ ) -> Union[GenerateOutput, torch.LongTensor]:
58
+ r"""
59
+
60
+ Generates sequences of token ids for models with a language modeling head.
61
+
62
+ <Tip warning={true}>
63
+
64
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
65
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
66
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
67
+
68
+ For an overview of generation strategies and code examples, check out the [following
69
+ guide](./generation_strategies).
70
+
71
+ </Tip>
72
+
73
+ Parameters:
74
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
75
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
76
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
77
+ should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
78
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
79
+ generation_config (`~generation.GenerationConfig`, *optional*):
80
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
81
+ passed to generate matching the attributes of `generation_config` will override them. If
82
+ `generation_config` is not provided, the default will be used, which had the following loading
83
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
84
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
85
+ default values, whose documentation should be checked to parameterize generation.
86
+ logits_processor (`LogitsProcessorList`, *optional*):
87
+ Custom logits processors that complement the default logits processors built from arguments and
88
+ generation config. If a logit processor is passed that is already created with the arguments or a
89
+ generation config an error is thrown. This feature is intended for advanced users.
90
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
91
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
92
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
93
+ generation config an error is thrown. This feature is intended for advanced users.
94
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
95
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
96
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
97
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
98
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
99
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
100
+ Retrieval](https://arxiv.org/abs/2010.00904).
101
+ synced_gpus (`bool`, *optional*, defaults to `False`):
102
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
103
+ kwargs:
104
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
105
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
106
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
107
+
108
+ Return:
109
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
110
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
111
+
112
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
113
+ [`~utils.ModelOutput`] types are:
114
+
115
+ - [`~generation.GreedySearchDecoderOnlyOutput`],
116
+ - [`~generation.SampleDecoderOnlyOutput`],
117
+ - [`~generation.BeamSearchDecoderOnlyOutput`],
118
+ - [`~generation.BeamSampleDecoderOnlyOutput`]
119
+
120
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
121
+ [`~utils.ModelOutput`] types are:
122
+
123
+ - [`~generation.GreedySearchEncoderDecoderOutput`],
124
+ - [`~generation.SampleEncoderDecoderOutput`],
125
+ - [`~generation.BeamSearchEncoderDecoderOutput`],
126
+ - [`~generation.BeamSampleEncoderDecoderOutput`]
127
+ """
128
+ setup_seed(seed)
129
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
130
+ self._validate_model_class()
131
+
132
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
133
+ if generation_config is None:
134
+ # legacy: users may modify the model configuration to control generation -- update the generation config
135
+ # model attribute accordingly, if it was created from the model config
136
+ if self.generation_config._from_model_config:
137
+ new_generation_config = StreamGenerationConfig.from_model_config(
138
+ self.config
139
+ )
140
+ if new_generation_config != self.generation_config:
141
+ warnings.warn(
142
+ "You have modified the pretrained model configuration to control generation. This is a"
143
+ " deprecated strategy to control generation and will be removed soon, in a future version."
144
+ " Please use a generation configuration file (see"
145
+ " https://huggingface.co/docs/transformers/main_classes/text_generation)"
146
+ )
147
+ self.generation_config = new_generation_config
148
+ generation_config = self.generation_config
149
+
150
+ generation_config = copy.deepcopy(generation_config)
151
+ model_kwargs = generation_config.update(
152
+ **kwargs
153
+ ) # All unused kwargs must be model kwargs
154
+ # self._validate_model_kwargs(model_kwargs.copy())
155
+
156
+ # 2. Set generation parameters if not already defined
157
+ logits_processor = (
158
+ logits_processor if logits_processor is not None else LogitsProcessorList()
159
+ )
160
+ stopping_criteria = (
161
+ stopping_criteria
162
+ if stopping_criteria is not None
163
+ else StoppingCriteriaList()
164
+ )
165
+
166
+ if (
167
+ generation_config.pad_token_id is None
168
+ and generation_config.eos_token_id is not None
169
+ ):
170
+ if model_kwargs.get("attention_mask", None) is None:
171
+ logger.warning(
172
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
173
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
174
+ )
175
+ eos_token_id = generation_config.eos_token_id
176
+ if isinstance(eos_token_id, list):
177
+ eos_token_id = eos_token_id[0]
178
+ logger.warning(
179
+ f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
180
+ )
181
+ generation_config.pad_token_id = eos_token_id
182
+
183
+ # 3. Define model inputs
184
+ # inputs_tensor has to be defined
185
+ # model_input_name is defined if model-specific keyword input is passed
186
+ # otherwise model_input_name is None
187
+ # all model-specific keyword inputs are removed from `model_kwargs`
188
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
189
+ inputs, generation_config.bos_token_id, model_kwargs
190
+ )
191
+ batch_size = inputs_tensor.shape[0]
192
+
193
+ # 4. Define other model kwargs
194
+ model_kwargs["output_attentions"] = generation_config.output_attentions
195
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
196
+ model_kwargs["use_cache"] = generation_config.use_cache
197
+
198
+ accepts_attention_mask = "attention_mask" in set(
199
+ inspect.signature(self.forward).parameters.keys()
200
+ )
201
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
202
+
203
+ if (
204
+ model_kwargs.get("attention_mask", None) is None
205
+ and requires_attention_mask
206
+ and accepts_attention_mask
207
+ ):
208
+ model_kwargs[
209
+ "attention_mask"
210
+ ] = self._prepare_attention_mask_for_generation(
211
+ inputs_tensor,
212
+ generation_config.pad_token_id,
213
+ generation_config.eos_token_id,
214
+ )
215
+
216
+ # decoder-only models should use left-padding for generation
217
+ if not self.config.is_encoder_decoder:
218
+ if (
219
+ generation_config.pad_token_id is not None
220
+ and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
221
+ > 0
222
+ ):
223
+ logger.warning(
224
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
225
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
226
+ )
227
+
228
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
229
+ # if model is encoder decoder encoder_outputs are created
230
+ # and added to `model_kwargs`
231
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
232
+ inputs_tensor, model_kwargs, model_input_name
233
+ )
234
+
235
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
236
+ if self.config.is_encoder_decoder:
237
+ input_ids = self._prepare_decoder_input_ids_for_generation(
238
+ batch_size,
239
+ decoder_start_token_id=generation_config.decoder_start_token_id,
240
+ bos_token_id=generation_config.bos_token_id,
241
+ model_kwargs=model_kwargs,
242
+ device=inputs_tensor.device,
243
+ )
244
+ else:
245
+ # if decoder-only then inputs_tensor has to be `input_ids`
246
+ input_ids = inputs_tensor
247
+
248
+ # 6. Prepare `max_length` depending on other stopping criteria.
249
+ input_ids_seq_length = input_ids.shape[-1]
250
+ has_default_max_length = (
251
+ kwargs.get("max_length") is None
252
+ and generation_config.max_length is not None
253
+ )
254
+ if has_default_max_length and generation_config.max_new_tokens is None:
255
+ warnings.warn(
256
+ "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
257
+ f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
258
+ " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
259
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
260
+ UserWarning,
261
+ )
262
+ elif has_default_max_length and generation_config.max_new_tokens is not None:
263
+ generation_config.max_length = (
264
+ generation_config.max_new_tokens + input_ids_seq_length
265
+ )
266
+ elif (
267
+ not has_default_max_length and generation_config.max_new_tokens is not None
268
+ ):
269
+ raise ValueError(
270
+ "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
271
+ " limit to the generated output length. Remove one of those arguments. Please refer to the"
272
+ " documentation for more information. "
273
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
274
+ )
275
+
276
+ if (
277
+ generation_config.min_length is not None
278
+ and generation_config.min_length > generation_config.max_length
279
+ ):
280
+ raise ValueError(
281
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
282
+ f" the maximum length ({generation_config.max_length})"
283
+ )
284
+ if input_ids_seq_length >= generation_config.max_length:
285
+ input_ids_string = (
286
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
287
+ )
288
+ logger.warning(
289
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
290
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
291
+ " increasing `max_new_tokens`."
292
+ )
293
+
294
+ # 7. determine generation mode
295
+ is_constraint_gen_mode = (
296
+ generation_config.constraints is not None
297
+ or generation_config.force_words_ids is not None
298
+ )
299
+
300
+ is_contrastive_search_gen_mode = (
301
+ generation_config.top_k is not None
302
+ and generation_config.top_k > 1
303
+ and generation_config.do_sample is False
304
+ and generation_config.penalty_alpha is not None
305
+ and generation_config.penalty_alpha > 0
306
+ )
307
+
308
+ is_greedy_gen_mode = (
309
+ (generation_config.num_beams == 1)
310
+ and (generation_config.num_beam_groups == 1)
311
+ and generation_config.do_sample is False
312
+ and not is_constraint_gen_mode
313
+ and not is_contrastive_search_gen_mode
314
+ )
315
+ is_sample_gen_mode = (
316
+ (generation_config.num_beams == 1)
317
+ and (generation_config.num_beam_groups == 1)
318
+ and generation_config.do_sample is True
319
+ and generation_config.do_stream is False
320
+ and not is_constraint_gen_mode
321
+ and not is_contrastive_search_gen_mode
322
+ )
323
+ is_sample_gen_stream_mode = (
324
+ (generation_config.num_beams == 1)
325
+ and (generation_config.num_beam_groups == 1)
326
+ and generation_config.do_stream is True
327
+ and not is_constraint_gen_mode
328
+ and not is_contrastive_search_gen_mode
329
+ )
330
+ is_beam_gen_mode = (
331
+ (generation_config.num_beams > 1)
332
+ and (generation_config.num_beam_groups == 1)
333
+ and generation_config.do_sample is False
334
+ and not is_constraint_gen_mode
335
+ and not is_contrastive_search_gen_mode
336
+ )
337
+ is_beam_sample_gen_mode = (
338
+ (generation_config.num_beams > 1)
339
+ and (generation_config.num_beam_groups == 1)
340
+ and generation_config.do_sample is True
341
+ and not is_constraint_gen_mode
342
+ and not is_contrastive_search_gen_mode
343
+ )
344
+ is_group_beam_gen_mode = (
345
+ (generation_config.num_beams > 1)
346
+ and (generation_config.num_beam_groups > 1)
347
+ and not is_constraint_gen_mode
348
+ and not is_contrastive_search_gen_mode
349
+ )
350
+
351
+ if generation_config.num_beam_groups > generation_config.num_beams:
352
+ raise ValueError(
353
+ "`num_beam_groups` has to be smaller or equal to `num_beams`"
354
+ )
355
+ if is_group_beam_gen_mode and generation_config.do_sample is True:
356
+ raise ValueError(
357
+ "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
358
+ )
359
+
360
+ if self.device.type != input_ids.device.type:
361
+ warnings.warn(
362
+ "You are calling .generate() with the `input_ids` being on a device type different"
363
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
364
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
365
+ " Please make sure that you have put `input_ids` to the"
366
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
367
+ " running `.generate()`.",
368
+ UserWarning,
369
+ )
370
+ # 8. prepare distribution pre_processing samplers
371
+ logits_processor = self._get_logits_processor(
372
+ generation_config=generation_config,
373
+ input_ids_seq_length=input_ids_seq_length,
374
+ encoder_input_ids=inputs_tensor,
375
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
376
+ logits_processor=logits_processor,
377
+ )
378
+
379
+ # 9. prepare stopping criteria
380
+ stopping_criteria = self._get_stopping_criteria(
381
+ generation_config=generation_config, stopping_criteria=stopping_criteria
382
+ )
383
+ # 10. go into different generation modes
384
+ if is_greedy_gen_mode:
385
+ if generation_config.num_return_sequences > 1:
386
+ raise ValueError(
387
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
388
+ " greedy search."
389
+ )
390
+
391
+ # 11. run greedy search
392
+ return self.greedy_search(
393
+ input_ids,
394
+ logits_processor=logits_processor,
395
+ stopping_criteria=stopping_criteria,
396
+ pad_token_id=generation_config.pad_token_id,
397
+ eos_token_id=generation_config.eos_token_id,
398
+ output_scores=generation_config.output_scores,
399
+ return_dict_in_generate=generation_config.return_dict_in_generate,
400
+ synced_gpus=synced_gpus,
401
+ **model_kwargs,
402
+ )
403
+
404
+ elif is_contrastive_search_gen_mode:
405
+ if generation_config.num_return_sequences > 1:
406
+ raise ValueError(
407
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
408
+ " contrastive search."
409
+ )
410
+
411
+ return self.contrastive_search(
412
+ input_ids,
413
+ top_k=generation_config.top_k,
414
+ penalty_alpha=generation_config.penalty_alpha,
415
+ logits_processor=logits_processor,
416
+ stopping_criteria=stopping_criteria,
417
+ pad_token_id=generation_config.pad_token_id,
418
+ eos_token_id=generation_config.eos_token_id,
419
+ output_scores=generation_config.output_scores,
420
+ return_dict_in_generate=generation_config.return_dict_in_generate,
421
+ synced_gpus=synced_gpus,
422
+ **model_kwargs,
423
+ )
424
+
425
+ elif is_sample_gen_mode:
426
+ # 11. prepare logits warper
427
+ logits_warper = self._get_logits_warper(generation_config)
428
+
429
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
430
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
431
+ input_ids=input_ids,
432
+ expand_size=generation_config.num_return_sequences,
433
+ is_encoder_decoder=self.config.is_encoder_decoder,
434
+ **model_kwargs,
435
+ )
436
+
437
+ # 13. run sample
438
+ return self.sample(
439
+ input_ids,
440
+ logits_processor=logits_processor,
441
+ logits_warper=logits_warper,
442
+ stopping_criteria=stopping_criteria,
443
+ pad_token_id=generation_config.pad_token_id,
444
+ eos_token_id=generation_config.eos_token_id,
445
+ output_scores=generation_config.output_scores,
446
+ return_dict_in_generate=generation_config.return_dict_in_generate,
447
+ synced_gpus=synced_gpus,
448
+ **model_kwargs,
449
+ )
450
+ elif is_sample_gen_stream_mode:
451
+ # 11. prepare logits warper
452
+ logits_warper = self._get_logits_warper(generation_config)
453
+
454
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
455
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
456
+ input_ids=input_ids,
457
+ expand_size=generation_config.num_return_sequences,
458
+ is_encoder_decoder=self.config.is_encoder_decoder,
459
+ **model_kwargs,
460
+ )
461
+
462
+ # 13. run sample
463
+ return self.sample_stream(
464
+ input_ids,
465
+ logits_processor=logits_processor,
466
+ logits_warper=logits_warper,
467
+ stopping_criteria=stopping_criteria,
468
+ pad_token_id=generation_config.pad_token_id,
469
+ eos_token_id=generation_config.eos_token_id,
470
+ output_scores=generation_config.output_scores,
471
+ return_dict_in_generate=generation_config.return_dict_in_generate,
472
+ synced_gpus=synced_gpus,
473
+ **model_kwargs,
474
+ )
475
+ elif is_beam_gen_mode:
476
+ if generation_config.num_return_sequences > generation_config.num_beams:
477
+ raise ValueError(
478
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
479
+ )
480
+
481
+ if stopping_criteria.max_length is None:
482
+ raise ValueError(
483
+ "`max_length` needs to be a stopping_criteria for now."
484
+ )
485
+
486
+ # 11. prepare beam search scorer
487
+ beam_scorer = BeamSearchScorer(
488
+ batch_size=batch_size,
489
+ num_beams=generation_config.num_beams,
490
+ device=inputs_tensor.device,
491
+ length_penalty=generation_config.length_penalty,
492
+ do_early_stopping=generation_config.early_stopping,
493
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
494
+ )
495
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
496
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
497
+ input_ids=input_ids,
498
+ expand_size=generation_config.num_beams,
499
+ is_encoder_decoder=self.config.is_encoder_decoder,
500
+ **model_kwargs,
501
+ )
502
+ # 13. run beam search
503
+ return self.beam_search(
504
+ input_ids,
505
+ beam_scorer,
506
+ logits_processor=logits_processor,
507
+ stopping_criteria=stopping_criteria,
508
+ pad_token_id=generation_config.pad_token_id,
509
+ eos_token_id=generation_config.eos_token_id,
510
+ output_scores=generation_config.output_scores,
511
+ return_dict_in_generate=generation_config.return_dict_in_generate,
512
+ synced_gpus=synced_gpus,
513
+ **model_kwargs,
514
+ )
515
+
516
+ elif is_beam_sample_gen_mode:
517
+ # 11. prepare logits warper
518
+ logits_warper = self._get_logits_warper(generation_config)
519
+
520
+ if stopping_criteria.max_length is None:
521
+ raise ValueError(
522
+ "`max_length` needs to be a stopping_criteria for now."
523
+ )
524
+ # 12. prepare beam search scorer
525
+ beam_scorer = BeamSearchScorer(
526
+ batch_size=batch_size * generation_config.num_return_sequences,
527
+ num_beams=generation_config.num_beams,
528
+ device=inputs_tensor.device,
529
+ length_penalty=generation_config.length_penalty,
530
+ do_early_stopping=generation_config.early_stopping,
531
+ )
532
+
533
+ # 13. interleave input_ids with `num_beams` additional sequences per batch
534
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
535
+ input_ids=input_ids,
536
+ expand_size=generation_config.num_beams
537
+ * generation_config.num_return_sequences,
538
+ is_encoder_decoder=self.config.is_encoder_decoder,
539
+ **model_kwargs,
540
+ )
541
+
542
+ # 14. run beam sample
543
+ return self.beam_sample(
544
+ input_ids,
545
+ beam_scorer,
546
+ logits_processor=logits_processor,
547
+ logits_warper=logits_warper,
548
+ stopping_criteria=stopping_criteria,
549
+ pad_token_id=generation_config.pad_token_id,
550
+ eos_token_id=generation_config.eos_token_id,
551
+ output_scores=generation_config.output_scores,
552
+ return_dict_in_generate=generation_config.return_dict_in_generate,
553
+ synced_gpus=synced_gpus,
554
+ **model_kwargs,
555
+ )
556
+
557
+ elif is_group_beam_gen_mode:
558
+ if generation_config.num_return_sequences > generation_config.num_beams:
559
+ raise ValueError(
560
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
561
+ )
562
+
563
+ if generation_config.num_beams % generation_config.num_beam_groups != 0:
564
+ raise ValueError(
565
+ "`num_beams` should be divisible by `num_beam_groups` for group beam search."
566
+ )
567
+
568
+ if stopping_criteria.max_length is None:
569
+ raise ValueError(
570
+ "`max_length` needs to be a stopping_criteria for now."
571
+ )
572
+
573
+ has_default_typical_p = (
574
+ kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
575
+ )
576
+ if not has_default_typical_p:
577
+ raise ValueError(
578
+ "Decoder argument `typical_p` is not supported with beam groups."
579
+ )
580
+
581
+ # 11. prepare beam search scorer
582
+ beam_scorer = BeamSearchScorer(
583
+ batch_size=batch_size,
584
+ num_beams=generation_config.num_beams,
585
+ max_length=stopping_criteria.max_length,
586
+ device=inputs_tensor.device,
587
+ length_penalty=generation_config.length_penalty,
588
+ do_early_stopping=generation_config.early_stopping,
589
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
590
+ num_beam_groups=generation_config.num_beam_groups,
591
+ )
592
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
593
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
594
+ input_ids=input_ids,
595
+ expand_size=generation_config.num_beams,
596
+ is_encoder_decoder=self.config.is_encoder_decoder,
597
+ **model_kwargs,
598
+ )
599
+ # 13. run beam search
600
+ return self.group_beam_search(
601
+ input_ids,
602
+ beam_scorer,
603
+ logits_processor=logits_processor,
604
+ stopping_criteria=stopping_criteria,
605
+ pad_token_id=generation_config.pad_token_id,
606
+ eos_token_id=generation_config.eos_token_id,
607
+ output_scores=generation_config.output_scores,
608
+ return_dict_in_generate=generation_config.return_dict_in_generate,
609
+ synced_gpus=synced_gpus,
610
+ **model_kwargs,
611
+ )
612
+
613
+ elif is_constraint_gen_mode:
614
+ if generation_config.num_return_sequences > generation_config.num_beams:
615
+ raise ValueError(
616
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
617
+ )
618
+
619
+ if stopping_criteria.max_length is None:
620
+ raise ValueError(
621
+ "`max_length` needs to be a stopping_criteria for now."
622
+ )
623
+
624
+ if generation_config.num_beams <= 1:
625
+ raise ValueError(
626
+ "`num_beams` needs to be greater than 1 for constrained generation."
627
+ )
628
+
629
+ if generation_config.do_sample:
630
+ raise ValueError(
631
+ "`do_sample` needs to be false for constrained generation."
632
+ )
633
+
634
+ if (
635
+ generation_config.num_beam_groups is not None
636
+ and generation_config.num_beam_groups > 1
637
+ ):
638
+ raise ValueError(
639
+ "`num_beam_groups` not supported yet for constrained generation."
640
+ )
641
+
642
+ final_constraints = []
643
+ if generation_config.constraints is not None:
644
+ final_constraints = generation_config.constraints
645
+
646
+ if generation_config.force_words_ids is not None:
647
+
648
+ def typeerror():
649
+ raise ValueError(
650
+ "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
651
+ f"of positive integers, but is {generation_config.force_words_ids}."
652
+ )
653
+
654
+ if (
655
+ not isinstance(generation_config.force_words_ids, list)
656
+ or len(generation_config.force_words_ids) == 0
657
+ ):
658
+ typeerror()
659
+
660
+ for word_ids in generation_config.force_words_ids:
661
+ if isinstance(word_ids[0], list):
662
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
663
+ typeerror()
664
+ if any(
665
+ not isinstance(token_ids, list) for token_ids in word_ids
666
+ ):
667
+ typeerror()
668
+ if any(
669
+ any(
670
+ (not isinstance(token_id, int) or token_id < 0)
671
+ for token_id in token_ids
672
+ )
673
+ for token_ids in word_ids
674
+ ):
675
+ typeerror()
676
+
677
+ constraint = DisjunctiveConstraint(word_ids)
678
+ else:
679
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
680
+ typeerror()
681
+ if any(
682
+ (not isinstance(token_id, int) or token_id < 0)
683
+ for token_id in word_ids
684
+ ):
685
+ typeerror()
686
+
687
+ constraint = PhrasalConstraint(word_ids)
688
+ final_constraints.append(constraint)
689
+
690
+ # 11. prepare beam search scorer
691
+ constrained_beam_scorer = ConstrainedBeamSearchScorer(
692
+ constraints=final_constraints,
693
+ batch_size=batch_size,
694
+ num_beams=generation_config.num_beams,
695
+ device=inputs_tensor.device,
696
+ length_penalty=generation_config.length_penalty,
697
+ do_early_stopping=generation_config.early_stopping,
698
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
699
+ )
700
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
701
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
702
+ input_ids=input_ids,
703
+ expand_size=generation_config.num_beams,
704
+ is_encoder_decoder=self.config.is_encoder_decoder,
705
+ **model_kwargs,
706
+ )
707
+ # 13. run beam search
708
+ return self.constrained_beam_search(
709
+ input_ids,
710
+ constrained_beam_scorer=constrained_beam_scorer,
711
+ logits_processor=logits_processor,
712
+ stopping_criteria=stopping_criteria,
713
+ pad_token_id=generation_config.pad_token_id,
714
+ eos_token_id=generation_config.eos_token_id,
715
+ output_scores=generation_config.output_scores,
716
+ return_dict_in_generate=generation_config.return_dict_in_generate,
717
+ synced_gpus=synced_gpus,
718
+ **model_kwargs,
719
+ )
720
+
721
+ @torch.no_grad()
722
+ def sample_stream(
723
+ self,
724
+ input_ids: torch.LongTensor,
725
+ logits_processor: Optional[LogitsProcessorList] = None,
726
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
727
+ logits_warper: Optional[LogitsProcessorList] = None,
728
+ max_length: Optional[int] = None,
729
+ pad_token_id: Optional[int] = None,
730
+ eos_token_id: Optional[Union[int, List[int]]] = None,
731
+ output_attentions: Optional[bool] = None,
732
+ output_hidden_states: Optional[bool] = None,
733
+ output_scores: Optional[bool] = None,
734
+ return_dict_in_generate: Optional[bool] = None,
735
+ synced_gpus: Optional[bool] = False,
736
+ **model_kwargs,
737
+ ) -> Union[SampleOutput, torch.LongTensor]:
738
+ r"""
739
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
740
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
741
+
742
+ <Tip warning={true}>
743
+
744
+ In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
745
+ For an overview of generation strategies and code examples, check the [following
746
+ guide](./generation_strategies).
747
+
748
+ </Tip>
749
+
750
+ Parameters:
751
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
752
+ The sequence used as a prompt for the generation.
753
+ logits_processor (`LogitsProcessorList`, *optional*):
754
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
755
+ used to modify the prediction scores of the language modeling head applied at each generation step.
756
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
757
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
758
+ used to tell if the generation loop should stop.
759
+ logits_warper (`LogitsProcessorList`, *optional*):
760
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
761
+ to warp the prediction score distribution of the language modeling head applied before multinomial
762
+ sampling at each generation step.
763
+ max_length (`int`, *optional*, defaults to 20):
764
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
765
+ tokens. The maximum length of the sequence to be generated.
766
+ pad_token_id (`int`, *optional*):
767
+ The id of the *padding* token.
768
+ eos_token_id (`int`, *optional*):
769
+ The id of the *end-of-sequence* token.
770
+ output_attentions (`bool`, *optional*, defaults to `False`):
771
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
772
+ returned tensors for more details.
773
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
774
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
775
+ for more details.
776
+ output_scores (`bool`, *optional*, defaults to `False`):
777
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
778
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
779
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
780
+ synced_gpus (`bool`, *optional*, defaults to `False`):
781
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
782
+ model_kwargs:
783
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
784
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
785
+
786
+ Return:
787
+ [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
788
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
789
+ [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
790
+ `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
791
+ `model.config.is_encoder_decoder=True`.
792
+
793
+ Examples:
794
+
795
+ ```python
796
+ >>> from transformers import (
797
+ ... AutoTokenizer,
798
+ ... AutoModelForCausalLM,
799
+ ... LogitsProcessorList,
800
+ ... MinLengthLogitsProcessor,
801
+ ... TopKLogitsWarper,
802
+ ... TemperatureLogitsWarper,
803
+ ... StoppingCriteriaList,
804
+ ... MaxLengthCriteria,
805
+ ... )
806
+ >>> import torch
807
+
808
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
809
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
810
+
811
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
812
+ >>> model.config.pad_token_id = model.config.eos_token_id
813
+ >>> model.generation_config.pad_token_id = model.config.eos_token_id
814
+
815
+ >>> input_prompt = "Today is a beautiful day, and"
816
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
817
+
818
+ >>> # instantiate logits processors
819
+ >>> logits_processor = LogitsProcessorList(
820
+ ... [
821
+ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
822
+ ... ]
823
+ ... )
824
+ >>> # instantiate logits processors
825
+ >>> logits_warper = LogitsProcessorList(
826
+ ... [
827
+ ... TopKLogitsWarper(50),
828
+ ... TemperatureLogitsWarper(0.7),
829
+ ... ]
830
+ ... )
831
+
832
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
833
+
834
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
835
+ >>> outputs = model.sample(
836
+ ... input_ids,
837
+ ... logits_processor=logits_processor,
838
+ ... logits_warper=logits_warper,
839
+ ... stopping_criteria=stopping_criteria,
840
+ ... )
841
+
842
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
843
+ ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
844
+ ```"""
845
+ # init values
846
+ logits_processor = (
847
+ logits_processor if logits_processor is not None else LogitsProcessorList()
848
+ )
849
+ stopping_criteria = (
850
+ stopping_criteria
851
+ if stopping_criteria is not None
852
+ else StoppingCriteriaList()
853
+ )
854
+ if max_length is not None:
855
+ warnings.warn(
856
+ "`max_length` is deprecated in this function, use"
857
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
858
+ UserWarning,
859
+ )
860
+ stopping_criteria = validate_stopping_criteria(
861
+ stopping_criteria, max_length
862
+ )
863
+ logits_warper = (
864
+ logits_warper if logits_warper is not None else LogitsProcessorList()
865
+ )
866
+ pad_token_id = (
867
+ pad_token_id
868
+ if pad_token_id is not None
869
+ else self.generation_config.pad_token_id
870
+ )
871
+ eos_token_id = (
872
+ eos_token_id
873
+ if eos_token_id is not None
874
+ else self.generation_config.eos_token_id
875
+ )
876
+ if isinstance(eos_token_id, int):
877
+ eos_token_id = [eos_token_id]
878
+ output_scores = (
879
+ output_scores
880
+ if output_scores is not None
881
+ else self.generation_config.output_scores
882
+ )
883
+ output_attentions = (
884
+ output_attentions
885
+ if output_attentions is not None
886
+ else self.generation_config.output_attentions
887
+ )
888
+ output_hidden_states = (
889
+ output_hidden_states
890
+ if output_hidden_states is not None
891
+ else self.generation_config.output_hidden_states
892
+ )
893
+ return_dict_in_generate = (
894
+ return_dict_in_generate
895
+ if return_dict_in_generate is not None
896
+ else self.generation_config.return_dict_in_generate
897
+ )
898
+
899
+ # init attention / hidden states / scores tuples
900
+ scores = () if (return_dict_in_generate and output_scores) else None
901
+ decoder_attentions = (
902
+ () if (return_dict_in_generate and output_attentions) else None
903
+ )
904
+ cross_attentions = (
905
+ () if (return_dict_in_generate and output_attentions) else None
906
+ )
907
+ decoder_hidden_states = (
908
+ () if (return_dict_in_generate and output_hidden_states) else None
909
+ )
910
+
911
+ # keep track of which sequences are already finished
912
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
913
+
914
+ this_peer_finished = False # used by synced_gpus only
915
+ # auto-regressive generation
916
+ while True:
917
+ if synced_gpus:
918
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
919
+ # The following logic allows an early break if all peers finished generating their sequence
920
+ this_peer_finished_flag = torch.tensor(
921
+ 0.0 if this_peer_finished else 1.0
922
+ ).to(input_ids.device)
923
+ # send 0.0 if we finished, 1.0 otherwise
924
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
925
+ # did all peers finish? the reduced sum will be 0.0 then
926
+ if this_peer_finished_flag.item() == 0.0:
927
+ break
928
+
929
+ # prepare model inputs
930
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
931
+
932
+ # forward pass to get next token
933
+ outputs = self(
934
+ **model_inputs,
935
+ return_dict=True,
936
+ output_attentions=output_attentions,
937
+ output_hidden_states=output_hidden_states,
938
+ )
939
+
940
+ if synced_gpus and this_peer_finished:
941
+ continue # don't waste resources running the code we don't need
942
+
943
+ next_token_logits = outputs.logits[:, -1, :]
944
+
945
+ # pre-process distribution
946
+ next_token_scores = logits_processor(input_ids, next_token_logits)
947
+ next_token_scores = logits_warper(input_ids, next_token_scores)
948
+
949
+ # Store scores, attentions and hidden_states when required
950
+ if return_dict_in_generate:
951
+ if output_scores:
952
+ scores += (next_token_scores,)
953
+ if output_attentions:
954
+ decoder_attentions += (
955
+ (outputs.decoder_attentions,)
956
+ if self.config.is_encoder_decoder
957
+ else (outputs.attentions,)
958
+ )
959
+ if self.config.is_encoder_decoder:
960
+ cross_attentions += (outputs.cross_attentions,)
961
+
962
+ if output_hidden_states:
963
+ decoder_hidden_states += (
964
+ (outputs.decoder_hidden_states,)
965
+ if self.config.is_encoder_decoder
966
+ else (outputs.hidden_states,)
967
+ )
968
+
969
+ # sample
970
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
971
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
972
+
973
+ # finished sentences should have their next token be a padding token
974
+ if eos_token_id is not None:
975
+ if pad_token_id is None:
976
+ raise ValueError(
977
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
978
+ )
979
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
980
+ 1 - unfinished_sequences
981
+ )
982
+ yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
983
+ # update generated ids, model inputs, and length for next step
984
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
985
+ model_kwargs = self._update_model_kwargs_for_generation(
986
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
987
+ )
988
+
989
+ # if eos_token was found in one sentence, set sentence to finished
990
+ if eos_token_id is not None:
991
+ unfinished_sequences = unfinished_sequences.mul(
992
+ (sum(next_tokens != i for i in eos_token_id)).long()
993
+ )
994
+
995
+ # stop when each sentence is finished, or if we exceed the maximum length
996
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
997
+ if not synced_gpus:
998
+ break
999
+ else:
1000
+ this_peer_finished = True
1001
+
1002
+
1003
+ def init_stream_support():
1004
+ """Overload PreTrainedModel for streaming."""
1005
+ PreTrainedModel.generate_stream = NewGenerationMixin.generate
1006
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
1007
+
1008
+
1009
+ if __name__ == "__main__":
1010
+ from transformers import PreTrainedModel
1011
+ from transformers import AutoTokenizer, AutoModelForCausalLM
1012
+
1013
+ PreTrainedModel.generate = NewGenerationMixin.generate
1014
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
1015
+ model = AutoModelForCausalLM.from_pretrained(
1016
+ "bigscience/bloom-560m", torch_dtype=torch.float16
1017
+ )
1018
+
1019
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
1020
+ model = model.to("cuda:0")
1021
+ model = model.eval()
1022
+ prompt_text = "hello? \n"
1023
+ input_ids = tokenizer(
1024
+ prompt_text, return_tensors="pt", add_special_tokens=False
1025
+ ).input_ids
1026
+ input_ids = input_ids.to("cuda:0")
1027
+
1028
+ with torch.no_grad():
1029
+ result = model.generate(
1030
+ input_ids,
1031
+ max_new_tokens=200,
1032
+ do_sample=True,
1033
+ top_k=30,
1034
+ top_p=0.85,
1035
+ temperature=0.35,
1036
+ repetition_penalty=1.2,
1037
+ early_stopping=True,
1038
+ seed=0,
1039
+ )
1040
+ print(tokenizer.decode(result, skip_special_tokens=True))
1041
+ generator = model.generate(
1042
+ input_ids,
1043
+ max_new_tokens=200,
1044
+ do_sample=True,
1045
+ top_k=30,
1046
+ top_p=0.85,
1047
+ temperature=0.35,
1048
+ repetition_penalty=1.2,
1049
+ early_stopping=True,
1050
+ seed=0,
1051
+ do_stream=True,
1052
+ )
1053
+ stream_result = ""
1054
+ for x in generator:
1055
+ chunk = tokenizer.decode(x, skip_special_tokens=True)
1056
+ stream_result += chunk
1057
+ print(stream_result)
tortoise/models/transformer.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from rotary_embedding_torch import RotaryEmbedding, broadcat
7
+ from torch import nn
8
+
9
+
10
+ # helpers
11
+
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+
21
+ def cast_tuple(val, depth = 1):
22
+ if isinstance(val, list):
23
+ val = tuple(val)
24
+ return val if isinstance(val, tuple) else (val,) * depth
25
+
26
+
27
+ def max_neg_value(t):
28
+ return -torch.finfo(t.dtype).max
29
+
30
+
31
+ def stable_softmax(t, dim = -1, alpha = 32 ** 2):
32
+ t = t / alpha
33
+ t = t - torch.amax(t, dim = dim, keepdim = True).detach()
34
+ return (t * alpha).softmax(dim = dim)
35
+
36
+
37
+ def route_args(router, args, depth):
38
+ routed_args = [(dict(), dict()) for _ in range(depth)]
39
+ matched_keys = [key for key in args.keys() if key in router]
40
+
41
+ for key in matched_keys:
42
+ val = args[key]
43
+ for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
44
+ new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
45
+ routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
46
+ return routed_args
47
+
48
+
49
+ # classes
50
+ class SequentialSequence(nn.Module):
51
+ def __init__(self, layers, args_route = {}, layer_dropout = 0.):
52
+ super().__init__()
53
+ assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
54
+ self.layers = layers
55
+ self.args_route = args_route
56
+ self.layer_dropout = layer_dropout
57
+
58
+ def forward(self, x, **kwargs):
59
+ args = route_args(self.args_route, kwargs, len(self.layers))
60
+ layers_and_args = list(zip(self.layers, args))
61
+
62
+ for (f, g), (f_args, g_args) in layers_and_args:
63
+ x = x + f(x, **f_args)
64
+ x = x + g(x, **g_args)
65
+ return x
66
+
67
+
68
+ class DivideMax(nn.Module):
69
+ def __init__(self, dim):
70
+ super().__init__()
71
+ self.dim = dim
72
+
73
+ def forward(self, x):
74
+ maxes = x.amax(dim = self.dim, keepdim = True).detach()
75
+ return x / maxes
76
+
77
+
78
+ # https://arxiv.org/abs/2103.17239
79
+ class LayerScale(nn.Module):
80
+ def __init__(self, dim, depth, fn):
81
+ super().__init__()
82
+ if depth <= 18:
83
+ init_eps = 0.1
84
+ elif depth > 18 and depth <= 24:
85
+ init_eps = 1e-5
86
+ else:
87
+ init_eps = 1e-6
88
+
89
+ scale = torch.zeros(1, 1, dim).fill_(init_eps)
90
+ self.scale = nn.Parameter(scale)
91
+ self.fn = fn
92
+ def forward(self, x, **kwargs):
93
+ return self.fn(x, **kwargs) * self.scale
94
+
95
+ # layer norm
96
+
97
+
98
+ class PreNorm(nn.Module):
99
+ def __init__(self, dim, fn, sandwich = False):
100
+ super().__init__()
101
+ self.norm = nn.LayerNorm(dim)
102
+ self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
103
+ self.fn = fn
104
+
105
+ def forward(self, x, **kwargs):
106
+ x = self.norm(x)
107
+ x = self.fn(x, **kwargs)
108
+ return self.norm_out(x)
109
+
110
+ # feed forward
111
+
112
+
113
+ class GEGLU(nn.Module):
114
+ def forward(self, x):
115
+ x, gates = x.chunk(2, dim = -1)
116
+ return x * F.gelu(gates)
117
+
118
+
119
+ class FeedForward(nn.Module):
120
+ def __init__(self, dim, dropout = 0., mult = 4.):
121
+ super().__init__()
122
+ self.net = nn.Sequential(
123
+ nn.Linear(dim, dim * mult * 2),
124
+ GEGLU(),
125
+ nn.Dropout(dropout),
126
+ nn.Linear(dim * mult, dim)
127
+ )
128
+
129
+ def forward(self, x):
130
+ return self.net(x)
131
+
132
+ # Attention
133
+
134
+
135
+ class Attention(nn.Module):
136
+ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
137
+ super().__init__()
138
+ inner_dim = dim_head * heads
139
+ self.heads = heads
140
+ self.seq_len = seq_len
141
+ self.scale = dim_head ** -0.5
142
+
143
+ self.causal = causal
144
+
145
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
146
+ self.to_out = nn.Sequential(
147
+ nn.Linear(inner_dim, dim),
148
+ nn.Dropout(dropout)
149
+ )
150
+
151
+ def forward(self, x, mask = None):
152
+ b, n, _, h, device = *x.shape, self.heads, x.device
153
+ softmax = torch.softmax
154
+
155
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
156
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
157
+
158
+ q = q * self.scale
159
+
160
+ dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
161
+ mask_value = max_neg_value(dots)
162
+
163
+ if exists(mask):
164
+ mask = rearrange(mask, 'b j -> b () () j')
165
+ dots.masked_fill_(~mask, mask_value)
166
+ del mask
167
+
168
+ if self.causal:
169
+ i, j = dots.shape[-2:]
170
+ mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
171
+ dots.masked_fill_(mask, mask_value)
172
+
173
+ attn = softmax(dots, dim=-1)
174
+
175
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
176
+ out = rearrange(out, 'b h n d -> b n (h d)')
177
+ out = self.to_out(out)
178
+ return out
179
+
180
+
181
+ # main transformer class
182
+ class Transformer(nn.Module):
183
+ def __init__(
184
+ self,
185
+ *,
186
+ dim,
187
+ depth,
188
+ seq_len,
189
+ causal = True,
190
+ heads = 8,
191
+ dim_head = 64,
192
+ ff_mult = 4,
193
+ attn_dropout = 0.,
194
+ ff_dropout = 0.,
195
+ sparse_attn = False,
196
+ sandwich_norm = False,
197
+ ):
198
+ super().__init__()
199
+ layers = nn.ModuleList([])
200
+ sparse_layer = cast_tuple(sparse_attn, depth)
201
+
202
+ for ind, sparse_attn in zip(range(depth), sparse_layer):
203
+ attn = Attention(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
204
+
205
+ ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
206
+
207
+ layers.append(nn.ModuleList([
208
+ LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
209
+ LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
210
+ ]))
211
+
212
+ execute_type = SequentialSequence
213
+ route_attn = ((True, False),) * depth
214
+ attn_route_map = {'mask': route_attn}
215
+
216
+ self.layers = execute_type(layers, args_route = attn_route_map)
217
+
218
+ def forward(self, x, **kwargs):
219
+ return self.layers(x, **kwargs)
tortoise/models/vocoder.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ MAX_WAV_VALUE = 32768.0
6
+
7
+ class KernelPredictor(torch.nn.Module):
8
+ ''' Kernel predictor for the location-variable convolutions'''
9
+
10
+ def __init__(
11
+ self,
12
+ cond_channels,
13
+ conv_in_channels,
14
+ conv_out_channels,
15
+ conv_layers,
16
+ conv_kernel_size=3,
17
+ kpnet_hidden_channels=64,
18
+ kpnet_conv_size=3,
19
+ kpnet_dropout=0.0,
20
+ kpnet_nonlinear_activation="LeakyReLU",
21
+ kpnet_nonlinear_activation_params={"negative_slope": 0.1},
22
+ ):
23
+ '''
24
+ Args:
25
+ cond_channels (int): number of channel for the conditioning sequence,
26
+ conv_in_channels (int): number of channel for the input sequence,
27
+ conv_out_channels (int): number of channel for the output sequence,
28
+ conv_layers (int): number of layers
29
+ '''
30
+ super().__init__()
31
+
32
+ self.conv_in_channels = conv_in_channels
33
+ self.conv_out_channels = conv_out_channels
34
+ self.conv_kernel_size = conv_kernel_size
35
+ self.conv_layers = conv_layers
36
+
37
+ kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
38
+ kpnet_bias_channels = conv_out_channels * conv_layers # l_b
39
+
40
+ self.input_conv = nn.Sequential(
41
+ nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
42
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
43
+ )
44
+
45
+ self.residual_convs = nn.ModuleList()
46
+ padding = (kpnet_conv_size - 1) // 2
47
+ for _ in range(3):
48
+ self.residual_convs.append(
49
+ nn.Sequential(
50
+ nn.Dropout(kpnet_dropout),
51
+ nn.utils.weight_norm(
52
+ nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
53
+ bias=True)),
54
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
55
+ nn.utils.weight_norm(
56
+ nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
57
+ bias=True)),
58
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
59
+ )
60
+ )
61
+ self.kernel_conv = nn.utils.weight_norm(
62
+ nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True))
63
+ self.bias_conv = nn.utils.weight_norm(
64
+ nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True))
65
+
66
+ def forward(self, c):
67
+ '''
68
+ Args:
69
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
70
+ '''
71
+ batch, _, cond_length = c.shape
72
+ c = self.input_conv(c)
73
+ for residual_conv in self.residual_convs:
74
+ residual_conv.to(c.device)
75
+ c = c + residual_conv(c)
76
+ k = self.kernel_conv(c)
77
+ b = self.bias_conv(c)
78
+ kernels = k.contiguous().view(
79
+ batch,
80
+ self.conv_layers,
81
+ self.conv_in_channels,
82
+ self.conv_out_channels,
83
+ self.conv_kernel_size,
84
+ cond_length,
85
+ )
86
+ bias = b.contiguous().view(
87
+ batch,
88
+ self.conv_layers,
89
+ self.conv_out_channels,
90
+ cond_length,
91
+ )
92
+
93
+ return kernels, bias
94
+
95
+ def remove_weight_norm(self):
96
+ nn.utils.remove_weight_norm(self.input_conv[0])
97
+ nn.utils.remove_weight_norm(self.kernel_conv)
98
+ nn.utils.remove_weight_norm(self.bias_conv)
99
+ for block in self.residual_convs:
100
+ nn.utils.remove_weight_norm(block[1])
101
+ nn.utils.remove_weight_norm(block[3])
102
+
103
+
104
+ class LVCBlock(torch.nn.Module):
105
+ '''the location-variable convolutions'''
106
+
107
+ def __init__(
108
+ self,
109
+ in_channels,
110
+ cond_channels,
111
+ stride,
112
+ dilations=[1, 3, 9, 27],
113
+ lReLU_slope=0.2,
114
+ conv_kernel_size=3,
115
+ cond_hop_length=256,
116
+ kpnet_hidden_channels=64,
117
+ kpnet_conv_size=3,
118
+ kpnet_dropout=0.0,
119
+ ):
120
+ super().__init__()
121
+
122
+ self.cond_hop_length = cond_hop_length
123
+ self.conv_layers = len(dilations)
124
+ self.conv_kernel_size = conv_kernel_size
125
+
126
+ self.kernel_predictor = KernelPredictor(
127
+ cond_channels=cond_channels,
128
+ conv_in_channels=in_channels,
129
+ conv_out_channels=2 * in_channels,
130
+ conv_layers=len(dilations),
131
+ conv_kernel_size=conv_kernel_size,
132
+ kpnet_hidden_channels=kpnet_hidden_channels,
133
+ kpnet_conv_size=kpnet_conv_size,
134
+ kpnet_dropout=kpnet_dropout,
135
+ kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}
136
+ )
137
+
138
+ self.convt_pre = nn.Sequential(
139
+ nn.LeakyReLU(lReLU_slope),
140
+ nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride,
141
+ padding=stride // 2 + stride % 2, output_padding=stride % 2)),
142
+ )
143
+
144
+ self.conv_blocks = nn.ModuleList()
145
+ for dilation in dilations:
146
+ self.conv_blocks.append(
147
+ nn.Sequential(
148
+ nn.LeakyReLU(lReLU_slope),
149
+ nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size,
150
+ padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)),
151
+ nn.LeakyReLU(lReLU_slope),
152
+ )
153
+ )
154
+
155
+ def forward(self, x, c):
156
+ ''' forward propagation of the location-variable convolutions.
157
+ Args:
158
+ x (Tensor): the input sequence (batch, in_channels, in_length)
159
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
160
+
161
+ Returns:
162
+ Tensor: the output sequence (batch, in_channels, in_length)
163
+ '''
164
+ _, in_channels, _ = x.shape # (B, c_g, L')
165
+
166
+ x = self.convt_pre(x) # (B, c_g, stride * L')
167
+ kernels, bias = self.kernel_predictor(c)
168
+
169
+ for i, conv in enumerate(self.conv_blocks):
170
+ output = conv(x) # (B, c_g, stride * L')
171
+
172
+ k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
173
+ b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
174
+
175
+ output = self.location_variable_convolution(output, k, b,
176
+ hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC
177
+ x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
178
+ output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU
179
+
180
+ return x
181
+
182
+ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
183
+ ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
184
+ Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
185
+ Args:
186
+ x (Tensor): the input sequence (batch, in_channels, in_length).
187
+ kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
188
+ bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
189
+ dilation (int): the dilation of convolution.
190
+ hop_size (int): the hop_size of the conditioning sequence.
191
+ Returns:
192
+ (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
193
+ '''
194
+ batch, _, in_length = x.shape
195
+ batch, _, out_channels, kernel_size, kernel_length = kernel.shape
196
+ assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
197
+
198
+ padding = dilation * int((kernel_size - 1) / 2)
199
+ x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
200
+ x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
201
+
202
+ if hop_size < dilation:
203
+ x = F.pad(x, (0, dilation), 'constant', 0)
204
+ x = x.unfold(3, dilation,
205
+ dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
206
+ x = x[:, :, :, :, :hop_size]
207
+ x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
208
+ x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
209
+
210
+ o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
211
+ o = o.to(memory_format=torch.channels_last_3d)
212
+ bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
213
+ o = o + bias
214
+ o = o.contiguous().view(batch, out_channels, -1)
215
+
216
+ return o
217
+
218
+ def remove_weight_norm(self):
219
+ self.kernel_predictor.remove_weight_norm()
220
+ nn.utils.remove_weight_norm(self.convt_pre[1])
221
+ for block in self.conv_blocks:
222
+ nn.utils.remove_weight_norm(block[1])
223
+
224
+
225
+ class UnivNetGenerator(nn.Module):
226
+ """
227
+ UnivNet Generator
228
+
229
+ Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py.
230
+ """
231
+
232
+ def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
233
+ # Below are MEL configurations options that this generator requires.
234
+ hop_length=256, n_mel_channels=100):
235
+ super(UnivNetGenerator, self).__init__()
236
+ self.mel_channel = n_mel_channels
237
+ self.noise_dim = noise_dim
238
+ self.hop_length = hop_length
239
+ channel_size = channel_size
240
+ kpnet_conv_size = kpnet_conv_size
241
+
242
+ self.res_stack = nn.ModuleList()
243
+ hop_length = 1
244
+ for stride in strides:
245
+ hop_length = stride * hop_length
246
+ self.res_stack.append(
247
+ LVCBlock(
248
+ channel_size,
249
+ n_mel_channels,
250
+ stride=stride,
251
+ dilations=dilations,
252
+ lReLU_slope=lReLU_slope,
253
+ cond_hop_length=hop_length,
254
+ kpnet_conv_size=kpnet_conv_size
255
+ )
256
+ )
257
+
258
+ self.conv_pre = \
259
+ nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
260
+
261
+ self.conv_post = nn.Sequential(
262
+ nn.LeakyReLU(lReLU_slope),
263
+ nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
264
+ nn.Tanh(),
265
+ )
266
+
267
+ def forward(self, c, z):
268
+ '''
269
+ Args:
270
+ c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
271
+ z (Tensor): the noise sequence (batch, noise_dim, in_length)
272
+
273
+ '''
274
+ z = self.conv_pre(z) # (B, c_g, L)
275
+
276
+ for res_block in self.res_stack:
277
+ res_block.to(z.device)
278
+ z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
279
+
280
+ z = self.conv_post(z) # (B, 1, L * 256)
281
+
282
+ return z
283
+
284
+ def eval(self, inference=False):
285
+ super(UnivNetGenerator, self).eval()
286
+ # don't remove weight norm while validation in training loop
287
+ if inference:
288
+ self.remove_weight_norm()
289
+
290
+ def remove_weight_norm(self):
291
+ nn.utils.remove_weight_norm(self.conv_pre)
292
+
293
+ for layer in self.conv_post:
294
+ if len(layer.state_dict()) != 0:
295
+ nn.utils.remove_weight_norm(layer)
296
+
297
+ for res_block in self.res_stack:
298
+ res_block.remove_weight_norm()
299
+
300
+ def inference(self, c, z=None):
301
+ # pad input mel with zeros to cut artifact
302
+ # see https://github.com/seungwonpark/melgan/issues/8
303
+ zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
304
+ mel = torch.cat((c, zero), dim=2)
305
+
306
+ if z is None:
307
+ z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
308
+
309
+ audio = self.forward(mel, z)
310
+ audio = audio[:, :, :-(self.hop_length * 10)]
311
+ audio = audio.clamp(min=-1, max=1)
312
+ return audio
313
+
314
+
315
+ if __name__ == '__main__':
316
+ model = UnivNetGenerator()
317
+
318
+ c = torch.randn(3, 100, 10)
319
+ z = torch.randn(3, 64, 10)
320
+ print(c.shape)
321
+
322
+ y = model(c, z)
323
+ print(y.shape)
324
+ assert y.shape == torch.Size([3, 1, 2560])
325
+
326
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
327
+ print(pytorch_total_params)
tortoise/models/xtransformers.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import namedtuple
3
+ from functools import partial
4
+ from inspect import isfunction
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from torch import nn, einsum
10
+
11
+ DEFAULT_DIM_HEAD = 64
12
+
13
+ Intermediates = namedtuple('Intermediates', [
14
+ 'pre_softmax_attn',
15
+ 'post_softmax_attn'
16
+ ])
17
+
18
+ LayerIntermediates = namedtuple('Intermediates', [
19
+ 'hiddens',
20
+ 'attn_intermediates',
21
+ 'past_key_values',
22
+ ])
23
+
24
+
25
+ # helpers
26
+
27
+ def exists(val):
28
+ return val is not None
29
+
30
+
31
+ def default(val, d):
32
+ if exists(val):
33
+ return val
34
+ return d() if isfunction(d) else d
35
+
36
+
37
+ def cast_tuple(val, depth):
38
+ return val if isinstance(val, tuple) else (val,) * depth
39
+
40
+
41
+ class always():
42
+ def __init__(self, val):
43
+ self.val = val
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ return self.val
47
+
48
+
49
+ class not_equals():
50
+ def __init__(self, val):
51
+ self.val = val
52
+
53
+ def __call__(self, x, *args, **kwargs):
54
+ return x != self.val
55
+
56
+
57
+ class equals():
58
+ def __init__(self, val):
59
+ self.val = val
60
+
61
+ def __call__(self, x, *args, **kwargs):
62
+ return x == self.val
63
+
64
+
65
+ def max_neg_value(tensor):
66
+ return -torch.finfo(tensor.dtype).max
67
+
68
+
69
+ def l2norm(t):
70
+ return F.normalize(t, p=2, dim=-1)
71
+
72
+
73
+ # init helpers
74
+
75
+ def init_zero_(layer):
76
+ nn.init.constant_(layer.weight, 0.)
77
+ if exists(layer.bias):
78
+ nn.init.constant_(layer.bias, 0.)
79
+
80
+
81
+ # keyword argument helpers
82
+
83
+ def pick_and_pop(keys, d):
84
+ values = list(map(lambda key: d.pop(key), keys))
85
+ return dict(zip(keys, values))
86
+
87
+
88
+ def group_dict_by_key(cond, d):
89
+ return_val = [dict(), dict()]
90
+ for key in d.keys():
91
+ match = bool(cond(key))
92
+ ind = int(not match)
93
+ return_val[ind][key] = d[key]
94
+ return (*return_val,)
95
+
96
+
97
+ def string_begins_with(prefix, str):
98
+ return str.startswith(prefix)
99
+
100
+
101
+ def group_by_key_prefix(prefix, d):
102
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
103
+
104
+
105
+ def groupby_prefix_and_trim(prefix, d):
106
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
107
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
108
+ return kwargs_without_prefix, kwargs
109
+
110
+
111
+ # activations
112
+
113
+ class ReluSquared(nn.Module):
114
+ def forward(self, x):
115
+ return F.relu(x) ** 2
116
+
117
+
118
+ # positional embeddings
119
+
120
+ class AbsolutePositionalEmbedding(nn.Module):
121
+ def __init__(self, dim, max_seq_len):
122
+ super().__init__()
123
+ self.scale = dim ** -0.5
124
+ self.emb = nn.Embedding(max_seq_len, dim)
125
+
126
+ def forward(self, x):
127
+ n = torch.arange(x.shape[1], device=x.device)
128
+ pos_emb = self.emb(n)
129
+ pos_emb = rearrange(pos_emb, 'n d -> () n d')
130
+ return pos_emb * self.scale
131
+
132
+
133
+ class FixedPositionalEmbedding(nn.Module):
134
+ def __init__(self, dim):
135
+ super().__init__()
136
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
137
+ self.register_buffer('inv_freq', inv_freq)
138
+
139
+ def forward(self, x, seq_dim=1, offset=0):
140
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
141
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
142
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
143
+ return rearrange(emb, 'n d -> () n d')
144
+
145
+
146
+ class RelativePositionBias(nn.Module):
147
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
148
+ super().__init__()
149
+ self.scale = scale
150
+ self.causal = causal
151
+ self.num_buckets = num_buckets
152
+ self.max_distance = max_distance
153
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
154
+
155
+ @staticmethod
156
+ def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
157
+ ret = 0
158
+ n = -relative_position
159
+ if not causal:
160
+ num_buckets //= 2
161
+ ret += (n < 0).long() * num_buckets
162
+ n = torch.abs(n)
163
+ else:
164
+ n = torch.max(n, torch.zeros_like(n))
165
+
166
+ max_exact = num_buckets // 2
167
+ is_small = n < max_exact
168
+
169
+ val_if_large = max_exact + (
170
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
171
+ ).long()
172
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
173
+
174
+ ret += torch.where(is_small, n, val_if_large)
175
+ return ret
176
+
177
+ def forward(self, qk_dots):
178
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
179
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
180
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
181
+ rel_pos = k_pos[None, :] - q_pos[:, None]
182
+ rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
183
+ max_distance=self.max_distance)
184
+ values = self.relative_attention_bias(rp_bucket)
185
+ bias = rearrange(values, 'i j h -> () h i j')
186
+ return qk_dots + (bias * self.scale)
187
+
188
+
189
+ class AlibiPositionalBias(nn.Module):
190
+ def __init__(self, heads, **kwargs):
191
+ super().__init__()
192
+ self.heads = heads
193
+ slopes = torch.Tensor(self._get_slopes(heads))
194
+ slopes = rearrange(slopes, 'h -> () h () ()')
195
+ self.register_buffer('slopes', slopes, persistent=False)
196
+ self.register_buffer('bias', None, persistent=False)
197
+
198
+ @staticmethod
199
+ def _get_slopes(heads):
200
+ def get_slopes_power_of_2(n):
201
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
202
+ ratio = start
203
+ return [start * ratio ** i for i in range(n)]
204
+
205
+ if math.log2(heads).is_integer():
206
+ return get_slopes_power_of_2(heads)
207
+
208
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
209
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
210
+ :heads - closest_power_of_2]
211
+
212
+ def forward(self, qk_dots):
213
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
214
+
215
+ if exists(self.bias) and self.bias.shape[-1] >= j:
216
+ return qk_dots + self.bias[..., :j]
217
+
218
+ bias = torch.arange(j, device=device)
219
+ bias = rearrange(bias, 'j -> () () () j')
220
+ bias = bias * self.slopes
221
+
222
+ num_heads_unalibied = h - bias.shape[1]
223
+ bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
224
+
225
+ self.register_buffer('bias', bias, persistent=False)
226
+ return qk_dots + self.bias
227
+
228
+
229
+ class LearnedAlibiPositionalBias(AlibiPositionalBias):
230
+ def __init__(self, heads, bidirectional=False):
231
+ super().__init__(heads)
232
+ los_slopes = torch.log(self.slopes)
233
+ self.learned_logslopes = nn.Parameter(los_slopes)
234
+
235
+ self.bidirectional = bidirectional
236
+ if self.bidirectional:
237
+ self.learned_logslopes_future = nn.Parameter(los_slopes)
238
+
239
+ def forward(self, qk_dots):
240
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
241
+
242
+ def get_slopes(param):
243
+ return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
244
+
245
+ if exists(self.bias) and self.bias.shape[-1] >= j:
246
+ bias = self.bias[..., :i, :j]
247
+ else:
248
+ i_arange = torch.arange(i, device=device)
249
+ j_arange = torch.arange(j, device=device)
250
+ bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
251
+ self.register_buffer('bias', bias, persistent=False)
252
+
253
+ if self.bidirectional:
254
+ past_slopes = get_slopes(self.learned_logslopes)
255
+ future_slopes = get_slopes(self.learned_logslopes_future)
256
+ bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
257
+ else:
258
+ slopes = get_slopes(self.learned_logslopes)
259
+ bias = bias * slopes
260
+
261
+ return qk_dots + bias
262
+
263
+
264
+ class RotaryEmbedding(nn.Module):
265
+ def __init__(self, dim):
266
+ super().__init__()
267
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
268
+ self.register_buffer('inv_freq', inv_freq)
269
+
270
+ def forward(self, max_seq_len, device):
271
+ t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
272
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
273
+ emb = torch.cat((freqs, freqs), dim=-1)
274
+ return rearrange(emb, 'n d -> () () n d')
275
+
276
+
277
+ def rotate_half(x):
278
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
279
+ x1, x2 = x.unbind(dim=-2)
280
+ return torch.cat((-x2, x1), dim=-1)
281
+
282
+
283
+ def apply_rotary_pos_emb(t, freqs):
284
+ seq_len = t.shape[-2]
285
+ freqs = freqs[:, :, -seq_len:]
286
+ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
287
+
288
+
289
+ # norms
290
+
291
+ class Scale(nn.Module):
292
+ def __init__(self, value, fn):
293
+ super().__init__()
294
+ self.value = value
295
+ self.fn = fn
296
+
297
+ def forward(self, x, **kwargs):
298
+ out = self.fn(x, **kwargs)
299
+ scale_fn = lambda t: t * self.value
300
+
301
+ if not isinstance(out, tuple):
302
+ return scale_fn(out)
303
+
304
+ return (scale_fn(out[0]), *out[1:])
305
+
306
+
307
+ class Rezero(nn.Module):
308
+ def __init__(self, fn):
309
+ super().__init__()
310
+ self.fn = fn
311
+ self.g = nn.Parameter(torch.zeros(1))
312
+
313
+ def forward(self, x, **kwargs):
314
+ out = self.fn(x, **kwargs)
315
+ rezero_fn = lambda t: t * self.g
316
+
317
+ if not isinstance(out, tuple):
318
+ return rezero_fn(out)
319
+
320
+ return (rezero_fn(out[0]), *out[1:])
321
+
322
+
323
+ class ScaleNorm(nn.Module):
324
+ def __init__(self, dim, eps=1e-5):
325
+ super().__init__()
326
+ self.scale = dim ** -0.5
327
+ self.eps = eps
328
+ self.g = nn.Parameter(torch.ones(1))
329
+
330
+ def forward(self, x):
331
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
332
+ return x / norm.clamp(min=self.eps) * self.g
333
+
334
+
335
+ class RMSNorm(nn.Module):
336
+ def __init__(self, dim, eps=1e-8):
337
+ super().__init__()
338
+ self.scale = dim ** -0.5
339
+ self.eps = eps
340
+ self.g = nn.Parameter(torch.ones(dim))
341
+
342
+ def forward(self, x):
343
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
344
+ return x / norm.clamp(min=self.eps) * self.g
345
+
346
+
347
+ class RMSScaleShiftNorm(nn.Module):
348
+ def __init__(self, dim, eps=1e-8):
349
+ super().__init__()
350
+ self.scale = dim ** -0.5
351
+ self.eps = eps
352
+ self.g = nn.Parameter(torch.ones(dim))
353
+ self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
354
+
355
+ def forward(self, x, norm_scale_shift_inp):
356
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
357
+ norm = x / norm.clamp(min=self.eps) * self.g
358
+
359
+ ss_emb = self.scale_shift_process(norm_scale_shift_inp)
360
+ scale, shift = torch.chunk(ss_emb, 2, dim=1)
361
+ h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
362
+ return h
363
+
364
+
365
+ # residual and residual gates
366
+
367
+ class Residual(nn.Module):
368
+ def __init__(self, dim, scale_residual=False):
369
+ super().__init__()
370
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
371
+
372
+ def forward(self, x, residual):
373
+ if exists(self.residual_scale):
374
+ residual = residual * self.residual_scale
375
+
376
+ return x + residual
377
+
378
+
379
+ class GRUGating(nn.Module):
380
+ def __init__(self, dim, scale_residual=False):
381
+ super().__init__()
382
+ self.gru = nn.GRUCell(dim, dim)
383
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
384
+
385
+ def forward(self, x, residual):
386
+ if exists(self.residual_scale):
387
+ residual = residual * self.residual_scale
388
+
389
+ gated_output = self.gru(
390
+ rearrange(x, 'b n d -> (b n) d'),
391
+ rearrange(residual, 'b n d -> (b n) d')
392
+ )
393
+
394
+ return gated_output.reshape_as(x)
395
+
396
+
397
+ # token shifting
398
+
399
+ def shift(t, amount, mask=None):
400
+ if amount == 0:
401
+ return t
402
+
403
+ if exists(mask):
404
+ t = t.masked_fill(~mask[..., None], 0.)
405
+
406
+ return F.pad(t, (0, 0, amount, -amount), value=0.)
407
+
408
+
409
+ class ShiftTokens(nn.Module):
410
+ def __init__(self, shifts, fn):
411
+ super().__init__()
412
+ self.fn = fn
413
+ self.shifts = tuple(shifts)
414
+
415
+ def forward(self, x, **kwargs):
416
+ mask = kwargs.get('mask', None)
417
+ shifts = self.shifts
418
+ segments = len(shifts)
419
+ feats_per_shift = x.shape[-1] // segments
420
+ splitted = x.split(feats_per_shift, dim=-1)
421
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
422
+ segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
423
+ x = torch.cat((*segments_to_shift, *rest), dim=-1)
424
+ return self.fn(x, **kwargs)
425
+
426
+
427
+ # feedforward
428
+
429
+ class GLU(nn.Module):
430
+ def __init__(self, dim_in, dim_out, activation):
431
+ super().__init__()
432
+ self.act = activation
433
+ self.proj = nn.Linear(dim_in, dim_out * 2)
434
+
435
+ def forward(self, x):
436
+ x, gate = self.proj(x).chunk(2, dim=-1)
437
+ return x * self.act(gate)
438
+
439
+
440
+ class FeedForward(nn.Module):
441
+ def __init__(
442
+ self,
443
+ dim,
444
+ dim_out=None,
445
+ mult=4,
446
+ glu=False,
447
+ relu_squared=False,
448
+ post_act_ln=False,
449
+ dropout=0.,
450
+ zero_init_output=False
451
+ ):
452
+ super().__init__()
453
+ inner_dim = int(dim * mult)
454
+ dim_out = default(dim_out, dim)
455
+ activation = ReluSquared() if relu_squared else nn.GELU()
456
+
457
+ project_in = nn.Sequential(
458
+ nn.Linear(dim, inner_dim),
459
+ activation
460
+ ) if not glu else GLU(dim, inner_dim, activation)
461
+
462
+ self.net = nn.Sequential(
463
+ project_in,
464
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
465
+ nn.Dropout(dropout),
466
+ nn.Linear(inner_dim, dim_out)
467
+ )
468
+
469
+ # init last linear layer to 0
470
+ if zero_init_output:
471
+ init_zero_(self.net[-1])
472
+
473
+ def forward(self, x):
474
+ return self.net(x)
475
+
476
+
477
+ # attention.
478
+
479
+ class Attention(nn.Module):
480
+ def __init__(
481
+ self,
482
+ dim,
483
+ dim_head=DEFAULT_DIM_HEAD,
484
+ heads=8,
485
+ causal=False,
486
+ talking_heads=False,
487
+ head_scale=False,
488
+ collab_heads=False,
489
+ collab_compression=.3,
490
+ sparse_topk=None,
491
+ use_entmax15=False,
492
+ num_mem_kv=0,
493
+ dropout=0.,
494
+ on_attn=False,
495
+ gate_values=False,
496
+ zero_init_output=False,
497
+ max_attend_past=None,
498
+ qk_norm=False,
499
+ scale_init_value=None,
500
+ rel_pos_bias=False,
501
+ rel_pos_num_buckets=32,
502
+ rel_pos_max_distance=128,
503
+ ):
504
+ super().__init__()
505
+ self.scale = dim_head ** -0.5
506
+
507
+ self.heads = heads
508
+ self.causal = causal
509
+ self.max_attend_past = max_attend_past
510
+
511
+ qk_dim = v_dim = dim_head * heads
512
+
513
+ # collaborative heads
514
+ self.collab_heads = collab_heads
515
+ if self.collab_heads:
516
+ qk_dim = int(collab_compression * qk_dim)
517
+ self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
518
+
519
+ self.to_q = nn.Linear(dim, qk_dim, bias=False)
520
+ self.to_k = nn.Linear(dim, qk_dim, bias=False)
521
+ self.to_v = nn.Linear(dim, v_dim, bias=False)
522
+
523
+ self.dropout = nn.Dropout(dropout)
524
+
525
+ # add GLU gating for aggregated values, from alphafold2
526
+ self.to_v_gate = None
527
+ if gate_values:
528
+ self.to_v_gate = nn.Linear(dim, v_dim)
529
+ nn.init.constant_(self.to_v_gate.weight, 0)
530
+ nn.init.constant_(self.to_v_gate.bias, 1)
531
+
532
+ # cosine sim attention
533
+ self.qk_norm = qk_norm
534
+ if qk_norm:
535
+ scale_init_value = default(scale_init_value,
536
+ -3) # if not provided, initialize as though it were sequence length of 1024
537
+ self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
538
+
539
+ # talking heads
540
+ self.talking_heads = talking_heads
541
+ if talking_heads:
542
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
543
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
544
+
545
+ # head scaling
546
+ self.head_scale = head_scale
547
+ if head_scale:
548
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
549
+
550
+ # explicit topk sparse attention
551
+ self.sparse_topk = sparse_topk
552
+
553
+ # entmax
554
+ self.attn_fn = F.softmax
555
+
556
+ # add memory key / values
557
+ self.num_mem_kv = num_mem_kv
558
+ if num_mem_kv > 0:
559
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
560
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
561
+
562
+ # attention on attention
563
+ self.attn_on_attn = on_attn
564
+ self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
565
+
566
+ self.rel_pos_bias = rel_pos_bias
567
+ if rel_pos_bias:
568
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
569
+ self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
570
+ num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
571
+
572
+ # init output projection 0
573
+ if zero_init_output:
574
+ init_zero_(self.to_out)
575
+
576
+ def forward(
577
+ self,
578
+ x,
579
+ context=None,
580
+ mask=None,
581
+ context_mask=None,
582
+ attn_mask=None,
583
+ sinusoidal_emb=None,
584
+ rotary_pos_emb=None,
585
+ prev_attn=None,
586
+ mem=None,
587
+ layer_past=None,
588
+ ):
589
+ b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
590
+ context)
591
+ kv_input = default(context, x)
592
+
593
+ q_input = x
594
+ k_input = kv_input
595
+ v_input = kv_input
596
+
597
+ if exists(mem):
598
+ k_input = torch.cat((mem, k_input), dim=-2)
599
+ v_input = torch.cat((mem, v_input), dim=-2)
600
+
601
+ if exists(sinusoidal_emb):
602
+ # in shortformer, the query would start at a position offset depending on the past cached memory
603
+ offset = k_input.shape[-2] - q_input.shape[-2]
604
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
605
+ k_input = k_input + sinusoidal_emb(k_input)
606
+
607
+ q = self.to_q(q_input)
608
+ k = self.to_k(k_input)
609
+ v = self.to_v(v_input)
610
+
611
+ if not collab_heads:
612
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
613
+ else:
614
+ q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
615
+ k = rearrange(k, 'b n d -> b () n d')
616
+ v = rearrange(v, 'b n (h d) -> b h n d', h=h)
617
+
618
+ if layer_past is not None:
619
+ past_key, past_value = layer_past
620
+ k = torch.cat([past_key, k], dim=-2)
621
+ v = torch.cat([past_value, v], dim=-2)
622
+ k_cache = k
623
+ v_cache = v
624
+
625
+ if exists(rotary_pos_emb) and not has_context:
626
+ l = rotary_pos_emb.shape[-1]
627
+ (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
628
+ ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
629
+ q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
630
+
631
+ input_mask = None
632
+ if any(map(exists, (mask, context_mask))):
633
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
634
+ k_mask = q_mask if not exists(context) else context_mask
635
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
636
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
637
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
638
+ input_mask = q_mask * k_mask
639
+
640
+ if self.num_mem_kv > 0:
641
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
642
+ k = torch.cat((mem_k, k), dim=-2)
643
+ v = torch.cat((mem_v, v), dim=-2)
644
+ if exists(input_mask):
645
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
646
+
647
+ if collab_heads:
648
+ k = k.expand(-1, h, -1, -1)
649
+
650
+ if self.qk_norm:
651
+ q, k = map(l2norm, (q, k))
652
+ scale = 1 / (self.scale.exp().clamp(min=1e-2))
653
+
654
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
655
+ mask_value = max_neg_value(dots)
656
+
657
+ if exists(prev_attn):
658
+ dots = dots + prev_attn
659
+
660
+ pre_softmax_attn = dots.clone()
661
+
662
+ if talking_heads:
663
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
664
+
665
+ if self.rel_pos_bias:
666
+ dots = self.rel_pos(dots)
667
+
668
+ if exists(input_mask):
669
+ dots.masked_fill_(~input_mask, mask_value)
670
+ del input_mask
671
+
672
+ if exists(attn_mask):
673
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
674
+ if attn_mask.ndim == 2:
675
+ attn_mask = rearrange(attn_mask, 'i j -> () () i j')
676
+ elif attn_mask.ndim == 3:
677
+ attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
678
+ dots.masked_fill_(~attn_mask, mask_value)
679
+
680
+ if exists(self.max_attend_past):
681
+ i, j = dots.shape[-2:]
682
+ range_q = torch.arange(j - i, j, device=device)
683
+ range_k = torch.arange(j, device=device)
684
+ dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
685
+ mask = dist > self.max_attend_past
686
+ dots.masked_fill_(mask, mask_value)
687
+ del mask
688
+
689
+ if self.causal:
690
+ i, j = dots.shape[-2:]
691
+ r = torch.arange(i, device=device)
692
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
693
+ mask = F.pad(mask, (j - i, 0), value=False)
694
+ dots.masked_fill_(mask, mask_value)
695
+ del mask
696
+
697
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
698
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
699
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
700
+ mask = dots < vk
701
+ dots.masked_fill_(mask, mask_value)
702
+ del mask
703
+
704
+ attn = self.attn_fn(dots, dim=-1)
705
+ post_softmax_attn = attn.clone()
706
+
707
+ attn = self.dropout(attn)
708
+
709
+ if talking_heads:
710
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
711
+
712
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
713
+
714
+ if head_scale:
715
+ out = out * self.head_scale_params
716
+
717
+ out = rearrange(out, 'b h n d -> b n (h d)')
718
+
719
+ if exists(self.to_v_gate):
720
+ gates = self.to_v_gate(x)
721
+ out = out * gates.sigmoid()
722
+
723
+ intermediates = Intermediates(
724
+ pre_softmax_attn=pre_softmax_attn,
725
+ post_softmax_attn=post_softmax_attn
726
+ )
727
+
728
+ return self.to_out(out), intermediates, k_cache, v_cache
729
+
730
+
731
+ class AttentionLayers(nn.Module):
732
+ def __init__(
733
+ self,
734
+ dim,
735
+ depth,
736
+ heads=8,
737
+ causal=False,
738
+ cross_attend=False,
739
+ only_cross=False,
740
+ use_scalenorm=False,
741
+ use_rms_scaleshift_norm=False,
742
+ use_rmsnorm=False,
743
+ use_rezero=False,
744
+ alibi_pos_bias=False,
745
+ alibi_num_heads=None,
746
+ alibi_learned=False,
747
+ position_infused_attn=False,
748
+ rotary_pos_emb=False,
749
+ rotary_emb_dim=None,
750
+ custom_layers=None,
751
+ sandwich_coef=None,
752
+ par_ratio=None,
753
+ residual_attn=False,
754
+ cross_residual_attn=False,
755
+ macaron=False,
756
+ pre_norm=True,
757
+ gate_residual=False,
758
+ scale_residual=False,
759
+ shift_tokens=0,
760
+ sandwich_norm=False,
761
+ use_qk_norm_attn=False,
762
+ qk_norm_attn_seq_len=None,
763
+ zero_init_branch_output=False,
764
+ **kwargs
765
+ ):
766
+ super().__init__()
767
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
768
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
769
+
770
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
771
+
772
+ self.dim = dim
773
+ self.depth = depth
774
+ self.layers = nn.ModuleList([])
775
+ self.causal = causal
776
+
777
+ rel_pos_bias = 'rel_pos_bias' in attn_kwargs
778
+ self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
779
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
780
+
781
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
782
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
783
+
784
+ assert not (
785
+ alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
786
+
787
+ if alibi_pos_bias:
788
+ alibi_num_heads = default(alibi_num_heads, heads)
789
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
790
+ alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
791
+ self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
792
+ else:
793
+ self.rel_pos = None
794
+
795
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
796
+ self.pre_norm = pre_norm
797
+ self.sandwich_norm = sandwich_norm
798
+
799
+ self.residual_attn = residual_attn
800
+ self.cross_residual_attn = cross_residual_attn
801
+ self.cross_attend = cross_attend
802
+
803
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
804
+ norm_class = RMSNorm if use_rmsnorm else norm_class
805
+ norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
806
+ norm_fn = partial(norm_class, dim)
807
+
808
+ norm_fn = nn.Identity if use_rezero else norm_fn
809
+ branch_fn = Rezero if use_rezero else None
810
+
811
+ if cross_attend and not only_cross:
812
+ default_block = ('a', 'c', 'f')
813
+ elif cross_attend and only_cross:
814
+ default_block = ('c', 'f')
815
+ else:
816
+ default_block = ('a', 'f')
817
+
818
+ if macaron:
819
+ default_block = ('f',) + default_block
820
+
821
+ # qk normalization
822
+
823
+ if use_qk_norm_attn:
824
+ attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
825
+ qk_norm_attn_seq_len) else None
826
+ attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
827
+
828
+ # zero init
829
+
830
+ if zero_init_branch_output:
831
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
832
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
833
+
834
+ # calculate layer block order
835
+
836
+ if exists(custom_layers):
837
+ layer_types = custom_layers
838
+ elif exists(par_ratio):
839
+ par_depth = depth * len(default_block)
840
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
841
+ default_block = tuple(filter(not_equals('f'), default_block))
842
+ par_attn = par_depth // par_ratio
843
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
844
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
845
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
846
+ par_block = default_block + ('f',) * (par_width - len(default_block))
847
+ par_head = par_block * par_attn
848
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
849
+ elif exists(sandwich_coef):
850
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
851
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
852
+ else:
853
+ layer_types = default_block * depth
854
+
855
+ self.layer_types = layer_types
856
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
857
+
858
+ # calculate token shifting
859
+
860
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
861
+
862
+ # iterate and construct layers
863
+
864
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
865
+ is_last_layer = ind == (len(self.layer_types) - 1)
866
+
867
+ if layer_type == 'a':
868
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
869
+ elif layer_type == 'c':
870
+ layer = Attention(dim, heads=heads, **attn_kwargs)
871
+ elif layer_type == 'f':
872
+ layer = FeedForward(dim, **ff_kwargs)
873
+ layer = layer if not macaron else Scale(0.5, layer)
874
+ else:
875
+ raise Exception(f'invalid layer type {layer_type}')
876
+
877
+ if layer_shift_tokens > 0:
878
+ shift_range_upper = layer_shift_tokens + 1
879
+ shift_range_lower = -layer_shift_tokens if not causal else 0
880
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
881
+
882
+ if exists(branch_fn):
883
+ layer = branch_fn(layer)
884
+
885
+ residual_fn = GRUGating if gate_residual else Residual
886
+ residual = residual_fn(dim, scale_residual=scale_residual)
887
+
888
+ layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
889
+
890
+ pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
891
+ post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
892
+ post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
893
+
894
+ norms = nn.ModuleList([
895
+ pre_branch_norm,
896
+ post_branch_norm,
897
+ post_main_norm
898
+ ])
899
+
900
+ self.layers.append(nn.ModuleList([
901
+ norms,
902
+ layer,
903
+ residual
904
+ ]))
905
+
906
+ def forward(
907
+ self,
908
+ x,
909
+ context=None,
910
+ full_context=None, # for passing a list of hidden states from an encoder
911
+ mask=None,
912
+ context_mask=None,
913
+ attn_mask=None,
914
+ mems=None,
915
+ return_hiddens=False,
916
+ norm_scale_shift_inp=None,
917
+ past_key_values=None,
918
+ expected_seq_len=None,
919
+ ):
920
+
921
+ assert not (self.cross_attend ^ (exists(context) or exists(
922
+ full_context))), 'context must be passed in if cross_attend is set to True'
923
+ assert context is None or full_context is None, 'only one of full_context or context can be provided'
924
+
925
+ hiddens = []
926
+ intermediates = []
927
+ prev_attn = None
928
+ prev_cross_attn = None
929
+
930
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
931
+ norm_args = {}
932
+ if exists(norm_scale_shift_inp):
933
+ norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
934
+
935
+ rotary_pos_emb = None
936
+ if exists(self.rotary_pos_emb):
937
+ if not self.training and self.causal:
938
+ assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
939
+ elif expected_seq_len is None:
940
+ expected_seq_len = 0
941
+ seq_len = x.shape[1]
942
+ if past_key_values is not None:
943
+ seq_len += past_key_values[0][0].shape[-2]
944
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
945
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
946
+
947
+ present_key_values = []
948
+ cross_attn_count = 0
949
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
950
+ if layer_type == 'a':
951
+ layer_mem = mems.pop(0) if mems else None
952
+
953
+ residual = x
954
+
955
+ pre_branch_norm, post_branch_norm, post_main_norm = norm
956
+
957
+ if exists(pre_branch_norm):
958
+ x = pre_branch_norm(x, **norm_args)
959
+
960
+ if layer_type == 'a' or layer_type == 'c':
961
+ if past_key_values is not None:
962
+ layer_kv = past_key_values.pop(0)
963
+ layer_past = tuple(s.to(x.device) for s in layer_kv)
964
+ else:
965
+ layer_past = None
966
+
967
+ if layer_type == 'a':
968
+ out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
969
+ prev_attn, layer_mem, layer_past)
970
+ elif layer_type == 'c':
971
+ if exists(full_context):
972
+ out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
973
+ None, prev_attn, None, layer_past)
974
+ else:
975
+ out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
976
+ elif layer_type == 'f':
977
+ out = block(x)
978
+
979
+ if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
980
+ present_key_values.append((k.detach(), v.detach()))
981
+
982
+ if exists(post_branch_norm):
983
+ out = post_branch_norm(out, **norm_args)
984
+
985
+ x = residual_fn(out, residual)
986
+
987
+ if layer_type in ('a', 'c'):
988
+ intermediates.append(inter)
989
+
990
+ if layer_type == 'a' and self.residual_attn:
991
+ prev_attn = inter.pre_softmax_attn
992
+ elif layer_type == 'c' and self.cross_residual_attn:
993
+ prev_cross_attn = inter.pre_softmax_attn
994
+
995
+ if exists(post_main_norm):
996
+ x = post_main_norm(x, **norm_args)
997
+
998
+ if layer_type == 'c':
999
+ cross_attn_count += 1
1000
+
1001
+ if layer_type == 'f':
1002
+ hiddens.append(x)
1003
+
1004
+ if return_hiddens:
1005
+ intermediates = LayerIntermediates(
1006
+ hiddens=hiddens,
1007
+ attn_intermediates=intermediates,
1008
+ past_key_values=present_key_values
1009
+ )
1010
+
1011
+ return x, intermediates
1012
+
1013
+ return x
1014
+
1015
+
1016
+ class Encoder(AttentionLayers):
1017
+ def __init__(self, **kwargs):
1018
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
1019
+ super().__init__(causal=False, **kwargs)
1020
+
1021
+
1022
+ class Decoder(AttentionLayers):
1023
+ def __init__(self, **kwargs):
1024
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
1025
+ super().__init__(causal=True, **kwargs)
1026
+
1027
+
1028
+ class CrossAttender(AttentionLayers):
1029
+ def __init__(self, **kwargs):
1030
+ super().__init__(cross_attend=True, only_cross=True, **kwargs)
1031
+
1032
+
1033
+ class ViTransformerWrapper(nn.Module):
1034
+ def __init__(
1035
+ self,
1036
+ *,
1037
+ image_size,
1038
+ patch_size,
1039
+ attn_layers,
1040
+ num_classes=None,
1041
+ dropout=0.,
1042
+ emb_dropout=0.
1043
+ ):
1044
+ super().__init__()
1045
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
1046
+ assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
1047
+ dim = attn_layers.dim
1048
+ num_patches = (image_size // patch_size) ** 2
1049
+ patch_dim = 3 * patch_size ** 2
1050
+
1051
+ self.patch_size = patch_size
1052
+
1053
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1054
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
1055
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
1056
+ self.dropout = nn.Dropout(emb_dropout)
1057
+
1058
+ self.attn_layers = attn_layers
1059
+ self.norm = nn.LayerNorm(dim)
1060
+ self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
1061
+
1062
+ def forward(
1063
+ self,
1064
+ img,
1065
+ return_embeddings=False
1066
+ ):
1067
+ p = self.patch_size
1068
+
1069
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
1070
+ x = self.patch_to_embedding(x)
1071
+ b, n, _ = x.shape
1072
+
1073
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
1074
+ x = torch.cat((cls_tokens, x), dim=1)
1075
+ x = x + self.pos_embedding[:, :(n + 1)]
1076
+ x = self.dropout(x)
1077
+
1078
+ x = self.attn_layers(x)
1079
+ x = self.norm(x)
1080
+
1081
+ if not exists(self.mlp_head) or return_embeddings:
1082
+ return x
1083
+
1084
+ return self.mlp_head(x[:, 0])
1085
+
1086
+
1087
+ class TransformerWrapper(nn.Module):
1088
+ def __init__(
1089
+ self,
1090
+ *,
1091
+ num_tokens,
1092
+ max_seq_len,
1093
+ attn_layers,
1094
+ emb_dim=None,
1095
+ max_mem_len=0.,
1096
+ shift_mem_down=0,
1097
+ emb_dropout=0.,
1098
+ num_memory_tokens=None,
1099
+ tie_embedding=False,
1100
+ use_pos_emb=True
1101
+ ):
1102
+ super().__init__()
1103
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1104
+
1105
+ dim = attn_layers.dim
1106
+ emb_dim = default(emb_dim, dim)
1107
+
1108
+ self.max_seq_len = max_seq_len
1109
+ self.max_mem_len = max_mem_len
1110
+ self.shift_mem_down = shift_mem_down
1111
+
1112
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
1113
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
1114
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1115
+ self.emb_dropout = nn.Dropout(emb_dropout)
1116
+
1117
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1118
+ self.attn_layers = attn_layers
1119
+ self.norm = nn.LayerNorm(dim)
1120
+
1121
+ self.init_()
1122
+
1123
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
1124
+
1125
+ # memory tokens (like [cls]) from Memory Transformers paper
1126
+ num_memory_tokens = default(num_memory_tokens, 0)
1127
+ self.num_memory_tokens = num_memory_tokens
1128
+ if num_memory_tokens > 0:
1129
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1130
+
1131
+ def init_(self):
1132
+ nn.init.kaiming_normal_(self.token_emb.weight)
1133
+
1134
+ def forward(
1135
+ self,
1136
+ x,
1137
+ return_embeddings=False,
1138
+ mask=None,
1139
+ return_hiddens=False,
1140
+ return_attn=False,
1141
+ mems=None,
1142
+ use_cache=False,
1143
+ **kwargs
1144
+ ):
1145
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
1146
+ x = self.token_emb(x)
1147
+ x = x + self.pos_emb(x)
1148
+ x = self.emb_dropout(x)
1149
+
1150
+ x = self.project_emb(x)
1151
+
1152
+ if num_mem > 0:
1153
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
1154
+ x = torch.cat((mem, x), dim=1)
1155
+
1156
+ # auto-handle masking after appending memory tokens
1157
+ if exists(mask):
1158
+ mask = F.pad(mask, (num_mem, 0), value=True)
1159
+
1160
+ if self.shift_mem_down and exists(mems):
1161
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
1162
+ mems = [*mems_r, *mems_l]
1163
+
1164
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1165
+ x = self.norm(x)
1166
+
1167
+ mem, x = x[:, :num_mem], x[:, num_mem:]
1168
+
1169
+ out = self.to_logits(x) if not return_embeddings else x
1170
+
1171
+ if return_hiddens:
1172
+ hiddens = intermediates.hiddens
1173
+ return out, hiddens
1174
+
1175
+ res = [out]
1176
+ if return_attn:
1177
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1178
+ res.append(attn_maps)
1179
+ if use_cache:
1180
+ res.append(intermediates.past_key_values)
1181
+
1182
+ if len(res) > 1:
1183
+ return tuple(res)
1184
+ return res[0]
1185
+
1186
+
1187
+ class ContinuousTransformerWrapper(nn.Module):
1188
+ def __init__(
1189
+ self,
1190
+ *,
1191
+ max_seq_len,
1192
+ attn_layers,
1193
+ dim_in=None,
1194
+ dim_out=None,
1195
+ emb_dim=None,
1196
+ emb_dropout=0.,
1197
+ use_pos_emb=True
1198
+ ):
1199
+ super().__init__()
1200
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1201
+
1202
+ dim = attn_layers.dim
1203
+
1204
+ self.max_seq_len = max_seq_len
1205
+
1206
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
1207
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1208
+ self.emb_dropout = nn.Dropout(emb_dropout)
1209
+
1210
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1211
+
1212
+ self.attn_layers = attn_layers
1213
+ self.norm = nn.LayerNorm(dim)
1214
+
1215
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1216
+
1217
+ def forward(
1218
+ self,
1219
+ x,
1220
+ return_embeddings=False,
1221
+ mask=None,
1222
+ return_attn=False,
1223
+ mems=None,
1224
+ use_cache=False,
1225
+ **kwargs
1226
+ ):
1227
+ b, n, _, device = *x.shape, x.device
1228
+
1229
+ x = self.project_in(x)
1230
+ x = x + self.pos_emb(x)
1231
+ x = self.emb_dropout(x)
1232
+
1233
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1234
+ x = self.norm(x)
1235
+
1236
+ out = self.project_out(x) if not return_embeddings else x
1237
+
1238
+ res = [out]
1239
+ if return_attn:
1240
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1241
+ res.append(attn_maps)
1242
+ if use_cache:
1243
+ res.append(intermediates.past_key_values)
1244
+
1245
+ if len(res) > 1:
1246
+ return tuple(res)
1247
+ return res[0]
1248
+
tortoise/read.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from time import time
4
+
5
+ import torch
6
+ import torchaudio
7
+
8
+ from api import TextToSpeech, MODELS_DIR
9
+ from utils.audio import load_audio, load_voices
10
+ from utils.text import split_and_recombine_text
11
+
12
+
13
+ if __name__ == '__main__':
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="tortoise/data/riding_hood.txt")
16
+ parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
17
+ 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
18
+ parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
19
+ parser.add_argument('--output_name', type=str, help='How to name the output file', default='combined.wav')
20
+ parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
21
+ parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
22
+ parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1)
23
+ parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
24
+ 'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
25
+ parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
26
+ parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
27
+ parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=False)
28
+ parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True)
29
+ parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True)
30
+
31
+
32
+ args = parser.parse_args()
33
+ if torch.backends.mps.is_available():
34
+ args.use_deepspeed = False
35
+ tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
36
+
37
+ outpath = args.output_path
38
+ outname = args.output_name
39
+ selected_voices = args.voice.split(',')
40
+ regenerate = args.regenerate
41
+ if regenerate is not None:
42
+ regenerate = [int(e) for e in regenerate.split(',')]
43
+
44
+ # Process text
45
+ with open(args.textfile, 'r', encoding='utf-8') as f:
46
+ text = ' '.join([l for l in f.readlines()])
47
+ if '|' in text:
48
+ print("Found the '|' character in your text, which I will use as a cue for where to split it up. If this was not"
49
+ "your intent, please remove all '|' characters from the input.")
50
+ texts = text.split('|')
51
+ else:
52
+ texts = split_and_recombine_text(text)
53
+
54
+ seed = int(time()) if args.seed is None else args.seed
55
+ for selected_voice in selected_voices:
56
+ voice_outpath = os.path.join(outpath, selected_voice)
57
+ os.makedirs(voice_outpath, exist_ok=True)
58
+
59
+ if '&' in selected_voice:
60
+ voice_sel = selected_voice.split('&')
61
+ else:
62
+ voice_sel = [selected_voice]
63
+
64
+ voice_samples, conditioning_latents = load_voices(voice_sel)
65
+ all_parts = []
66
+ for j, text in enumerate(texts):
67
+ if regenerate is not None and j not in regenerate:
68
+ all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000))
69
+ continue
70
+ gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
71
+ preset=args.preset, k=args.candidates, use_deterministic_seed=seed)
72
+ if args.candidates == 1:
73
+ audio_ = gen.squeeze(0).cpu()
74
+ torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), audio_, 24000)
75
+ else:
76
+ candidate_dir = os.path.join(voice_outpath, str(j))
77
+ os.makedirs(candidate_dir, exist_ok=True)
78
+ for k, g in enumerate(gen):
79
+ torchaudio.save(os.path.join(candidate_dir, f'{k}.wav'), g.squeeze(0).cpu(), 24000)
80
+ audio_ = gen[0].squeeze(0).cpu()
81
+ all_parts.append(audio_)
82
+
83
+ if args.candidates == 1:
84
+ full_audio = torch.cat(all_parts, dim=-1)
85
+ torchaudio.save(os.path.join(voice_outpath, f"{outname}.wav"), full_audio, 24000)
86
+
87
+ if args.produce_debug_state:
88
+ os.makedirs('debug_states', exist_ok=True)
89
+ dbg_state = (seed, texts, voice_samples, conditioning_latents)
90
+ torch.save(dbg_state, f'debug_states/read_debug_{selected_voice}.pth')
91
+
92
+ # Combine each candidate's audio clips.
93
+ if args.candidates > 1:
94
+ audio_clips = []
95
+ for candidate in range(args.candidates):
96
+ for line in range(len(texts)):
97
+ wav_file = os.path.join(voice_outpath, str(line), f"{candidate}.wav")
98
+ audio_clips.append(load_audio(wav_file, 24000))
99
+ audio_clips = torch.cat(audio_clips, dim=-1)
100
+ torchaudio.save(os.path.join(voice_outpath, f"{outname}_{candidate:02d}.wav"), audio_clips, 24000)
101
+ audio_clips = []
tortoise/utils/__init__.py ADDED
File without changes
tortoise/utils/audio.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+
4
+ import librosa
5
+ import torch
6
+ import torchaudio
7
+ import numpy as np
8
+ from scipy.io.wavfile import read
9
+
10
+ from tortoise.utils.stft import STFT
11
+
12
+
13
+ BUILTIN_VOICES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../voices')
14
+
15
+
16
+ def load_wav_to_torch(full_path):
17
+ sampling_rate, data = read(full_path)
18
+ if data.dtype == np.int32:
19
+ norm_fix = 2 ** 31
20
+ elif data.dtype == np.int16:
21
+ norm_fix = 2 ** 15
22
+ elif data.dtype == np.float16 or data.dtype == np.float32:
23
+ norm_fix = 1.
24
+ else:
25
+ raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
26
+ return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
27
+
28
+
29
+ def load_audio(audiopath, sampling_rate):
30
+ if audiopath[-4:] == '.wav':
31
+ audio, lsr = load_wav_to_torch(audiopath)
32
+ elif audiopath[-4:] == '.mp3':
33
+ audio, lsr = librosa.load(audiopath, sr=sampling_rate)
34
+ audio = torch.FloatTensor(audio)
35
+ else:
36
+ assert False, f"Unsupported audio format provided: {audiopath[-4:]}"
37
+
38
+ # Remove any channel data.
39
+ if len(audio.shape) > 1:
40
+ if audio.shape[0] < 5:
41
+ audio = audio[0]
42
+ else:
43
+ assert audio.shape[1] < 5
44
+ audio = audio[:, 0]
45
+
46
+ if lsr != sampling_rate:
47
+ audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
48
+
49
+ # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
50
+ # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
51
+ if torch.any(audio > 2) or not torch.any(audio < 0):
52
+ print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
53
+ audio.clip_(-1, 1)
54
+
55
+ return audio.unsqueeze(0)
56
+
57
+
58
+ TACOTRON_MEL_MAX = 2.3143386840820312
59
+ TACOTRON_MEL_MIN = -11.512925148010254
60
+
61
+
62
+ def denormalize_tacotron_mel(norm_mel):
63
+ return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
64
+
65
+
66
+ def normalize_tacotron_mel(mel):
67
+ return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
68
+
69
+
70
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
71
+ """
72
+ PARAMS
73
+ ------
74
+ C: compression factor
75
+ """
76
+ return torch.log(torch.clamp(x, min=clip_val) * C)
77
+
78
+
79
+ def dynamic_range_decompression(x, C=1):
80
+ """
81
+ PARAMS
82
+ ------
83
+ C: compression factor used to compress
84
+ """
85
+ return torch.exp(x) / C
86
+
87
+
88
+ def get_voices(extra_voice_dirs=[]):
89
+ dirs = [BUILTIN_VOICES_DIR] + extra_voice_dirs
90
+ voices = {}
91
+ for d in dirs:
92
+ subs = os.listdir(d)
93
+ for sub in subs:
94
+ subj = os.path.join(d, sub)
95
+ if os.path.isdir(subj):
96
+ voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.pth'))
97
+ return voices
98
+
99
+
100
+ def load_voice(voice, extra_voice_dirs=[]):
101
+ if voice == 'random':
102
+ return None, None
103
+
104
+ voices = get_voices(extra_voice_dirs)
105
+ paths = voices[voice]
106
+ if len(paths) == 1 and paths[0].endswith('.pth'):
107
+ return None, torch.load(paths[0])
108
+ else:
109
+ conds = []
110
+ for cond_path in paths:
111
+ c = load_audio(cond_path, 22050)
112
+ conds.append(c)
113
+ return conds, None
114
+
115
+
116
+ def load_voices(voices, extra_voice_dirs=[]):
117
+ latents = []
118
+ clips = []
119
+ for voice in voices:
120
+ if voice == 'random':
121
+ if len(voices) > 1:
122
+ print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
123
+ return None, None
124
+ clip, latent = load_voice(voice, extra_voice_dirs)
125
+ if latent is None:
126
+ assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
127
+ clips.extend(clip)
128
+ elif clip is None:
129
+ assert len(clips) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
130
+ latents.append(latent)
131
+ if len(latents) == 0:
132
+ return clips, None
133
+ else:
134
+ latents_0 = torch.stack([l[0] for l in latents], dim=0).mean(dim=0)
135
+ latents_1 = torch.stack([l[1] for l in latents], dim=0).mean(dim=0)
136
+ latents = (latents_0,latents_1)
137
+ return None, latents
138
+
139
+
140
+ class TacotronSTFT(torch.nn.Module):
141
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
142
+ n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
143
+ mel_fmax=8000.0):
144
+ super(TacotronSTFT, self).__init__()
145
+ self.n_mel_channels = n_mel_channels
146
+ self.sampling_rate = sampling_rate
147
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
148
+ from librosa.filters import mel as librosa_mel_fn
149
+ mel_basis = librosa_mel_fn(
150
+ sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax)
151
+ mel_basis = torch.from_numpy(mel_basis).float()
152
+ self.register_buffer('mel_basis', mel_basis)
153
+
154
+ def spectral_normalize(self, magnitudes):
155
+ output = dynamic_range_compression(magnitudes)
156
+ return output
157
+
158
+ def spectral_de_normalize(self, magnitudes):
159
+ output = dynamic_range_decompression(magnitudes)
160
+ return output
161
+
162
+ def mel_spectrogram(self, y):
163
+ """Computes mel-spectrograms from a batch of waves
164
+ PARAMS
165
+ ------
166
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
167
+
168
+ RETURNS
169
+ -------
170
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
171
+ """
172
+ assert(torch.min(y.data) >= -10)
173
+ assert(torch.max(y.data) <= 10)
174
+ y = torch.clip(y, min=-1, max=1)
175
+
176
+ magnitudes, phases = self.stft_fn.transform(y)
177
+ magnitudes = magnitudes.data
178
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
179
+ mel_output = self.spectral_normalize(mel_output)
180
+ return mel_output
181
+
182
+
183
+ def wav_to_univnet_mel(wav, do_normalization=False, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
184
+ stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
185
+ stft = stft.to(device)
186
+ mel = stft.mel_spectrogram(wav)
187
+ if do_normalization:
188
+ mel = normalize_tacotron_mel(mel)
189
+ return mel
tortoise/utils/diffusion.py ADDED
@@ -0,0 +1,1250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself:
3
+
4
+ This code started out as a PyTorch port of Ho et al's diffusion models:
5
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
6
+
7
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
8
+ """
9
+
10
+ import enum
11
+ import math
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch as th
16
+ from tqdm import tqdm
17
+
18
+
19
+ def normal_kl(mean1, logvar1, mean2, logvar2):
20
+ """
21
+ Compute the KL divergence between two gaussians.
22
+
23
+ Shapes are automatically broadcasted, so batches can be compared to
24
+ scalars, among other use cases.
25
+ """
26
+ tensor = None
27
+ for obj in (mean1, logvar1, mean2, logvar2):
28
+ if isinstance(obj, th.Tensor):
29
+ tensor = obj
30
+ break
31
+ assert tensor is not None, "at least one argument must be a Tensor"
32
+
33
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
34
+ # Tensors, but it does not work for th.exp().
35
+ logvar1, logvar2 = [
36
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
37
+ for x in (logvar1, logvar2)
38
+ ]
39
+
40
+ return 0.5 * (
41
+ -1.0
42
+ + logvar2
43
+ - logvar1
44
+ + th.exp(logvar1 - logvar2)
45
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
46
+ )
47
+
48
+
49
+ def approx_standard_normal_cdf(x):
50
+ """
51
+ A fast approximation of the cumulative distribution function of the
52
+ standard normal.
53
+ """
54
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
55
+
56
+
57
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
58
+ """
59
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
60
+ given image.
61
+
62
+ :param x: the target images. It is assumed that this was uint8 values,
63
+ rescaled to the range [-1, 1].
64
+ :param means: the Gaussian mean Tensor.
65
+ :param log_scales: the Gaussian log stddev Tensor.
66
+ :return: a tensor like x of log probabilities (in nats).
67
+ """
68
+ assert x.shape == means.shape == log_scales.shape
69
+ centered_x = x - means
70
+ inv_stdv = th.exp(-log_scales)
71
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
72
+ cdf_plus = approx_standard_normal_cdf(plus_in)
73
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
74
+ cdf_min = approx_standard_normal_cdf(min_in)
75
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
76
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
77
+ cdf_delta = cdf_plus - cdf_min
78
+ log_probs = th.where(
79
+ x < -0.999,
80
+ log_cdf_plus,
81
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
82
+ )
83
+ assert log_probs.shape == x.shape
84
+ return log_probs
85
+
86
+
87
+ def mean_flat(tensor):
88
+ """
89
+ Take the mean over all non-batch dimensions.
90
+ """
91
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
92
+
93
+
94
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
95
+ """
96
+ Get a pre-defined beta schedule for the given name.
97
+
98
+ The beta schedule library consists of beta schedules which remain similar
99
+ in the limit of num_diffusion_timesteps.
100
+ Beta schedules may be added, but should not be removed or changed once
101
+ they are committed to maintain backwards compatibility.
102
+ """
103
+ if schedule_name == "linear":
104
+ # Linear schedule from Ho et al, extended to work for any number of
105
+ # diffusion steps.
106
+ scale = 1000 / num_diffusion_timesteps
107
+ beta_start = scale * 0.0001
108
+ beta_end = scale * 0.02
109
+ return np.linspace(
110
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
111
+ )
112
+ elif schedule_name == "cosine":
113
+ return betas_for_alpha_bar(
114
+ num_diffusion_timesteps,
115
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
116
+ )
117
+ else:
118
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
119
+
120
+
121
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
122
+ """
123
+ Create a beta schedule that discretizes the given alpha_t_bar function,
124
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
125
+
126
+ :param num_diffusion_timesteps: the number of betas to produce.
127
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
128
+ produces the cumulative product of (1-beta) up to that
129
+ part of the diffusion process.
130
+ :param max_beta: the maximum beta to use; use values lower than 1 to
131
+ prevent singularities.
132
+ """
133
+ betas = []
134
+ for i in range(num_diffusion_timesteps):
135
+ t1 = i / num_diffusion_timesteps
136
+ t2 = (i + 1) / num_diffusion_timesteps
137
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
138
+ return np.array(betas)
139
+
140
+
141
+ class ModelMeanType(enum.Enum):
142
+ """
143
+ Which type of output the model predicts.
144
+ """
145
+
146
+ PREVIOUS_X = 'previous_x' # the model predicts x_{t-1}
147
+ START_X = 'start_x' # the model predicts x_0
148
+ EPSILON = 'epsilon' # the model predicts epsilon
149
+
150
+
151
+ class ModelVarType(enum.Enum):
152
+ """
153
+ What is used as the model's output variance.
154
+
155
+ The LEARNED_RANGE option has been added to allow the model to predict
156
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
157
+ """
158
+
159
+ LEARNED = 'learned'
160
+ FIXED_SMALL = 'fixed_small'
161
+ FIXED_LARGE = 'fixed_large'
162
+ LEARNED_RANGE = 'learned_range'
163
+
164
+
165
+ class LossType(enum.Enum):
166
+ MSE = 'mse' # use raw MSE loss (and KL when learning variances)
167
+ RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances)
168
+ KL = 'kl' # use the variational lower-bound
169
+ RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB
170
+
171
+ def is_vb(self):
172
+ return self == LossType.KL or self == LossType.RESCALED_KL
173
+
174
+
175
+ class GaussianDiffusion:
176
+ """
177
+ Utilities for training and sampling diffusion models.
178
+
179
+ Ported directly from here, and then adapted over time to further experimentation.
180
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
181
+
182
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
183
+ starting at T and going to 1.
184
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
185
+ :param model_var_type: a ModelVarType determining how variance is output.
186
+ :param loss_type: a LossType determining the loss function to use.
187
+ :param rescale_timesteps: if True, pass floating point timesteps into the
188
+ model so that they are always scaled like in the
189
+ original paper (0 to 1000).
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ *,
195
+ betas,
196
+ model_mean_type,
197
+ model_var_type,
198
+ loss_type,
199
+ rescale_timesteps=False,
200
+ conditioning_free=False,
201
+ conditioning_free_k=1,
202
+ ramp_conditioning_free=True,
203
+ ):
204
+ self.model_mean_type = ModelMeanType(model_mean_type)
205
+ self.model_var_type = ModelVarType(model_var_type)
206
+ self.loss_type = LossType(loss_type)
207
+ self.rescale_timesteps = rescale_timesteps
208
+ self.conditioning_free = conditioning_free
209
+ self.conditioning_free_k = conditioning_free_k
210
+ self.ramp_conditioning_free = ramp_conditioning_free
211
+
212
+ # Use float64 for accuracy.
213
+ betas = np.array(betas, dtype=np.float64)
214
+ self.betas = betas
215
+ assert len(betas.shape) == 1, "betas must be 1-D"
216
+ assert (betas > 0).all() and (betas <= 1).all()
217
+
218
+ self.num_timesteps = int(betas.shape[0])
219
+
220
+ alphas = 1.0 - betas
221
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
222
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
223
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
224
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
225
+
226
+ # calculations for diffusion q(x_t | x_{t-1}) and others
227
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
228
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
229
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
230
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
231
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
232
+
233
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
234
+ self.posterior_variance = (
235
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
236
+ )
237
+ # log calculation clipped because the posterior variance is 0 at the
238
+ # beginning of the diffusion chain.
239
+ self.posterior_log_variance_clipped = np.log(
240
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
241
+ )
242
+ self.posterior_mean_coef1 = (
243
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
244
+ )
245
+ self.posterior_mean_coef2 = (
246
+ (1.0 - self.alphas_cumprod_prev)
247
+ * np.sqrt(alphas)
248
+ / (1.0 - self.alphas_cumprod)
249
+ )
250
+
251
+ def q_mean_variance(self, x_start, t):
252
+ """
253
+ Get the distribution q(x_t | x_0).
254
+
255
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
256
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
257
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
258
+ """
259
+ mean = (
260
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
261
+ )
262
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
263
+ log_variance = _extract_into_tensor(
264
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
265
+ )
266
+ return mean, variance, log_variance
267
+
268
+ def q_sample(self, x_start, t, noise=None):
269
+ """
270
+ Diffuse the data for a given number of diffusion steps.
271
+
272
+ In other words, sample from q(x_t | x_0).
273
+
274
+ :param x_start: the initial data batch.
275
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
276
+ :param noise: if specified, the split-out normal noise.
277
+ :return: A noisy version of x_start.
278
+ """
279
+ if noise is None:
280
+ noise = th.randn_like(x_start)
281
+ assert noise.shape == x_start.shape
282
+ return (
283
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
284
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
285
+ * noise
286
+ )
287
+
288
+ def q_posterior_mean_variance(self, x_start, x_t, t):
289
+ """
290
+ Compute the mean and variance of the diffusion posterior:
291
+
292
+ q(x_{t-1} | x_t, x_0)
293
+
294
+ """
295
+ assert x_start.shape == x_t.shape
296
+ posterior_mean = (
297
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
298
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
299
+ )
300
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
301
+ posterior_log_variance_clipped = _extract_into_tensor(
302
+ self.posterior_log_variance_clipped, t, x_t.shape
303
+ )
304
+ assert (
305
+ posterior_mean.shape[0]
306
+ == posterior_variance.shape[0]
307
+ == posterior_log_variance_clipped.shape[0]
308
+ == x_start.shape[0]
309
+ )
310
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
311
+
312
+ def p_mean_variance(
313
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
314
+ ):
315
+ """
316
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
317
+ the initial x, x_0.
318
+
319
+ :param model: the model, which takes a signal and a batch of timesteps
320
+ as input.
321
+ :param x: the [N x C x ...] tensor at time t.
322
+ :param t: a 1-D Tensor of timesteps.
323
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
324
+ :param denoised_fn: if not None, a function which applies to the
325
+ x_start prediction before it is used to sample. Applies before
326
+ clip_denoised.
327
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
328
+ pass to the model. This can be used for conditioning.
329
+ :return: a dict with the following keys:
330
+ - 'mean': the model mean output.
331
+ - 'variance': the model variance output.
332
+ - 'log_variance': the log of 'variance'.
333
+ - 'pred_xstart': the prediction for x_0.
334
+ """
335
+ if model_kwargs is None:
336
+ model_kwargs = {}
337
+
338
+ B, C = x.shape[:2]
339
+ assert t.shape == (B,)
340
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
341
+ if self.conditioning_free:
342
+ model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
343
+
344
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
345
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
346
+ model_output, model_var_values = th.split(model_output, C, dim=1)
347
+ if self.conditioning_free:
348
+ model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
349
+ if self.model_var_type == ModelVarType.LEARNED:
350
+ model_log_variance = model_var_values
351
+ model_variance = th.exp(model_log_variance)
352
+ else:
353
+ min_log = _extract_into_tensor(
354
+ self.posterior_log_variance_clipped, t, x.shape
355
+ )
356
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
357
+ # The model_var_values is [-1, 1] for [min_var, max_var].
358
+ frac = (model_var_values + 1) / 2
359
+ model_log_variance = frac * max_log + (1 - frac) * min_log
360
+ model_variance = th.exp(model_log_variance)
361
+ else:
362
+ model_variance, model_log_variance = {
363
+ # for fixedlarge, we set the initial (log-)variance like so
364
+ # to get a better decoder log likelihood.
365
+ ModelVarType.FIXED_LARGE: (
366
+ np.append(self.posterior_variance[1], self.betas[1:]),
367
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
368
+ ),
369
+ ModelVarType.FIXED_SMALL: (
370
+ self.posterior_variance,
371
+ self.posterior_log_variance_clipped,
372
+ ),
373
+ }[self.model_var_type]
374
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
375
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
376
+
377
+ if self.conditioning_free:
378
+ if self.ramp_conditioning_free:
379
+ assert t.shape[0] == 1 # This should only be used in inference.
380
+ cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
381
+ else:
382
+ cfk = self.conditioning_free_k
383
+ model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
384
+
385
+ def process_xstart(x):
386
+ if denoised_fn is not None:
387
+ x = denoised_fn(x)
388
+ if clip_denoised:
389
+ return x.clamp(-1, 1)
390
+ return x
391
+
392
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
393
+ pred_xstart = process_xstart(
394
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
395
+ )
396
+ model_mean = model_output
397
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
398
+ if self.model_mean_type == ModelMeanType.START_X:
399
+ pred_xstart = process_xstart(model_output)
400
+ else:
401
+ pred_xstart = process_xstart(
402
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
403
+ )
404
+ model_mean, _, _ = self.q_posterior_mean_variance(
405
+ x_start=pred_xstart, x_t=x, t=t
406
+ )
407
+ else:
408
+ raise NotImplementedError(self.model_mean_type)
409
+
410
+ assert (
411
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
412
+ )
413
+ return {
414
+ "mean": model_mean,
415
+ "variance": model_variance,
416
+ "log_variance": model_log_variance,
417
+ "pred_xstart": pred_xstart,
418
+ }
419
+
420
+ def _predict_xstart_from_eps(self, x_t, t, eps):
421
+ assert x_t.shape == eps.shape
422
+ return (
423
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
424
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
425
+ )
426
+
427
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
428
+ assert x_t.shape == xprev.shape
429
+ return ( # (xprev - coef2*x_t) / coef1
430
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
431
+ - _extract_into_tensor(
432
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
433
+ )
434
+ * x_t
435
+ )
436
+
437
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
438
+ return (
439
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
440
+ - pred_xstart
441
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
442
+
443
+ def _scale_timesteps(self, t):
444
+ if self.rescale_timesteps:
445
+ return t.float() * (1000.0 / self.num_timesteps)
446
+ return t
447
+
448
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
449
+ """
450
+ Compute the mean for the previous step, given a function cond_fn that
451
+ computes the gradient of a conditional log probability with respect to
452
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
453
+ condition on y.
454
+
455
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
456
+ """
457
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
458
+ new_mean = (
459
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
460
+ )
461
+ return new_mean
462
+
463
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
464
+ """
465
+ Compute what the p_mean_variance output would have been, should the
466
+ model's score function be conditioned by cond_fn.
467
+
468
+ See condition_mean() for details on cond_fn.
469
+
470
+ Unlike condition_mean(), this instead uses the conditioning strategy
471
+ from Song et al (2020).
472
+ """
473
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
474
+
475
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
476
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
477
+ x, self._scale_timesteps(t), **model_kwargs
478
+ )
479
+
480
+ out = p_mean_var.copy()
481
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
482
+ out["mean"], _, _ = self.q_posterior_mean_variance(
483
+ x_start=out["pred_xstart"], x_t=x, t=t
484
+ )
485
+ return out
486
+
487
+ def p_sample(
488
+ self,
489
+ model,
490
+ x,
491
+ t,
492
+ clip_denoised=True,
493
+ denoised_fn=None,
494
+ cond_fn=None,
495
+ model_kwargs=None,
496
+ ):
497
+ """
498
+ Sample x_{t-1} from the model at the given timestep.
499
+
500
+ :param model: the model to sample from.
501
+ :param x: the current tensor at x_{t-1}.
502
+ :param t: the value of t, starting at 0 for the first diffusion step.
503
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
504
+ :param denoised_fn: if not None, a function which applies to the
505
+ x_start prediction before it is used to sample.
506
+ :param cond_fn: if not None, this is a gradient function that acts
507
+ similarly to the model.
508
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
509
+ pass to the model. This can be used for conditioning.
510
+ :return: a dict containing the following keys:
511
+ - 'sample': a random sample from the model.
512
+ - 'pred_xstart': a prediction of x_0.
513
+ """
514
+ out = self.p_mean_variance(
515
+ model,
516
+ x,
517
+ t,
518
+ clip_denoised=clip_denoised,
519
+ denoised_fn=denoised_fn,
520
+ model_kwargs=model_kwargs,
521
+ )
522
+ noise = th.randn_like(x)
523
+ nonzero_mask = (
524
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
525
+ ) # no noise when t == 0
526
+ if cond_fn is not None:
527
+ out["mean"] = self.condition_mean(
528
+ cond_fn, out, x, t, model_kwargs=model_kwargs
529
+ )
530
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
531
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
532
+
533
+ def p_sample_loop(
534
+ self,
535
+ model,
536
+ shape,
537
+ noise=None,
538
+ clip_denoised=True,
539
+ denoised_fn=None,
540
+ cond_fn=None,
541
+ model_kwargs=None,
542
+ device=None,
543
+ progress=False,
544
+ ):
545
+ """
546
+ Generate samples from the model.
547
+
548
+ :param model: the model module.
549
+ :param shape: the shape of the samples, (N, C, H, W).
550
+ :param noise: if specified, the noise from the encoder to sample.
551
+ Should be of the same shape as `shape`.
552
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
553
+ :param denoised_fn: if not None, a function which applies to the
554
+ x_start prediction before it is used to sample.
555
+ :param cond_fn: if not None, this is a gradient function that acts
556
+ similarly to the model.
557
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
558
+ pass to the model. This can be used for conditioning.
559
+ :param device: if specified, the device to create the samples on.
560
+ If not specified, use a model parameter's device.
561
+ :param progress: if True, show a tqdm progress bar.
562
+ :return: a non-differentiable batch of samples.
563
+ """
564
+ final = None
565
+ for sample in self.p_sample_loop_progressive(
566
+ model,
567
+ shape,
568
+ noise=noise,
569
+ clip_denoised=clip_denoised,
570
+ denoised_fn=denoised_fn,
571
+ cond_fn=cond_fn,
572
+ model_kwargs=model_kwargs,
573
+ device=device,
574
+ progress=progress,
575
+ ):
576
+ final = sample
577
+ return final["sample"]
578
+
579
+ def p_sample_loop_progressive(
580
+ self,
581
+ model,
582
+ shape,
583
+ noise=None,
584
+ clip_denoised=True,
585
+ denoised_fn=None,
586
+ cond_fn=None,
587
+ model_kwargs=None,
588
+ device=None,
589
+ progress=False,
590
+ ):
591
+ """
592
+ Generate samples from the model and yield intermediate samples from
593
+ each timestep of diffusion.
594
+
595
+ Arguments are the same as p_sample_loop().
596
+ Returns a generator over dicts, where each dict is the return value of
597
+ p_sample().
598
+ """
599
+ if device is None:
600
+ device = next(model.parameters()).device
601
+ assert isinstance(shape, (tuple, list))
602
+ if noise is not None:
603
+ img = noise
604
+ else:
605
+ img = th.randn(*shape, device=device)
606
+ indices = list(range(self.num_timesteps))[::-1]
607
+
608
+ for i in tqdm(indices, disable=not progress):
609
+ t = th.tensor([i] * shape[0], device=device)
610
+ with th.no_grad():
611
+ out = self.p_sample(
612
+ model,
613
+ img,
614
+ t,
615
+ clip_denoised=clip_denoised,
616
+ denoised_fn=denoised_fn,
617
+ cond_fn=cond_fn,
618
+ model_kwargs=model_kwargs,
619
+ )
620
+ yield out
621
+ img = out["sample"]
622
+
623
+ def ddim_sample(
624
+ self,
625
+ model,
626
+ x,
627
+ t,
628
+ clip_denoised=True,
629
+ denoised_fn=None,
630
+ cond_fn=None,
631
+ model_kwargs=None,
632
+ eta=0.0,
633
+ ):
634
+ """
635
+ Sample x_{t-1} from the model using DDIM.
636
+
637
+ Same usage as p_sample().
638
+ """
639
+ out = self.p_mean_variance(
640
+ model,
641
+ x,
642
+ t,
643
+ clip_denoised=clip_denoised,
644
+ denoised_fn=denoised_fn,
645
+ model_kwargs=model_kwargs,
646
+ )
647
+ if cond_fn is not None:
648
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
649
+
650
+ # Usually our model outputs epsilon, but we re-derive it
651
+ # in case we used x_start or x_prev prediction.
652
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
653
+
654
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
655
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
656
+ sigma = (
657
+ eta
658
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
659
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
660
+ )
661
+ # Equation 12.
662
+ noise = th.randn_like(x)
663
+ mean_pred = (
664
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
665
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
666
+ )
667
+ nonzero_mask = (
668
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
669
+ ) # no noise when t == 0
670
+ sample = mean_pred + nonzero_mask * sigma * noise
671
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
672
+
673
+ def ddim_reverse_sample(
674
+ self,
675
+ model,
676
+ x,
677
+ t,
678
+ clip_denoised=True,
679
+ denoised_fn=None,
680
+ model_kwargs=None,
681
+ eta=0.0,
682
+ ):
683
+ """
684
+ Sample x_{t+1} from the model using DDIM reverse ODE.
685
+ """
686
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
687
+ out = self.p_mean_variance(
688
+ model,
689
+ x,
690
+ t,
691
+ clip_denoised=clip_denoised,
692
+ denoised_fn=denoised_fn,
693
+ model_kwargs=model_kwargs,
694
+ )
695
+ # Usually our model outputs epsilon, but we re-derive it
696
+ # in case we used x_start or x_prev prediction.
697
+ eps = (
698
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
699
+ - out["pred_xstart"]
700
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
701
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
702
+
703
+ # Equation 12. reversed
704
+ mean_pred = (
705
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
706
+ + th.sqrt(1 - alpha_bar_next) * eps
707
+ )
708
+
709
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
710
+
711
+ def ddim_sample_loop(
712
+ self,
713
+ model,
714
+ shape,
715
+ noise=None,
716
+ clip_denoised=True,
717
+ denoised_fn=None,
718
+ cond_fn=None,
719
+ model_kwargs=None,
720
+ device=None,
721
+ progress=False,
722
+ eta=0.0,
723
+ ):
724
+ """
725
+ Generate samples from the model using DDIM.
726
+
727
+ Same usage as p_sample_loop().
728
+ """
729
+ final = None
730
+ for sample in self.ddim_sample_loop_progressive(
731
+ model,
732
+ shape,
733
+ noise=noise,
734
+ clip_denoised=clip_denoised,
735
+ denoised_fn=denoised_fn,
736
+ cond_fn=cond_fn,
737
+ model_kwargs=model_kwargs,
738
+ device=device,
739
+ progress=progress,
740
+ eta=eta,
741
+ ):
742
+ final = sample
743
+ return final["sample"]
744
+
745
+ def ddim_sample_loop_progressive(
746
+ self,
747
+ model,
748
+ shape,
749
+ noise=None,
750
+ clip_denoised=True,
751
+ denoised_fn=None,
752
+ cond_fn=None,
753
+ model_kwargs=None,
754
+ device=None,
755
+ progress=False,
756
+ eta=0.0,
757
+ ):
758
+ """
759
+ Use DDIM to sample from the model and yield intermediate samples from
760
+ each timestep of DDIM.
761
+
762
+ Same usage as p_sample_loop_progressive().
763
+ """
764
+ if device is None:
765
+ device = next(model.parameters()).device
766
+ assert isinstance(shape, (tuple, list))
767
+ if noise is not None:
768
+ img = noise
769
+ else:
770
+ img = th.randn(*shape, device=device)
771
+ indices = list(range(self.num_timesteps))[::-1]
772
+
773
+ if progress:
774
+ # Lazy import so that we don't depend on tqdm.
775
+ from tqdm.auto import tqdm
776
+
777
+ indices = tqdm(indices, disable=not progress)
778
+
779
+ for i in indices:
780
+ t = th.tensor([i] * shape[0], device=device)
781
+ with th.no_grad():
782
+ out = self.ddim_sample(
783
+ model,
784
+ img,
785
+ t,
786
+ clip_denoised=clip_denoised,
787
+ denoised_fn=denoised_fn,
788
+ cond_fn=cond_fn,
789
+ model_kwargs=model_kwargs,
790
+ eta=eta,
791
+ )
792
+ yield out
793
+ img = out["sample"]
794
+
795
+ def _vb_terms_bpd(
796
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
797
+ ):
798
+ """
799
+ Get a term for the variational lower-bound.
800
+
801
+ The resulting units are bits (rather than nats, as one might expect).
802
+ This allows for comparison to other papers.
803
+
804
+ :return: a dict with the following keys:
805
+ - 'output': a shape [N] tensor of NLLs or KLs.
806
+ - 'pred_xstart': the x_0 predictions.
807
+ """
808
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
809
+ x_start=x_start, x_t=x_t, t=t
810
+ )
811
+ out = self.p_mean_variance(
812
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
813
+ )
814
+ kl = normal_kl(
815
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
816
+ )
817
+ kl = mean_flat(kl) / np.log(2.0)
818
+
819
+ decoder_nll = -discretized_gaussian_log_likelihood(
820
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
821
+ )
822
+ assert decoder_nll.shape == x_start.shape
823
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
824
+
825
+ # At the first timestep return the decoder NLL,
826
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
827
+ output = th.where((t == 0), decoder_nll, kl)
828
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
829
+
830
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
831
+ """
832
+ Compute training losses for a single timestep.
833
+
834
+ :param model: the model to evaluate loss on.
835
+ :param x_start: the [N x C x ...] tensor of inputs.
836
+ :param t: a batch of timestep indices.
837
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
838
+ pass to the model. This can be used for conditioning.
839
+ :param noise: if specified, the specific Gaussian noise to try to remove.
840
+ :return: a dict with the key "loss" containing a tensor of shape [N].
841
+ Some mean or variance settings may also have other keys.
842
+ """
843
+ if model_kwargs is None:
844
+ model_kwargs = {}
845
+ if noise is None:
846
+ noise = th.randn_like(x_start)
847
+ x_t = self.q_sample(x_start, t, noise=noise)
848
+
849
+ terms = {}
850
+
851
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
852
+ # TODO: support multiple model outputs for this mode.
853
+ terms["loss"] = self._vb_terms_bpd(
854
+ model=model,
855
+ x_start=x_start,
856
+ x_t=x_t,
857
+ t=t,
858
+ clip_denoised=False,
859
+ model_kwargs=model_kwargs,
860
+ )["output"]
861
+ if self.loss_type == LossType.RESCALED_KL:
862
+ terms["loss"] *= self.num_timesteps
863
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
864
+ model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
865
+ if isinstance(model_outputs, tuple):
866
+ model_output = model_outputs[0]
867
+ terms['extra_outputs'] = model_outputs[1:]
868
+ else:
869
+ model_output = model_outputs
870
+
871
+ if self.model_var_type in [
872
+ ModelVarType.LEARNED,
873
+ ModelVarType.LEARNED_RANGE,
874
+ ]:
875
+ B, C = x_t.shape[:2]
876
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
877
+ model_output, model_var_values = th.split(model_output, C, dim=1)
878
+ # Learn the variance using the variational bound, but don't let
879
+ # it affect our mean prediction.
880
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
881
+ terms["vb"] = self._vb_terms_bpd(
882
+ model=lambda *args, r=frozen_out: r,
883
+ x_start=x_start,
884
+ x_t=x_t,
885
+ t=t,
886
+ clip_denoised=False,
887
+ )["output"]
888
+ if self.loss_type == LossType.RESCALED_MSE:
889
+ # Divide by 1000 for equivalence with initial implementation.
890
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
891
+ terms["vb"] *= self.num_timesteps / 1000.0
892
+
893
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
894
+ target = self.q_posterior_mean_variance(
895
+ x_start=x_start, x_t=x_t, t=t
896
+ )[0]
897
+ x_start_pred = torch.zeros(x_start) # Not supported.
898
+ elif self.model_mean_type == ModelMeanType.START_X:
899
+ target = x_start
900
+ x_start_pred = model_output
901
+ elif self.model_mean_type == ModelMeanType.EPSILON:
902
+ target = noise
903
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
904
+ else:
905
+ raise NotImplementedError(self.model_mean_type)
906
+ assert model_output.shape == target.shape == x_start.shape
907
+ terms["mse"] = mean_flat((target - model_output) ** 2)
908
+ terms["x_start_predicted"] = x_start_pred
909
+ if "vb" in terms:
910
+ terms["loss"] = terms["mse"] + terms["vb"]
911
+ else:
912
+ terms["loss"] = terms["mse"]
913
+ else:
914
+ raise NotImplementedError(self.loss_type)
915
+
916
+ return terms
917
+
918
+ def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
919
+ """
920
+ Compute training losses for a single timestep.
921
+
922
+ :param model: the model to evaluate loss on.
923
+ :param x_start: the [N x C x ...] tensor of inputs.
924
+ :param t: a batch of timestep indices.
925
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
926
+ pass to the model. This can be used for conditioning.
927
+ :param noise: if specified, the specific Gaussian noise to try to remove.
928
+ :return: a dict with the key "loss" containing a tensor of shape [N].
929
+ Some mean or variance settings may also have other keys.
930
+ """
931
+ if model_kwargs is None:
932
+ model_kwargs = {}
933
+ if noise is None:
934
+ noise = th.randn_like(x_start)
935
+ x_t = self.q_sample(x_start, t, noise=noise)
936
+ terms = {}
937
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
938
+ assert False # not currently supported for this type of diffusion.
939
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
940
+ model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
941
+ terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
942
+ model_output = terms[gd_out_key]
943
+ if self.model_var_type in [
944
+ ModelVarType.LEARNED,
945
+ ModelVarType.LEARNED_RANGE,
946
+ ]:
947
+ B, C = x_t.shape[:2]
948
+ assert model_output.shape == (B, C, 2, *x_t.shape[2:])
949
+ model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
950
+ # Learn the variance using the variational bound, but don't let
951
+ # it affect our mean prediction.
952
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
953
+ terms["vb"] = self._vb_terms_bpd(
954
+ model=lambda *args, r=frozen_out: r,
955
+ x_start=x_start,
956
+ x_t=x_t,
957
+ t=t,
958
+ clip_denoised=False,
959
+ )["output"]
960
+ if self.loss_type == LossType.RESCALED_MSE:
961
+ # Divide by 1000 for equivalence with initial implementation.
962
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
963
+ terms["vb"] *= self.num_timesteps / 1000.0
964
+
965
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
966
+ target = self.q_posterior_mean_variance(
967
+ x_start=x_start, x_t=x_t, t=t
968
+ )[0]
969
+ x_start_pred = torch.zeros(x_start) # Not supported.
970
+ elif self.model_mean_type == ModelMeanType.START_X:
971
+ target = x_start
972
+ x_start_pred = model_output
973
+ elif self.model_mean_type == ModelMeanType.EPSILON:
974
+ target = noise
975
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
976
+ else:
977
+ raise NotImplementedError(self.model_mean_type)
978
+ assert model_output.shape == target.shape == x_start.shape
979
+ terms["mse"] = mean_flat((target - model_output) ** 2)
980
+ terms["x_start_predicted"] = x_start_pred
981
+ if "vb" in terms:
982
+ terms["loss"] = terms["mse"] + terms["vb"]
983
+ else:
984
+ terms["loss"] = terms["mse"]
985
+ else:
986
+ raise NotImplementedError(self.loss_type)
987
+
988
+ return terms
989
+
990
+ def _prior_bpd(self, x_start):
991
+ """
992
+ Get the prior KL term for the variational lower-bound, measured in
993
+ bits-per-dim.
994
+
995
+ This term can't be optimized, as it only depends on the encoder.
996
+
997
+ :param x_start: the [N x C x ...] tensor of inputs.
998
+ :return: a batch of [N] KL values (in bits), one per batch element.
999
+ """
1000
+ batch_size = x_start.shape[0]
1001
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1002
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1003
+ kl_prior = normal_kl(
1004
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1005
+ )
1006
+ return mean_flat(kl_prior) / np.log(2.0)
1007
+
1008
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
1009
+ """
1010
+ Compute the entire variational lower-bound, measured in bits-per-dim,
1011
+ as well as other related quantities.
1012
+
1013
+ :param model: the model to evaluate loss on.
1014
+ :param x_start: the [N x C x ...] tensor of inputs.
1015
+ :param clip_denoised: if True, clip denoised samples.
1016
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1017
+ pass to the model. This can be used for conditioning.
1018
+
1019
+ :return: a dict containing the following keys:
1020
+ - total_bpd: the total variational lower-bound, per batch element.
1021
+ - prior_bpd: the prior term in the lower-bound.
1022
+ - vb: an [N x T] tensor of terms in the lower-bound.
1023
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1024
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1025
+ """
1026
+ device = x_start.device
1027
+ batch_size = x_start.shape[0]
1028
+
1029
+ vb = []
1030
+ xstart_mse = []
1031
+ mse = []
1032
+ for t in list(range(self.num_timesteps))[::-1]:
1033
+ t_batch = th.tensor([t] * batch_size, device=device)
1034
+ noise = th.randn_like(x_start)
1035
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1036
+ # Calculate VLB term at the current timestep
1037
+ with th.no_grad():
1038
+ out = self._vb_terms_bpd(
1039
+ model,
1040
+ x_start=x_start,
1041
+ x_t=x_t,
1042
+ t=t_batch,
1043
+ clip_denoised=clip_denoised,
1044
+ model_kwargs=model_kwargs,
1045
+ )
1046
+ vb.append(out["output"])
1047
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1048
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1049
+ mse.append(mean_flat((eps - noise) ** 2))
1050
+
1051
+ vb = th.stack(vb, dim=1)
1052
+ xstart_mse = th.stack(xstart_mse, dim=1)
1053
+ mse = th.stack(mse, dim=1)
1054
+
1055
+ prior_bpd = self._prior_bpd(x_start)
1056
+ total_bpd = vb.sum(dim=1) + prior_bpd
1057
+ return {
1058
+ "total_bpd": total_bpd,
1059
+ "prior_bpd": prior_bpd,
1060
+ "vb": vb,
1061
+ "xstart_mse": xstart_mse,
1062
+ "mse": mse,
1063
+ }
1064
+
1065
+
1066
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
1067
+ """
1068
+ Get a pre-defined beta schedule for the given name.
1069
+
1070
+ The beta schedule library consists of beta schedules which remain similar
1071
+ in the limit of num_diffusion_timesteps.
1072
+ Beta schedules may be added, but should not be removed or changed once
1073
+ they are committed to maintain backwards compatibility.
1074
+ """
1075
+ if schedule_name == "linear":
1076
+ # Linear schedule from Ho et al, extended to work for any number of
1077
+ # diffusion steps.
1078
+ scale = 1000 / num_diffusion_timesteps
1079
+ beta_start = scale * 0.0001
1080
+ beta_end = scale * 0.02
1081
+ return np.linspace(
1082
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
1083
+ )
1084
+ elif schedule_name == "cosine":
1085
+ return betas_for_alpha_bar(
1086
+ num_diffusion_timesteps,
1087
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
1088
+ )
1089
+ else:
1090
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1091
+
1092
+
1093
+ class SpacedDiffusion(GaussianDiffusion):
1094
+ """
1095
+ A diffusion process which can skip steps in a base diffusion process.
1096
+
1097
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
1098
+ original diffusion process to retain.
1099
+ :param kwargs: the kwargs to create the base diffusion process.
1100
+ """
1101
+
1102
+ def __init__(self, use_timesteps, **kwargs):
1103
+ self.use_timesteps = set(use_timesteps)
1104
+ self.timestep_map = []
1105
+ self.original_num_steps = len(kwargs["betas"])
1106
+
1107
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
1108
+ last_alpha_cumprod = 1.0
1109
+ new_betas = []
1110
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
1111
+ if i in self.use_timesteps:
1112
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
1113
+ last_alpha_cumprod = alpha_cumprod
1114
+ self.timestep_map.append(i)
1115
+ kwargs["betas"] = np.array(new_betas)
1116
+ super().__init__(**kwargs)
1117
+
1118
+ def p_mean_variance(
1119
+ self, model, *args, **kwargs
1120
+ ): # pylint: disable=signature-differs
1121
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
1122
+
1123
+ def training_losses(
1124
+ self, model, *args, **kwargs
1125
+ ): # pylint: disable=signature-differs
1126
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
1127
+
1128
+ def autoregressive_training_losses(
1129
+ self, model, *args, **kwargs
1130
+ ): # pylint: disable=signature-differs
1131
+ return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
1132
+
1133
+ def condition_mean(self, cond_fn, *args, **kwargs):
1134
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
1135
+
1136
+ def condition_score(self, cond_fn, *args, **kwargs):
1137
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
1138
+
1139
+ def _wrap_model(self, model, autoregressive=False):
1140
+ if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
1141
+ return model
1142
+ mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
1143
+ return mod(
1144
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
1145
+ )
1146
+
1147
+ def _scale_timesteps(self, t):
1148
+ # Scaling is done by the wrapped model.
1149
+ return t
1150
+
1151
+
1152
+ def space_timesteps(num_timesteps, section_counts):
1153
+ """
1154
+ Create a list of timesteps to use from an original diffusion process,
1155
+ given the number of timesteps we want to take from equally-sized portions
1156
+ of the original process.
1157
+
1158
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
1159
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
1160
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
1161
+
1162
+ If the stride is a string starting with "ddim", then the fixed striding
1163
+ from the DDIM paper is used, and only one section is allowed.
1164
+
1165
+ :param num_timesteps: the number of diffusion steps in the original
1166
+ process to divide up.
1167
+ :param section_counts: either a list of numbers, or a string containing
1168
+ comma-separated numbers, indicating the step count
1169
+ per section. As a special case, use "ddimN" where N
1170
+ is a number of steps to use the striding from the
1171
+ DDIM paper.
1172
+ :return: a set of diffusion steps from the original process to use.
1173
+ """
1174
+ if isinstance(section_counts, str):
1175
+ if section_counts.startswith("ddim"):
1176
+ desired_count = int(section_counts[len("ddim") :])
1177
+ for i in range(1, num_timesteps):
1178
+ if len(range(0, num_timesteps, i)) == desired_count:
1179
+ return set(range(0, num_timesteps, i))
1180
+ raise ValueError(
1181
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
1182
+ )
1183
+ section_counts = [int(x) for x in section_counts.split(",")]
1184
+ size_per = num_timesteps // len(section_counts)
1185
+ extra = num_timesteps % len(section_counts)
1186
+ start_idx = 0
1187
+ all_steps = []
1188
+ for i, section_count in enumerate(section_counts):
1189
+ size = size_per + (1 if i < extra else 0)
1190
+ if size < section_count:
1191
+ raise ValueError(
1192
+ f"cannot divide section of {size} steps into {section_count}"
1193
+ )
1194
+ if section_count <= 1:
1195
+ frac_stride = 1
1196
+ else:
1197
+ frac_stride = (size - 1) / (section_count - 1)
1198
+ cur_idx = 0.0
1199
+ taken_steps = []
1200
+ for _ in range(section_count):
1201
+ taken_steps.append(start_idx + round(cur_idx))
1202
+ cur_idx += frac_stride
1203
+ all_steps += taken_steps
1204
+ start_idx += size
1205
+ return set(all_steps)
1206
+
1207
+
1208
+ class _WrappedModel:
1209
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1210
+ self.model = model
1211
+ self.timestep_map = timestep_map
1212
+ self.rescale_timesteps = rescale_timesteps
1213
+ self.original_num_steps = original_num_steps
1214
+
1215
+ def __call__(self, x, ts, **kwargs):
1216
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1217
+ new_ts = map_tensor[ts]
1218
+ if self.rescale_timesteps:
1219
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1220
+ return self.model(x, new_ts, **kwargs)
1221
+
1222
+
1223
+ class _WrappedAutoregressiveModel:
1224
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1225
+ self.model = model
1226
+ self.timestep_map = timestep_map
1227
+ self.rescale_timesteps = rescale_timesteps
1228
+ self.original_num_steps = original_num_steps
1229
+
1230
+ def __call__(self, x, x0, ts, **kwargs):
1231
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1232
+ new_ts = map_tensor[ts]
1233
+ if self.rescale_timesteps:
1234
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1235
+ return self.model(x, x0, new_ts, **kwargs)
1236
+
1237
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1238
+ """
1239
+ Extract values from a 1-D numpy array for a batch of indices.
1240
+
1241
+ :param arr: the 1-D numpy array.
1242
+ :param timesteps: a tensor of indices into the array to extract.
1243
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1244
+ dimension equal to the length of timesteps.
1245
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1246
+ """
1247
+ res = th.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps]
1248
+ while len(res.shape) < len(broadcast_shape):
1249
+ res = res[..., None]
1250
+ return res.expand(broadcast_shape)
tortoise/utils/stft.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+
4
+ Copyright (c) 2017, Prem Seetharaman
5
+ All rights reserved.
6
+
7
+ * Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice,
11
+ this list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice, this
14
+ list of conditions and the following disclaimer in the
15
+ documentation and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from this
19
+ software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ """
32
+
33
+ import torch
34
+ import numpy as np
35
+ import torch.nn.functional as F
36
+ from torch.autograd import Variable
37
+ from scipy.signal import get_window
38
+ from librosa.util import pad_center, tiny
39
+ import librosa.util as librosa_util
40
+
41
+
42
+ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
43
+ n_fft=800, dtype=np.float32, norm=None):
44
+ """
45
+ # from librosa 0.6
46
+ Compute the sum-square envelope of a window function at a given hop length.
47
+
48
+ This is used to estimate modulation effects induced by windowing
49
+ observations in short-time fourier transforms.
50
+
51
+ Parameters
52
+ ----------
53
+ window : string, tuple, number, callable, or list-like
54
+ Window specification, as in `get_window`
55
+
56
+ n_frames : int > 0
57
+ The number of analysis frames
58
+
59
+ hop_length : int > 0
60
+ The number of samples to advance between frames
61
+
62
+ win_length : [optional]
63
+ The length of the window function. By default, this matches `n_fft`.
64
+
65
+ n_fft : int > 0
66
+ The length of each analysis frame.
67
+
68
+ dtype : np.dtype
69
+ The data type of the output
70
+
71
+ Returns
72
+ -------
73
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
74
+ The sum-squared envelope of the window function
75
+ """
76
+ if win_length is None:
77
+ win_length = n_fft
78
+
79
+ n = n_fft + hop_length * (n_frames - 1)
80
+ x = np.zeros(n, dtype=dtype)
81
+
82
+ # Compute the squared window at the desired length
83
+ win_sq = get_window(window, win_length, fftbins=True)
84
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
85
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
86
+
87
+ # Fill the envelope
88
+ for i in range(n_frames):
89
+ sample = i * hop_length
90
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
91
+ return x
92
+
93
+
94
+ class STFT(torch.nn.Module):
95
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
96
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
97
+ window='hann'):
98
+ super(STFT, self).__init__()
99
+ self.filter_length = filter_length
100
+ self.hop_length = hop_length
101
+ self.win_length = win_length
102
+ self.window = window
103
+ self.forward_transform = None
104
+ scale = self.filter_length / self.hop_length
105
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
106
+
107
+ cutoff = int((self.filter_length / 2 + 1))
108
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
109
+ np.imag(fourier_basis[:cutoff, :])])
110
+
111
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
112
+ inverse_basis = torch.FloatTensor(
113
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
114
+
115
+ if window is not None:
116
+ assert(filter_length >= win_length)
117
+ # get window and zero center pad it to filter_length
118
+ fft_window = get_window(window, win_length, fftbins=True)
119
+ fft_window = pad_center(fft_window, size=filter_length)
120
+ fft_window = torch.from_numpy(fft_window).float()
121
+
122
+ # window the bases
123
+ forward_basis *= fft_window
124
+ inverse_basis *= fft_window
125
+
126
+ self.register_buffer('forward_basis', forward_basis.float())
127
+ self.register_buffer('inverse_basis', inverse_basis.float())
128
+
129
+ def transform(self, input_data):
130
+ num_batches = input_data.size(0)
131
+ num_samples = input_data.size(1)
132
+
133
+ self.num_samples = num_samples
134
+
135
+ # similar to librosa, reflect-pad the input
136
+ input_data = input_data.view(num_batches, 1, num_samples)
137
+ input_data = F.pad(
138
+ input_data.unsqueeze(1),
139
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
140
+ mode='reflect')
141
+ input_data = input_data.squeeze(1)
142
+
143
+ forward_transform = F.conv1d(
144
+ input_data,
145
+ Variable(self.forward_basis, requires_grad=False),
146
+ stride=self.hop_length,
147
+ padding=0)
148
+
149
+ cutoff = int((self.filter_length / 2) + 1)
150
+ real_part = forward_transform[:, :cutoff, :]
151
+ imag_part = forward_transform[:, cutoff:, :]
152
+
153
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
154
+ phase = torch.autograd.Variable(
155
+ torch.atan2(imag_part.data, real_part.data))
156
+
157
+ return magnitude, phase
158
+
159
+ def inverse(self, magnitude, phase):
160
+ recombine_magnitude_phase = torch.cat(
161
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
162
+
163
+ inverse_transform = F.conv_transpose1d(
164
+ recombine_magnitude_phase,
165
+ Variable(self.inverse_basis, requires_grad=False),
166
+ stride=self.hop_length,
167
+ padding=0)
168
+
169
+ if self.window is not None:
170
+ window_sum = window_sumsquare(
171
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
172
+ win_length=self.win_length, n_fft=self.filter_length,
173
+ dtype=np.float32)
174
+ # remove modulation effects
175
+ approx_nonzero_indices = torch.from_numpy(
176
+ np.where(window_sum > tiny(window_sum))[0])
177
+ window_sum = torch.autograd.Variable(
178
+ torch.from_numpy(window_sum), requires_grad=False)
179
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
180
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
181
+
182
+ # scale by hop ratio
183
+ inverse_transform *= float(self.filter_length) / self.hop_length
184
+
185
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
186
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
187
+
188
+ return inverse_transform
189
+
190
+ def forward(self, input_data):
191
+ self.magnitude, self.phase = self.transform(input_data)
192
+ reconstruction = self.inverse(self.magnitude, self.phase)
193
+ return reconstruction
tortoise/utils/text.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def split_and_recombine_text(text, desired_length=200, max_length=300):
5
+ """Split text it into chunks of a desired length trying to keep sentences intact."""
6
+ # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
7
+ text = re.sub(r'\n\n+', '\n', text)
8
+ text = re.sub(r'\s+', ' ', text)
9
+ text = re.sub(r'[“”]', '"', text)
10
+
11
+ rv = []
12
+ in_quote = False
13
+ current = ""
14
+ split_pos = []
15
+ pos = -1
16
+ end_pos = len(text) - 1
17
+
18
+ def seek(delta):
19
+ nonlocal pos, in_quote, current
20
+ is_neg = delta < 0
21
+ for _ in range(abs(delta)):
22
+ if is_neg:
23
+ pos -= 1
24
+ current = current[:-1]
25
+ else:
26
+ pos += 1
27
+ current += text[pos]
28
+ if text[pos] == '"':
29
+ in_quote = not in_quote
30
+ return text[pos]
31
+
32
+ def peek(delta):
33
+ p = pos + delta
34
+ return text[p] if p < end_pos and p >= 0 else ""
35
+
36
+ def commit():
37
+ nonlocal rv, current, split_pos
38
+ rv.append(current)
39
+ current = ""
40
+ split_pos = []
41
+
42
+ while pos < end_pos:
43
+ c = seek(1)
44
+ # do we need to force a split?
45
+ if len(current) >= max_length:
46
+ if len(split_pos) > 0 and len(current) > (desired_length / 2):
47
+ # we have at least one sentence and we are over half the desired length, seek back to the last split
48
+ d = pos - split_pos[-1]
49
+ seek(-d)
50
+ else:
51
+ # no full sentences, seek back until we are not in the middle of a word and split there
52
+ while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
53
+ c = seek(-1)
54
+ commit()
55
+ # check for sentence boundaries
56
+ elif not in_quote and (c in '!?\n' or (c == '.' and peek(1) in '\n ')):
57
+ # seek forward if we have consecutive boundary markers but still within the max length
58
+ while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.':
59
+ c = seek(1)
60
+ split_pos.append(pos)
61
+ if len(current) >= desired_length:
62
+ commit()
63
+ # treat end of quote as a boundary if its followed by a space or newline
64
+ elif in_quote and peek(1) == '"' and peek(2) in '\n ':
65
+ seek(2)
66
+ split_pos.append(pos)
67
+ rv.append(current)
68
+
69
+ # clean up, remove lines with only whitespace or punctuation
70
+ rv = [s.strip() for s in rv]
71
+ rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)]
72
+
73
+ return rv
74
+
75
+
76
+ if __name__ == '__main__':
77
+ import os
78
+ import unittest
79
+
80
+ class Test(unittest.TestCase):
81
+ def test_split_and_recombine_text(self):
82
+ text = """
83
+ This is a sample sentence.
84
+ This is another sample sentence.
85
+ This is a longer sample sentence that should force a split inthemiddlebutinotinthislongword.
86
+ "Don't split my quote... please"
87
+ """
88
+ self.assertEqual(split_and_recombine_text(text, desired_length=20, max_length=40),
89
+ ['This is a sample sentence.',
90
+ 'This is another sample sentence.',
91
+ 'This is a longer sample sentence that',
92
+ 'should force a split',
93
+ 'inthemiddlebutinotinthislongword.',
94
+ '"Don\'t split my quote... please"'])
95
+
96
+ def test_split_and_recombine_text_2(self):
97
+ text = """
98
+ When you are really angry sometimes you use consecutive exclamation marks!!!!!! Is this a good thing to do?!?!?!
99
+ I don't know but we should handle this situation..........................
100
+ """
101
+ self.assertEqual(split_and_recombine_text(text, desired_length=30, max_length=50),
102
+ ['When you are really angry sometimes you use',
103
+ 'consecutive exclamation marks!!!!!!',
104
+ 'Is this a good thing to do?!?!?!',
105
+ 'I don\'t know but we should handle this situation.'])
106
+
107
+ def test_split_and_recombine_text_3(self):
108
+ text_src = os.path.join(os.path.dirname(__file__), '../data/riding_hood.txt')
109
+ with open(text_src, 'r') as f:
110
+ text = f.read()
111
+ self.assertEqual(
112
+ split_and_recombine_text(text),
113
+ [
114
+ 'Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her.',
115
+ 'It suited the girl so extremely well that everybody called her Little Red Riding Hood. One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter."',
116
+ 'Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village. As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest.',
117
+ 'He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother." "Does she live far off?" said the wolf "Oh I say,"',
118
+ 'answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village." "Well," said the wolf, "and I\'ll go and see her too. I\'ll go this way and go you that, and we shall see who will be there first."',
119
+ 'The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers.',
120
+ 'It was not long before the wolf arrived at the old woman\'s house. He knocked at the door: tap, tap. "Who\'s there?" "Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother."',
121
+ 'The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up."',
122
+ 'The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten.',
123
+ 'He then shut the door and got into the grandmother\'s bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap. "Who\'s there?"',
124
+ 'Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you."',
125
+ 'The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up." Little Red Riding Hood pulled the bobbin, and the door opened.',
126
+ 'The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me." Little Red Riding Hood took off her clothes and got into bed.',
127
+ 'She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!" "All the better to hug you with, my dear." "Grandmother, what big legs you have!" "All the better to run with, my child." "Grandmother, what big ears you have!"',
128
+ '"All the better to hear with, my child." "Grandmother, what big eyes you have!" "All the better to see with, my child." "Grandmother, what big teeth you have got!" "All the better to eat you up with." And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up.',
129
+ ]
130
+ )
131
+
132
+ unittest.main()
tortoise/utils/tokenizer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import inflect
5
+ import torch
6
+ from tokenizers import Tokenizer
7
+
8
+
9
+ # Regular expression matching whitespace:
10
+ from unidecode import unidecode
11
+
12
+ _whitespace_re = re.compile(r'\s+')
13
+
14
+
15
+ # List of (regular expression, replacement) pairs for abbreviations:
16
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
17
+ ('mrs', 'misess'),
18
+ ('mr', 'mister'),
19
+ ('dr', 'doctor'),
20
+ ('st', 'saint'),
21
+ ('co', 'company'),
22
+ ('jr', 'junior'),
23
+ ('maj', 'major'),
24
+ ('gen', 'general'),
25
+ ('drs', 'doctors'),
26
+ ('rev', 'reverend'),
27
+ ('lt', 'lieutenant'),
28
+ ('hon', 'honorable'),
29
+ ('sgt', 'sergeant'),
30
+ ('capt', 'captain'),
31
+ ('esq', 'esquire'),
32
+ ('ltd', 'limited'),
33
+ ('col', 'colonel'),
34
+ ('ft', 'fort'),
35
+ ]]
36
+
37
+
38
+ def expand_abbreviations(text):
39
+ for regex, replacement in _abbreviations:
40
+ text = re.sub(regex, replacement, text)
41
+ return text
42
+
43
+
44
+ _inflect = inflect.engine()
45
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
46
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
47
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
48
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
49
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
50
+ _number_re = re.compile(r'[0-9]+')
51
+
52
+
53
+ def _remove_commas(m):
54
+ return m.group(1).replace(',', '')
55
+
56
+
57
+ def _expand_decimal_point(m):
58
+ return m.group(1).replace('.', ' point ')
59
+
60
+
61
+ def _expand_dollars(m):
62
+ match = m.group(1)
63
+ parts = match.split('.')
64
+ if len(parts) > 2:
65
+ return match + ' dollars' # Unexpected format
66
+ dollars = int(parts[0]) if parts[0] else 0
67
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
68
+ if dollars and cents:
69
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
70
+ cent_unit = 'cent' if cents == 1 else 'cents'
71
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
72
+ elif dollars:
73
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
74
+ return '%s %s' % (dollars, dollar_unit)
75
+ elif cents:
76
+ cent_unit = 'cent' if cents == 1 else 'cents'
77
+ return '%s %s' % (cents, cent_unit)
78
+ else:
79
+ return 'zero dollars'
80
+
81
+
82
+ def _expand_ordinal(m):
83
+ return _inflect.number_to_words(m.group(0))
84
+
85
+
86
+ def _expand_number(m):
87
+ num = int(m.group(0))
88
+ if num > 1000 and num < 3000:
89
+ if num == 2000:
90
+ return 'two thousand'
91
+ elif num > 2000 and num < 2010:
92
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
93
+ elif num % 100 == 0:
94
+ return _inflect.number_to_words(num // 100) + ' hundred'
95
+ else:
96
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
97
+ else:
98
+ return _inflect.number_to_words(num, andword='')
99
+
100
+
101
+ def normalize_numbers(text):
102
+ text = re.sub(_comma_number_re, _remove_commas, text)
103
+ text = re.sub(_pounds_re, r'\1 pounds', text)
104
+ text = re.sub(_dollars_re, _expand_dollars, text)
105
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
106
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
107
+ text = re.sub(_number_re, _expand_number, text)
108
+ return text
109
+
110
+
111
+ def expand_numbers(text):
112
+ return normalize_numbers(text)
113
+
114
+
115
+ def lowercase(text):
116
+ return text.lower()
117
+
118
+
119
+ def collapse_whitespace(text):
120
+ return re.sub(_whitespace_re, ' ', text)
121
+
122
+
123
+ def convert_to_ascii(text):
124
+ return unidecode(text)
125
+
126
+
127
+ def basic_cleaners(text):
128
+ '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
129
+ text = lowercase(text)
130
+ text = collapse_whitespace(text)
131
+ return text
132
+
133
+
134
+ def transliteration_cleaners(text):
135
+ '''Pipeline for non-English text that transliterate to ASCII.'''
136
+ text = convert_to_ascii(text)
137
+ text = lowercase(text)
138
+ text = collapse_whitespace(text)
139
+ return text
140
+
141
+
142
+ def english_cleaners(text):
143
+ '''Pipeline for English text, including number and abbreviation expansion.'''
144
+ text = convert_to_ascii(text)
145
+ text = lowercase(text)
146
+ text = expand_numbers(text)
147
+ text = expand_abbreviations(text)
148
+ text = collapse_whitespace(text)
149
+ text = text.replace('"', '')
150
+ return text
151
+
152
+
153
+ def lev_distance(s1, s2):
154
+ if len(s1) > len(s2):
155
+ s1, s2 = s2, s1
156
+
157
+ distances = range(len(s1) + 1)
158
+ for i2, c2 in enumerate(s2):
159
+ distances_ = [i2 + 1]
160
+ for i1, c1 in enumerate(s1):
161
+ if c1 == c2:
162
+ distances_.append(distances[i1])
163
+ else:
164
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
165
+ distances = distances_
166
+ return distances[-1]
167
+
168
+
169
+ DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/tokenizer.json')
170
+
171
+
172
+ class VoiceBpeTokenizer:
173
+ def __init__(self, vocab_file=None, use_basic_cleaners=False):
174
+ self.tokenizer = Tokenizer.from_file(
175
+ DEFAULT_VOCAB_FILE if vocab_file is None else vocab_file
176
+ )
177
+ if use_basic_cleaners:
178
+ self.preprocess_text = basic_cleaners
179
+ else:
180
+ self.preprocess_text = english_cleaners
181
+
182
+ def encode(self, txt):
183
+ txt = self.preprocess_text(txt)
184
+ txt = txt.replace(' ', '[SPACE]')
185
+ return self.tokenizer.encode(txt).ids
186
+
187
+ def decode(self, seq):
188
+ if isinstance(seq, torch.Tensor):
189
+ seq = seq.cpu().numpy()
190
+ txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
191
+ txt = txt.replace('[SPACE]', ' ')
192
+ txt = txt.replace('[STOP]', '')
193
+ txt = txt.replace('[UNK]', '')
194
+ return txt
tortoise/utils/typical_sampling.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import LogitsWarper
3
+
4
+
5
+ class TypicalLogitsWarper(LogitsWarper):
6
+ def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
7
+ self.filter_value = filter_value
8
+ self.mass = mass
9
+ self.min_tokens_to_keep = min_tokens_to_keep
10
+
11
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
12
+ # calculate entropy
13
+ normalized = torch.nn.functional.log_softmax(scores, dim=-1)
14
+ p = torch.exp(normalized)
15
+ ent = -(normalized * p).nansum(-1, keepdim=True)
16
+
17
+ # shift and sort
18
+ shifted_scores = torch.abs((-normalized) - ent)
19
+ sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
20
+ sorted_logits = scores.gather(-1, sorted_indices)
21
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
22
+
23
+ # Remove tokens with cumulative mass above the threshold
24
+ last_ind = (cumulative_probs < self.mass).sum(dim=1)
25
+ last_ind[last_ind < 0] = 0
26
+ sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
27
+ if self.min_tokens_to_keep > 1:
28
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
29
+ sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
30
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
31
+
32
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
33
+ return scores
tortoise/utils/wav2vec_alignment.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ import torchaudio
5
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
6
+
7
+ from tortoise.utils.audio import load_audio
8
+
9
+
10
+ def max_alignment(s1, s2, skip_character='~', record=None):
11
+ """
12
+ A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
13
+ used to replace that character.
14
+
15
+ Finally got to use my DP skills!
16
+ """
17
+ if record is None:
18
+ record = {}
19
+ assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
20
+ if len(s1) == 0:
21
+ return ''
22
+ if len(s2) == 0:
23
+ return skip_character * len(s1)
24
+ if s1 == s2:
25
+ return s1
26
+ if s1[0] == s2[0]:
27
+ return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record)
28
+
29
+ take_s1_key = (len(s1), len(s2) - 1)
30
+ if take_s1_key in record:
31
+ take_s1, take_s1_score = record[take_s1_key]
32
+ else:
33
+ take_s1 = max_alignment(s1, s2[1:], skip_character, record)
34
+ take_s1_score = len(take_s1.replace(skip_character, ''))
35
+ record[take_s1_key] = (take_s1, take_s1_score)
36
+
37
+ take_s2_key = (len(s1) - 1, len(s2))
38
+ if take_s2_key in record:
39
+ take_s2, take_s2_score = record[take_s2_key]
40
+ else:
41
+ take_s2 = max_alignment(s1[1:], s2, skip_character, record)
42
+ take_s2_score = len(take_s2.replace(skip_character, ''))
43
+ record[take_s2_key] = (take_s2, take_s2_score)
44
+
45
+ return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2
46
+
47
+
48
+ class Wav2VecAlignment:
49
+ """
50
+ Uses wav2vec2 to perform audio<->text alignment.
51
+ """
52
+ def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
53
+ self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
54
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
55
+ self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
56
+ self.device = device
57
+
58
+ def align(self, audio, expected_text, audio_sample_rate=24000):
59
+ orig_len = audio.shape[-1]
60
+
61
+ with torch.no_grad():
62
+ self.model = self.model.to(self.device)
63
+ audio = audio.to(self.device)
64
+ audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
65
+ clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
66
+ logits = self.model(clip_norm).logits
67
+ self.model = self.model.cpu()
68
+
69
+ logits = logits[0]
70
+ pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
71
+
72
+ fixed_expectation = max_alignment(expected_text.lower(), pred_string)
73
+ w2v_compression = orig_len // logits.shape[0]
74
+ expected_tokens = self.tokenizer.encode(fixed_expectation)
75
+ expected_chars = list(fixed_expectation)
76
+ if len(expected_tokens) == 1:
77
+ return [0] # The alignment is simple; there is only one token.
78
+ expected_tokens.pop(0) # The first token is a given.
79
+ expected_chars.pop(0)
80
+
81
+ alignments = [0]
82
+ def pop_till_you_win():
83
+ if len(expected_tokens) == 0:
84
+ return None
85
+ popped = expected_tokens.pop(0)
86
+ popped_char = expected_chars.pop(0)
87
+ while popped_char == '~':
88
+ alignments.append(-1)
89
+ if len(expected_tokens) == 0:
90
+ return None
91
+ popped = expected_tokens.pop(0)
92
+ popped_char = expected_chars.pop(0)
93
+ return popped
94
+
95
+ next_expected_token = pop_till_you_win()
96
+ for i, logit in enumerate(logits):
97
+ top = logit.argmax()
98
+ if next_expected_token == top:
99
+ alignments.append(i * w2v_compression)
100
+ if len(expected_tokens) > 0:
101
+ next_expected_token = pop_till_you_win()
102
+ else:
103
+ break
104
+
105
+ pop_till_you_win()
106
+ if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)):
107
+ torch.save([audio, expected_text], 'alignment_debug.pth')
108
+ assert False, "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to" \
109
+ "your current working directory. Please report this along with the file so it can get fixed."
110
+
111
+ # Now fix up alignments. Anything with -1 should be interpolated.
112
+ alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable.
113
+ for i in range(len(alignments)):
114
+ if alignments[i] == -1:
115
+ for j in range(i+1, len(alignments)):
116
+ if alignments[j] != -1:
117
+ next_found_token = j
118
+ break
119
+ for j in range(i, next_found_token):
120
+ gap = alignments[next_found_token] - alignments[i-1]
121
+ alignments[j] = (j-i+1) * gap // (next_found_token-i+1) + alignments[i-1]
122
+
123
+ return alignments[:-1]
124
+
125
+ def redact(self, audio, expected_text, audio_sample_rate=24000):
126
+ if '[' not in expected_text:
127
+ return audio
128
+ splitted = expected_text.split('[')
129
+ fully_split = [splitted[0]]
130
+ for spl in splitted[1:]:
131
+ assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
132
+ fully_split.extend(spl.split(']'))
133
+
134
+ # At this point, fully_split is a list of strings, with every other string being something that should be redacted.
135
+ non_redacted_intervals = []
136
+ last_point = 0
137
+ for i in range(len(fully_split)):
138
+ if i % 2 == 0 and fully_split[i] != "": # Check for empty string fixes index error
139
+ end_interval = max(0, last_point + len(fully_split[i]) - 1)
140
+ non_redacted_intervals.append((last_point, end_interval))
141
+ last_point += len(fully_split[i])
142
+
143
+ bare_text = ''.join(fully_split)
144
+ alignments = self.align(audio, bare_text, audio_sample_rate)
145
+
146
+ output_audio = []
147
+ for nri in non_redacted_intervals:
148
+ start, stop = nri
149
+ output_audio.append(audio[:, alignments[start]:alignments[stop]])
150
+ return torch.cat(output_audio, dim=-1)
tortoise/voices/angie/1.wav ADDED
Binary file (625 kB). View file
 
tortoise/voices/angie/2.wav ADDED
Binary file (551 kB). View file
 
tortoise/voices/angie/3.wav ADDED
Binary file (827 kB). View file
 
tortoise/voices/applejack/1.wav ADDED
Binary file (328 kB). View file
 
tortoise/voices/applejack/2.wav ADDED
Binary file (331 kB). View file
 
tortoise/voices/applejack/3.wav ADDED
Binary file (327 kB). View file
 
tortoise/voices/atkins/1.wav ADDED
Binary file (397 kB). View file
 
tortoise/voices/atkins/2.wav ADDED
Binary file (309 kB). View file
 
tortoise/voices/daniel/1.wav ADDED
Binary file (618 kB). View file
 
tortoise/voices/daniel/2.wav ADDED
Binary file (329 kB). View file
 
tortoise/voices/daniel/3.wav ADDED
Binary file (371 kB). View file
 
tortoise/voices/daniel/4.wav ADDED
Binary file (275 kB). View file
 
tortoise/voices/daws/1.mp3 ADDED
Binary file (36.1 kB). View file
 
tortoise/voices/daws/2.mp3 ADDED
Binary file (35.3 kB). View file
 
tortoise/voices/daws/3.mp3 ADDED
Binary file (35.6 kB). View file
 
tortoise/voices/deniro/1.wav ADDED
Binary file (407 kB). View file
 
tortoise/voices/deniro/2.wav ADDED
Binary file (610 kB). View file