prakashr7d commited on
Commit
b0bf39f
1 Parent(s): 9d1ae0d

written handler

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __pycache__/constants.cpython-310.pyc +0 -0
  2. __pycache__/constants.cpython-38.pyc +0 -0
  3. __pycache__/handler.cpython-38.pyc +0 -0
  4. __pycache__/serve.cpython-310.pyc +0 -0
  5. __pycache__/serve.cpython-38.pyc +0 -0
  6. __pycache__/server.cpython-38.pyc +0 -0
  7. __pycache__/try.cpython-310.pyc +0 -0
  8. __pycache__/utils.cpython-310.pyc +0 -0
  9. __pycache__/utils.cpython-38.pyc +0 -0
  10. config-model.yaml +12 -0
  11. constants.py +12 -0
  12. handler.py +479 -0
  13. requirements.txt +18 -0
  14. ruth_tts_transformer/.gitignore +1 -0
  15. ruth_tts_transformer/__init__.py +2 -0
  16. ruth_tts_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  17. ruth_tts_transformer/__pycache__/__init__.cpython-37.pyc +0 -0
  18. ruth_tts_transformer/__pycache__/__init__.cpython-38.pyc +0 -0
  19. ruth_tts_transformer/__pycache__/__init__.cpython-39.pyc +0 -0
  20. ruth_tts_transformer/data/latents.pkl +0 -0
  21. ruth_tts_transformer/data/layman.txt +0 -0
  22. ruth_tts_transformer/data/mel_norms.pth +0 -0
  23. ruth_tts_transformer/data/riding_hood.txt +54 -0
  24. ruth_tts_transformer/data/seal_copypasta.txt +1 -0
  25. ruth_tts_transformer/data/tokenizer.json +1 -0
  26. ruth_tts_transformer/models/__init__.py +0 -0
  27. ruth_tts_transformer/models/__pycache__/__init__.cpython-310.pyc +0 -0
  28. ruth_tts_transformer/models/__pycache__/__init__.cpython-38.pyc +0 -0
  29. ruth_tts_transformer/models/__pycache__/arch_util.cpython-310.pyc +0 -0
  30. ruth_tts_transformer/models/__pycache__/arch_util.cpython-38.pyc +0 -0
  31. ruth_tts_transformer/models/__pycache__/autoregressive.cpython-310.pyc +0 -0
  32. ruth_tts_transformer/models/__pycache__/autoregressive.cpython-38.pyc +0 -0
  33. ruth_tts_transformer/models/__pycache__/clvp.cpython-310.pyc +0 -0
  34. ruth_tts_transformer/models/__pycache__/clvp.cpython-38.pyc +0 -0
  35. ruth_tts_transformer/models/__pycache__/diffusion_decoder.cpython-310.pyc +0 -0
  36. ruth_tts_transformer/models/__pycache__/diffusion_decoder.cpython-38.pyc +0 -0
  37. ruth_tts_transformer/models/__pycache__/transformer.cpython-310.pyc +0 -0
  38. ruth_tts_transformer/models/__pycache__/transformer.cpython-38.pyc +0 -0
  39. ruth_tts_transformer/models/__pycache__/vocoder.cpython-310.pyc +0 -0
  40. ruth_tts_transformer/models/__pycache__/vocoder.cpython-38.pyc +0 -0
  41. ruth_tts_transformer/models/__pycache__/xtransformers.cpython-310.pyc +0 -0
  42. ruth_tts_transformer/models/__pycache__/xtransformers.cpython-38.pyc +0 -0
  43. ruth_tts_transformer/models/arch_util.py +371 -0
  44. ruth_tts_transformer/models/autoregressive.py +528 -0
  45. ruth_tts_transformer/models/clvp.py +155 -0
  46. ruth_tts_transformer/models/diffusion_decoder.py +349 -0
  47. ruth_tts_transformer/models/transformer.py +221 -0
  48. ruth_tts_transformer/models/vocoder.py +323 -0
  49. ruth_tts_transformer/models/xtransformers.py +1248 -0
  50. ruth_tts_transformer/utils/__init__.py +0 -0
__pycache__/constants.cpython-310.pyc ADDED
Binary file (538 Bytes). View file
 
__pycache__/constants.cpython-38.pyc ADDED
Binary file (530 Bytes). View file
 
__pycache__/handler.cpython-38.pyc ADDED
Binary file (13.2 kB). View file
 
__pycache__/serve.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
__pycache__/serve.cpython-38.pyc ADDED
Binary file (13.6 kB). View file
 
__pycache__/server.cpython-38.pyc ADDED
Binary file (13.1 kB). View file
 
__pycache__/try.cpython-310.pyc ADDED
Binary file (1.17 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.7 kB). View file
 
__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.66 kB). View file
 
config-model.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpt:
2
+ num_autoregressive_samples: 16
3
+ top_p: 0.8
4
+ temperature: 0.8
5
+ length_penalty: 1
6
+ max_mel_tokens: 500
7
+ repetition_penalty: 2.0
8
+ autoregressive_batch_size: 16
9
+ clvp:
10
+ k: 1
11
+ diffusion:
12
+ diffusion_temperature: 1.0
constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUM_AUTOREGRESSIVE_SAMPLES = "num_autoregressive_samples"
2
+ TOP_P = "top_p"
3
+ TEMPERATURE = "temperature"
4
+ LENGTH_PENALTY = "length_penalty"
5
+ REPETITION_PENALTY = "repetition_penalty"
6
+ MAX_MEL_TOKENS = "max_mel_tokens"
7
+ AUTO_REGRESSIVE_BATCH_SIZE = "autoregressive_batch_size"
8
+ DIFFUSION_TEMPERATURE = "diffusion_temperature"
9
+ # MODELS
10
+ GPT = "gpt"
11
+ CLVP_const = "clvp"
12
+ DIFFUSION = "diffusion"
handler.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import hashlib
3
+ from io import BytesIO
4
+ import random
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ from copy import copy
9
+ from datetime import datetime
10
+ from fastapi import FastAPI
11
+ from fastapi.responses import FileResponse
12
+ from pathlib import Path
13
+ from pydantic import BaseModel
14
+
15
+ from time import time
16
+ from typing import Any, Dict, List, Text, Tuple
17
+
18
+ from constants import (
19
+ AUTO_REGRESSIVE_BATCH_SIZE,
20
+ DIFFUSION,
21
+ DIFFUSION_TEMPERATURE,
22
+ GPT,
23
+ LENGTH_PENALTY,
24
+ MAX_MEL_TOKENS,
25
+ NUM_AUTOREGRESSIVE_SAMPLES,
26
+ REPETITION_PENALTY,
27
+ TEMPERATURE,
28
+ TOP_P,
29
+ CLVP_const,
30
+ )
31
+ from ruth_tts_transformer.models.autoregressive import UnifiedVoice
32
+ from ruth_tts_transformer.models.clvp import CLVP
33
+ from ruth_tts_transformer.models.diffusion_decoder import DiffusionTts
34
+ from ruth_tts_transformer.models.vocoder import UnivNetGenerator
35
+ from ruth_tts_transformer.utils.audio import load_voice
36
+ from ruth_tts_transformer.utils.tokenizer import VoiceBpeTokenizer
37
+ from ruth_tts_transformer.utils.wav2vec_alignment import Wav2VecAlignment
38
+ from utils import (
39
+ MODELS_DIR,
40
+ get_config_file,
41
+ get_model_path,
42
+ load_discrete_vocoder_diffuser,
43
+ )
44
+
45
+ app = FastAPI()
46
+
47
+
48
+ class Item(BaseModel):
49
+ text: str
50
+ voice: str
51
+ seed: int = 3
52
+
53
+
54
+ class Gpt:
55
+ def __init__(
56
+ self,
57
+ num_autoregressive_samples: int,
58
+ top_p: float,
59
+ temperature: float,
60
+ length_penalty: int,
61
+ repetition_penalty: float,
62
+ max_mel_tokens: int,
63
+ autoregressive_batch_size: int,
64
+ ):
65
+ self.num_autoregressive_samples = num_autoregressive_samples
66
+ self.top_p = top_p
67
+ self.temperature = temperature
68
+ self.length_penalty = length_penalty
69
+ self.repetition_penalty = repetition_penalty
70
+ self.max_mel_tokens = max_mel_tokens
71
+ self.autoregressive_batch_size = autoregressive_batch_size
72
+ self.gpt = (
73
+ UnifiedVoice(
74
+ max_mel_tokens=604,
75
+ max_text_tokens=402,
76
+ max_conditioning_inputs=2,
77
+ layers=30,
78
+ model_dim=1024,
79
+ heads=16,
80
+ number_text_tokens=255,
81
+ start_text_token=255,
82
+ checkpointing=False,
83
+ train_solo_embeddings=False,
84
+ )
85
+ .cpu()
86
+ .eval()
87
+ )
88
+ self.gpt.load_state_dict(
89
+ torch.load(get_model_path("autoregressive.pth", MODELS_DIR))
90
+ )
91
+ self.gpt = self.gpt.to("cuda")
92
+
93
+ def __num_batches(self):
94
+ return self.num_autoregressive_samples // self.autoregressive_batch_size
95
+
96
+ @staticmethod
97
+ def deterministic_state(seed=None):
98
+ seed = int(time()) if seed is None else seed
99
+ torch.manual_seed(seed)
100
+ random.seed(seed)
101
+ return seed
102
+
103
+ def parse(self, auto_conditioning, text_tokens, best_results, seed, k=1):
104
+ self.deterministic_state(seed=seed)
105
+ auto_conditioning = copy(auto_conditioning).to("cuda")
106
+ text_tokens = copy(text_tokens).to("cuda")
107
+ best_results = copy(best_results).to("cuda")
108
+ best_latents = self.gpt(
109
+ auto_conditioning.repeat(k, 1),
110
+ text_tokens.repeat(k, 1),
111
+ torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
112
+ best_results,
113
+ torch.tensor(
114
+ [best_results.shape[-1] * self.gpt.mel_length_compression],
115
+ device=text_tokens.device,
116
+ ),
117
+ return_latent=True,
118
+ clip_inputs=False,
119
+ )
120
+ # return best_latents.cpu().detach().numpy()
121
+ return best_latents
122
+
123
+ def parse_inference(
124
+ self, auto_conditioning: torch.Tensor, text_tokens: torch.Tensor, seed
125
+ ) -> Tuple[List[torch.Tensor], int]:
126
+ self.deterministic_state(seed=seed)
127
+ auto_conditioning = copy(auto_conditioning).to("cuda")
128
+ text_tokens = copy(text_tokens).to("cuda")
129
+ with torch.no_grad():
130
+ samples = []
131
+ num_batches = self.__num_batches()
132
+ for b in range(num_batches):
133
+ codes = self.gpt.inference_speech(
134
+ auto_conditioning,
135
+ text_tokens,
136
+ do_sample=True,
137
+ top_p=self.top_p,
138
+ temperature=self.temperature,
139
+ num_return_sequences=self.autoregressive_batch_size,
140
+ length_penalty=self.length_penalty,
141
+ repetition_penalty=self.repetition_penalty,
142
+ max_generate_length=self.max_mel_tokens,
143
+ )
144
+ padding_needed = self.max_mel_tokens - codes.shape[1]
145
+ codes = F.pad(codes, (0, padding_needed), value=self.gpt.stop_mel_token)
146
+ # samples.append(codes.cpu().detach().numpy())
147
+ samples.append(codes)
148
+
149
+ return samples, self.gpt.stop_mel_token
150
+
151
+
152
+ class clvp:
153
+ def __init__(self, K):
154
+
155
+ self.clvp = (
156
+ CLVP(
157
+ dim_text=768,
158
+ dim_speech=768,
159
+ dim_latent=768,
160
+ num_text_tokens=256,
161
+ text_enc_depth=20,
162
+ text_seq_len=350,
163
+ text_heads=12,
164
+ num_speech_tokens=8192,
165
+ speech_enc_depth=20,
166
+ speech_heads=12,
167
+ speech_seq_len=430,
168
+ use_xformers=True,
169
+ )
170
+ .cpu()
171
+ .eval()
172
+ )
173
+ self.clvp.load_state_dict(torch.load(get_model_path("clvp2.pth", MODELS_DIR)))
174
+ self.clvp.to("cuda")
175
+ self.K = K
176
+
177
+ @staticmethod
178
+ def fix_gpt_output(codes, stop_token, complain=True):
179
+ stop_token_indices = (codes == stop_token).nonzero()
180
+ if len(stop_token_indices) == 0:
181
+ if complain:
182
+ print(
183
+ "No stop tokens found in one of the generated voice clips. This typically means the spoken audio "
184
+ "is "
185
+ "too long. In some cases, the output will still be good, though. Listen to it and if it is "
186
+ "missing words, "
187
+ "try breaking up your input text."
188
+ )
189
+ return codes
190
+ else:
191
+ codes[stop_token_indices] = 83
192
+ stm = stop_token_indices.min().item()
193
+ codes[stm:] = 83
194
+ if stm - 3 < codes.shape[0]:
195
+ codes[-3] = 45
196
+ codes[-2] = 45
197
+ codes[-1] = 248
198
+
199
+ return codes
200
+
201
+ def parse(
202
+ self,
203
+ text_tokens: torch.Tensor,
204
+ samples: List[torch.Tensor],
205
+ stop_mel_token: int,
206
+ seed: int,
207
+ ) -> torch.Tensor:
208
+ self.deterministic_state(seed=seed)
209
+ clip_results = []
210
+ text_tokens = copy(text_tokens).to("cuda")
211
+ samples = [copy(batch).to("cuda") for batch in samples]
212
+ for batch in samples:
213
+ for i in range(batch.shape[0]):
214
+ batch[i] = self.fix_gpt_output(batch[i], stop_mel_token)
215
+
216
+ clvp = self.clvp(
217
+ text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False
218
+ )
219
+ clip_results.append(clvp)
220
+
221
+ clip_results = torch.cat(clip_results, dim=0)
222
+ samples = torch.cat(samples, dim=0)
223
+ # return samples[torch.topk(clip_results, self.K).indices].cpu().detach().numpy()
224
+ return samples[torch.topk(clip_results, self.K).indices]
225
+
226
+ @staticmethod
227
+ def deterministic_state(seed=None):
228
+ seed = int(time()) if seed is None else seed
229
+ torch.manual_seed(seed)
230
+ random.seed(seed)
231
+ return seed
232
+
233
+
234
+ class Diffusion:
235
+ def __init__(
236
+ self,
237
+ diffusion_temperature,
238
+ diffusion_iterations=30,
239
+ cond_free=True,
240
+ cond_free_k=2,
241
+ ):
242
+ self.diffusion_temperature = diffusion_temperature
243
+ self.diffusion = (
244
+ DiffusionTts(
245
+ model_channels=1024,
246
+ num_layers=10,
247
+ in_channels=100,
248
+ out_channels=200,
249
+ in_latent_channels=1024,
250
+ in_tokens=8193,
251
+ dropout=0,
252
+ use_fp16=False,
253
+ num_heads=16,
254
+ layer_drop=0,
255
+ unconditioned_percentage=0,
256
+ )
257
+ .cpu()
258
+ .eval()
259
+ )
260
+ self.diffusion.load_state_dict(
261
+ torch.load(get_model_path("diffusion_decoder.pth", MODELS_DIR))
262
+ )
263
+ self.diffuser = load_discrete_vocoder_diffuser(
264
+ desired_diffusion_steps=diffusion_iterations,
265
+ cond_free=cond_free,
266
+ cond_free_k=cond_free_k,
267
+ )
268
+
269
+ self.vocoder = UnivNetGenerator().cpu()
270
+ self.vocoder.load_state_dict(
271
+ torch.load(
272
+ get_model_path("vocoder.pth", MODELS_DIR),
273
+ map_location=torch.device("cpu"),
274
+ )["model_g"]
275
+ )
276
+ self.vocoder.eval(inference=True)
277
+ self.diffusion.to("cuda")
278
+ self.vocoder.to("cuda")
279
+ self.aligner = Wav2VecAlignment()
280
+ # state = self.deterministic_state(seed=0) #Remove after testing
281
+ self.TACOTRON_MEL_MAX = 2.3143386840820312
282
+ self.TACOTRON_MEL_MIN = -11.512925148010254
283
+
284
+ def denormalize_tacotron_mel(self, norm_mel):
285
+ return ((norm_mel + 1) / 2) * (
286
+ self.TACOTRON_MEL_MAX - self.TACOTRON_MEL_MIN
287
+ ) + self.TACOTRON_MEL_MIN
288
+
289
+ def potentially_redact(self, clip, text):
290
+ return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
291
+
292
+ @staticmethod
293
+ def deterministic_state(seed=None):
294
+ seed = int(time()) if seed is None else seed
295
+ torch.manual_seed(seed)
296
+ random.seed(seed)
297
+ return seed
298
+
299
+ def do_spectrogram_diffusion(
300
+ self,
301
+ diffusion_model,
302
+ diffuser,
303
+ latents,
304
+ conditioning_latents,
305
+ seed,
306
+ temperature=1,
307
+ verbose=False,
308
+ ):
309
+ self.deterministic_state(seed=seed)
310
+ with torch.no_grad():
311
+ output_seq_len = (
312
+ latents.shape[1] * 4 * 24000 // 22050
313
+ ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
314
+ output_shape = (latents.shape[0], 100, output_seq_len)
315
+ precomputed_embeddings = diffusion_model.timestep_independent(
316
+ latents, conditioning_latents, output_seq_len, False
317
+ )
318
+
319
+ noise = torch.randn(output_shape, device=latents.device) * temperature
320
+ mel = diffuser.p_sample_loop(
321
+ diffusion_model,
322
+ output_shape,
323
+ noise=noise,
324
+ model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
325
+ progress=verbose,
326
+ )
327
+ return self.denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
328
+
329
+ def parse(
330
+ self, best_results, best_latents, calm_token, diffusion_conditioning, text, seed
331
+ ):
332
+ self.deterministic_state(seed=seed)
333
+ best_results = copy(best_results).to("cuda")
334
+ best_latents = copy(best_latents).to("cuda")
335
+ diffusion_conditioning = copy(diffusion_conditioning).to("cuda")
336
+ wav_candidates = []
337
+ for b in range(best_results.shape[0]):
338
+
339
+ codes = best_results[b].unsqueeze(0)
340
+ latents = best_latents[b].unsqueeze(0)
341
+
342
+ ctokens = 0
343
+ for k in range(codes.shape[-1]):
344
+ if codes[0, k] == calm_token:
345
+ ctokens += 1
346
+ else:
347
+ ctokens = 0
348
+ if ctokens > 8:
349
+ latents = latents[:, :k]
350
+ break
351
+
352
+ mel = self.do_spectrogram_diffusion(
353
+ self.diffusion,
354
+ self.diffuser,
355
+ latents,
356
+ diffusion_conditioning,
357
+ seed,
358
+ temperature=self.diffusion_temperature,
359
+ verbose=False,
360
+ )
361
+ wav = self.vocoder.inference(mel)
362
+ wav_candidates.append(wav)
363
+ # wav_candidates = [self.potentially_redact(wav_candidate, text).cpu().detach().numpy() for wav_candidate in
364
+ # wav_candidates]
365
+ # TODO: Check whether wav candidates should be in numpy
366
+ wav_candidates = [
367
+ self.potentially_redact(wav_candidate, text)
368
+ for wav_candidate in wav_candidates
369
+ ]
370
+ return wav_candidates
371
+
372
+ class EndpointHandler():
373
+ def __init__(self, path="config-model.yaml"):
374
+ config = get_config_file(Path(path))
375
+ self.calm_token = 83
376
+ self.tokenizer = VoiceBpeTokenizer()
377
+ _, conditioning_latent_1 = load_voice("gabby_reading", map_location="cpu")
378
+ _, conditioning_latent_2 = load_voice("gabby_conversation", map_location="cpu")
379
+
380
+ # self.conditioning_latents1 = (latent.cpu().detach().numpy() for latent in conditioning_latent_1)
381
+ # self.conditioning_latents2 = (latent.cpu().detach().numpy() for latent in conditioning_latent_2)
382
+ self.conditioning_latents1 = (latent for latent in conditioning_latent_1)
383
+ self.conditioning_latents2 = (latent for latent in conditioning_latent_2)
384
+ (
385
+ self.auto_conditioning1,
386
+ self.diffusion_conditioning1,
387
+ ) = self.conditioning_latents1
388
+ (
389
+ self.auto_conditioning2,
390
+ self.diffusion_conditioning2,
391
+ ) = self.conditioning_latents2
392
+
393
+ self.auto_conditioning = None
394
+ self.diffusion_conditioning = None
395
+ self.gpt = Gpt(
396
+ config[GPT][NUM_AUTOREGRESSIVE_SAMPLES],
397
+ config[GPT][TOP_P],
398
+ config[GPT][TEMPERATURE],
399
+ config[GPT][LENGTH_PENALTY],
400
+ config[GPT][REPETITION_PENALTY],
401
+ config[GPT][MAX_MEL_TOKENS],
402
+ config[GPT][AUTO_REGRESSIVE_BATCH_SIZE],
403
+ )
404
+ self.clvp = clvp(config[CLVP_const]["k"])
405
+ self.diffusion = Diffusion(config[DIFFUSION][DIFFUSION_TEMPERATURE])
406
+ self.calm_token = 83
407
+ print("orchestrator setup completed")
408
+
409
+ @staticmethod
410
+ def __check_for_long_sentence(text_tokens):
411
+ assert (
412
+ text_tokens.shape[-1] < 400
413
+ ), "Too much text provided. Break the text up into separate segments and re-try inference."
414
+ # TODO: split the text into several pieces and do the generation and combine them last
415
+
416
+ @staticmethod
417
+ def deterministic_state(seed=None):
418
+ seed = int(time()) if seed is None else seed
419
+ torch.manual_seed(seed)
420
+ random.seed(seed)
421
+ return seed
422
+
423
+ def preprocess_text(self, text: Text):
424
+ torch_tensor = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0)
425
+ return torch_tensor
426
+
427
+ def parse(self, res):
428
+ print("parsing")
429
+ file_name = hashlib.sha1(str(datetime.now()).encode("UTF-8"))
430
+ res = [torch.Tensor(copy(split)).squeeze(0).cpu() for split in res]
431
+ res = [torch.flatten(split) for split in res]
432
+ merged_audio_tensor = torch.cat(res).reshape(1, -1)
433
+ torchaudio.save(f"./{file_name.hexdigest()}.wav", merged_audio_tensor, 24000)
434
+ # torchaudio.save(f"./{file_name.hexdigest()}.wav", torch.Tensor(copy(res)).squeeze(0).cpu(), 24000)
435
+ return file_name.hexdigest()
436
+
437
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
438
+ voice = data["voice"]
439
+ text = data["text"]
440
+ seed = data["seed"]
441
+ if voice == "gabby_reading":
442
+ self.auto_conditioning = self.auto_conditioning1
443
+ self.diffusion_conditioning = self.diffusion_conditioning1
444
+ elif voice == "gabby_conversation":
445
+ self.auto_conditioning = self.auto_conditioning2
446
+ self.diffusion_conditioning = self.diffusion_conditioning2
447
+
448
+ self.deterministic_state(seed=seed)
449
+ text_tokens = self.preprocess_text(
450
+ text
451
+ ) # preprocess the in-coming text into tokens
452
+ self.__check_for_long_sentence(text_tokens)
453
+ # text_tokens = text_tokens.cpu().detach().numpy()
454
+ samples, stop_mel_token = self.gpt.parse_inference(
455
+ self.auto_conditioning, text_tokens, seed
456
+ )
457
+ best_sample = self.clvp.parse(text_tokens, samples, stop_mel_token, seed)
458
+ best_latent = self.gpt.parse(
459
+ self.auto_conditioning, text_tokens, best_sample, seed
460
+ )
461
+ wav_candidates = self.diffusion.parse(
462
+ best_sample,
463
+ best_latent,
464
+ self.calm_token,
465
+ self.diffusion_conditioning,
466
+ text,
467
+ seed,
468
+ )
469
+ if len(wav_candidates) > 1:
470
+ res = wav_candidates
471
+ else:
472
+ res = wav_candidates[0]
473
+
474
+ buffered = BytesIO()
475
+ self.parse(res)
476
+ img_str = base64.b64encode(buffered.getvalue())
477
+
478
+ # postprocess the prediction
479
+ return {"audio": img_str.decode()}
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm~=4.64.0
2
+ rotary_embedding_torch
3
+ transformers~=4.21.2
4
+ tokenizers~=0.12.1
5
+ inflect~=6.0.0
6
+ progressbar~=2.5
7
+ einops~=0.4.1
8
+ unidecode~=1.3.4
9
+ scipy~=1.9.1
10
+ librosa~=0.9.2
11
+ numba==0.48.0
12
+ ffmpeg
13
+ fastapi~=0.81.0
14
+ ray[serve]~=2.0.0
15
+ PyYAML~=6.0
16
+ starlette~=0.19.1
17
+ numpy~=1.23.2
18
+ setuptools~=60.2.0
ruth_tts_transformer/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
ruth_tts_transformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ VERSION = "0.0.27"
2
+
ruth_tts_transformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (164 Bytes). View file
 
ruth_tts_transformer/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (178 Bytes). View file
 
ruth_tts_transformer/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (156 Bytes). View file
 
ruth_tts_transformer/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (182 Bytes). View file
 
ruth_tts_transformer/data/latents.pkl ADDED
Binary file (510 kB). View file
 
ruth_tts_transformer/data/layman.txt ADDED
File without changes
ruth_tts_transformer/data/mel_norms.pth ADDED
Binary file (1.07 kB). View file
 
ruth_tts_transformer/data/riding_hood.txt ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her. It suited the girl so extremely well that everybody called her Little Red Riding Hood.
2
+ One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter."
3
+
4
+ Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village.
5
+
6
+ As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest. He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother."
7
+
8
+ "Does she live far off?" said the wolf
9
+
10
+ "Oh I say," answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village."
11
+
12
+ "Well," said the wolf, "and I'll go and see her too. I'll go this way and go you that, and we shall see who will be there first."
13
+
14
+ The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers. It was not long before the wolf arrived at the old woman's house. He knocked at the door: tap, tap.
15
+
16
+ "Who's there?"
17
+
18
+ "Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother."
19
+
20
+ The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up."
21
+
22
+ The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten. He then shut the door and got into the grandmother's bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap.
23
+
24
+ "Who's there?"
25
+
26
+ Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you."
27
+
28
+ The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up."
29
+
30
+ Little Red Riding Hood pulled the bobbin, and the door opened.
31
+
32
+ The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me."
33
+
34
+ Little Red Riding Hood took off her clothes and got into bed. She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!"
35
+
36
+ "All the better to hug you with, my dear."
37
+
38
+ "Grandmother, what big legs you have!"
39
+
40
+ "All the better to run with, my child."
41
+
42
+ "Grandmother, what big ears you have!"
43
+
44
+ "All the better to hear with, my child."
45
+
46
+ "Grandmother, what big eyes you have!"
47
+
48
+ "All the better to see with, my child."
49
+
50
+ "Grandmother, what big teeth you have got!"
51
+
52
+ "All the better to eat you up with."
53
+
54
+ And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up.
ruth_tts_transformer/data/seal_copypasta.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al kayda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire U S armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the U S A and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little "clever" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo.
ruth_tts_transformer/data/tokenizer.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
ruth_tts_transformer/models/__init__.py ADDED
File without changes
ruth_tts_transformer/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (153 Bytes). View file
 
ruth_tts_transformer/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (145 Bytes). View file
 
ruth_tts_transformer/models/__pycache__/arch_util.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
ruth_tts_transformer/models/__pycache__/arch_util.cpython-38.pyc ADDED
Binary file (11.4 kB). View file
 
ruth_tts_transformer/models/__pycache__/autoregressive.cpython-310.pyc ADDED
Binary file (17.8 kB). View file
 
ruth_tts_transformer/models/__pycache__/autoregressive.cpython-38.pyc ADDED
Binary file (17.8 kB). View file
 
ruth_tts_transformer/models/__pycache__/clvp.cpython-310.pyc ADDED
Binary file (4.13 kB). View file
 
ruth_tts_transformer/models/__pycache__/clvp.cpython-38.pyc ADDED
Binary file (4.11 kB). View file
 
ruth_tts_transformer/models/__pycache__/diffusion_decoder.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
ruth_tts_transformer/models/__pycache__/diffusion_decoder.cpython-38.pyc ADDED
Binary file (10.1 kB). View file
 
ruth_tts_transformer/models/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (7.85 kB). View file
 
ruth_tts_transformer/models/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (7.84 kB). View file
 
ruth_tts_transformer/models/__pycache__/vocoder.cpython-310.pyc ADDED
Binary file (9.17 kB). View file
 
ruth_tts_transformer/models/__pycache__/vocoder.cpython-38.pyc ADDED
Binary file (9.13 kB). View file
 
ruth_tts_transformer/models/__pycache__/xtransformers.cpython-310.pyc ADDED
Binary file (34.7 kB). View file
 
ruth_tts_transformer/models/__pycache__/xtransformers.cpython-38.pyc ADDED
Binary file (35.3 kB). View file
 
ruth_tts_transformer/models/arch_util.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from ruth_tts_transformer.models.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
10
+
11
+
12
+ def zero_module(module):
13
+ """
14
+ Zero out the parameters of a module and return it.
15
+ """
16
+ for p in module.parameters():
17
+ p.detach().zero_()
18
+ return module
19
+
20
+
21
+ class GroupNorm32(nn.GroupNorm):
22
+ def forward(self, x):
23
+ return super().forward(x.float()).type(x.dtype)
24
+
25
+
26
+ def normalization(channels):
27
+ """
28
+ Make a standard normalization layer.
29
+
30
+ :param channels: number of input channels.
31
+ :return: an nn.Module for normalization.
32
+ """
33
+ groups = 32
34
+ if channels <= 16:
35
+ groups = 8
36
+ elif channels <= 64:
37
+ groups = 16
38
+ while channels % groups != 0:
39
+ groups = int(groups / 2)
40
+ assert groups > 2
41
+ return GroupNorm32(groups, channels)
42
+
43
+
44
+ class QKVAttentionLegacy(nn.Module):
45
+ """
46
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
47
+ """
48
+
49
+ def __init__(self, n_heads):
50
+ super().__init__()
51
+ self.n_heads = n_heads
52
+
53
+ def forward(self, qkv, mask=None, rel_pos=None):
54
+ """
55
+ Apply QKV attention.
56
+
57
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
58
+ :return: an [N x (H * C) x T] tensor after attention.
59
+ """
60
+ bs, width, length = qkv.shape
61
+ assert width % (3 * self.n_heads) == 0
62
+ ch = width // (3 * self.n_heads)
63
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
64
+ scale = 1 / math.sqrt(math.sqrt(ch))
65
+ weight = torch.einsum(
66
+ "bct,bcs->bts", q * scale, k * scale
67
+ ) # More stable with f16 than dividing afterwards
68
+ if rel_pos is not None:
69
+ weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ if mask is not None:
72
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
73
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
74
+ weight = weight * mask
75
+ a = torch.einsum("bts,bcs->bct", weight, v)
76
+
77
+ return a.reshape(bs, -1, length)
78
+
79
+
80
+ class AttentionBlock(nn.Module):
81
+ """
82
+ An attention block that allows spatial positions to attend to each other.
83
+
84
+ Originally ported from here, but adapted to the N-d case.
85
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ channels,
91
+ num_heads=1,
92
+ num_head_channels=-1,
93
+ do_checkpoint=True,
94
+ relative_pos_embeddings=False,
95
+ ):
96
+ super().__init__()
97
+ self.channels = channels
98
+ self.do_checkpoint = do_checkpoint
99
+ if num_head_channels == -1:
100
+ self.num_heads = num_heads
101
+ else:
102
+ assert (
103
+ channels % num_head_channels == 0
104
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
105
+ self.num_heads = channels // num_head_channels
106
+ self.norm = normalization(channels)
107
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
108
+ # split heads before split qkv
109
+ self.attention = QKVAttentionLegacy(self.num_heads)
110
+
111
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
112
+ if relative_pos_embeddings:
113
+ self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
114
+ else:
115
+ self.relative_pos_embeddings = None
116
+
117
+ def forward(self, x, mask=None):
118
+ b, c, *spatial = x.shape
119
+ x = x.reshape(b, c, -1)
120
+ qkv = self.qkv(self.norm(x))
121
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
122
+ h = self.proj_out(h)
123
+ return (x + h).reshape(b, c, *spatial)
124
+
125
+
126
+ class Upsample(nn.Module):
127
+ """
128
+ An upsampling layer with an optional convolution.
129
+
130
+ :param channels: channels in the inputs and outputs.
131
+ :param use_conv: a bool determining if a convolution is applied.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv, out_channels=None, factor=4):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.factor = factor
140
+ if use_conv:
141
+ ksize = 5
142
+ pad = 2
143
+ self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
144
+
145
+ def forward(self, x):
146
+ assert x.shape[1] == self.channels
147
+ x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
148
+ if self.use_conv:
149
+ x = self.conv(x)
150
+ return x
151
+
152
+
153
+ class Downsample(nn.Module):
154
+ """
155
+ A downsampling layer with an optional convolution.
156
+
157
+ :param channels: channels in the inputs and outputs.
158
+ :param use_conv: a bool determining if a convolution is applied.
159
+ """
160
+
161
+ def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
162
+ super().__init__()
163
+ self.channels = channels
164
+ self.out_channels = out_channels or channels
165
+ self.use_conv = use_conv
166
+
167
+ stride = factor
168
+ if use_conv:
169
+ self.op = nn.Conv1d(
170
+ self.channels, self.out_channels, ksize, stride=stride, padding=pad
171
+ )
172
+ else:
173
+ assert self.channels == self.out_channels
174
+ self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
175
+
176
+ def forward(self, x):
177
+ assert x.shape[1] == self.channels
178
+ return self.op(x)
179
+
180
+
181
+ class ResBlock(nn.Module):
182
+ def __init__(
183
+ self,
184
+ channels,
185
+ dropout,
186
+ out_channels=None,
187
+ use_conv=False,
188
+ use_scale_shift_norm=False,
189
+ up=False,
190
+ down=False,
191
+ kernel_size=3,
192
+ ):
193
+ super().__init__()
194
+ self.channels = channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_scale_shift_norm = use_scale_shift_norm
199
+ padding = 1 if kernel_size == 3 else 2
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False)
211
+ self.x_upd = Upsample(channels, False)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False)
214
+ self.x_upd = Downsample(channels, False)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.out_layers = nn.Sequential(
219
+ normalization(self.out_channels),
220
+ nn.SiLU(),
221
+ nn.Dropout(p=dropout),
222
+ zero_module(
223
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
224
+ ),
225
+ )
226
+
227
+ if self.out_channels == channels:
228
+ self.skip_connection = nn.Identity()
229
+ elif use_conv:
230
+ self.skip_connection = nn.Conv1d(
231
+ channels, self.out_channels, kernel_size, padding=padding
232
+ )
233
+ else:
234
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
235
+
236
+ def forward(self, x):
237
+ if self.updown:
238
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
239
+ h = in_rest(x)
240
+ h = self.h_upd(h)
241
+ x = self.x_upd(x)
242
+ h = in_conv(h)
243
+ else:
244
+ h = self.in_layers(x)
245
+ h = self.out_layers(h)
246
+ return self.skip_connection(x) + h
247
+
248
+
249
+ class AudioMiniEncoder(nn.Module):
250
+ def __init__(self,
251
+ spec_dim,
252
+ embedding_dim,
253
+ base_channels=128,
254
+ depth=2,
255
+ resnet_blocks=2,
256
+ attn_blocks=4,
257
+ num_attn_heads=4,
258
+ dropout=0,
259
+ downsample_factor=2,
260
+ kernel_size=3):
261
+ super().__init__()
262
+ self.init = nn.Sequential(
263
+ nn.Conv1d(spec_dim, base_channels, 3, padding=1)
264
+ )
265
+ ch = base_channels
266
+ res = []
267
+ for l in range(depth):
268
+ for r in range(resnet_blocks):
269
+ res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
270
+ res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
271
+ ch *= 2
272
+ self.res = nn.Sequential(*res)
273
+ self.final = nn.Sequential(
274
+ normalization(ch),
275
+ nn.SiLU(),
276
+ nn.Conv1d(ch, embedding_dim, 1)
277
+ )
278
+ attn = []
279
+ for a in range(attn_blocks):
280
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
281
+ self.attn = nn.Sequential(*attn)
282
+ self.dim = embedding_dim
283
+
284
+ def forward(self, x):
285
+ h = self.init(x)
286
+ h = self.res(h)
287
+ h = self.final(h)
288
+ h = self.attn(h)
289
+ return h[:, :, 0]
290
+
291
+
292
+ DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth')
293
+
294
+
295
+ class TorchMelSpectrogram(nn.Module):
296
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000,
297
+ sampling_rate=22050, normalize=False, mel_norm_file=DEFAULT_MEL_NORM_FILE):
298
+ super().__init__()
299
+ # These are the default tacotron values for the MEL spectrogram.
300
+ self.filter_length = filter_length
301
+ self.hop_length = hop_length
302
+ self.win_length = win_length
303
+ self.n_mel_channels = n_mel_channels
304
+ self.mel_fmin = mel_fmin
305
+ self.mel_fmax = mel_fmax
306
+ self.sampling_rate = sampling_rate
307
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
308
+ win_length=self.win_length, power=2, normalized=normalize,
309
+ sample_rate=self.sampling_rate, f_min=self.mel_fmin,
310
+ f_max=self.mel_fmax, n_mels=self.n_mel_channels,
311
+ norm="slaney")
312
+ self.mel_norm_file = mel_norm_file
313
+ if self.mel_norm_file is not None:
314
+ self.mel_norms = torch.load(self.mel_norm_file)
315
+ else:
316
+ self.mel_norms = None
317
+
318
+ def forward(self, inp):
319
+ if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
320
+ inp = inp.squeeze(1)
321
+ assert len(inp.shape) == 2
322
+ self.mel_stft = self.mel_stft.to(inp.device)
323
+ mel = self.mel_stft(inp)
324
+ # Perform dynamic range compression
325
+ mel = torch.log(torch.clamp(mel, min=1e-5))
326
+ if self.mel_norms is not None:
327
+ self.mel_norms = self.mel_norms.to(mel.device)
328
+ mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
329
+ return mel
330
+
331
+
332
+ class CheckpointedLayer(nn.Module):
333
+ """
334
+ Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
335
+ checkpoint for all other args.
336
+ """
337
+ def __init__(self, wrap):
338
+ super().__init__()
339
+ self.wrap = wrap
340
+
341
+ def forward(self, x, *args, **kwargs):
342
+ for k, v in kwargs.items():
343
+ assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
344
+ partial = functools.partial(self.wrap, **kwargs)
345
+ return partial(x, *args)
346
+
347
+
348
+ class CheckpointedXTransformerEncoder(nn.Module):
349
+ """
350
+ Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
351
+ to channels-last that XTransformer expects.
352
+ """
353
+ def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
354
+ super().__init__()
355
+ self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
356
+ self.needs_permute = needs_permute
357
+ self.exit_permute = exit_permute
358
+
359
+ if not checkpoint:
360
+ return
361
+ for i in range(len(self.transformer.attn_layers.layers)):
362
+ n, b, r = self.transformer.attn_layers.layers[i]
363
+ self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
364
+
365
+ def forward(self, x, **kwargs):
366
+ if self.needs_permute:
367
+ x = x.permute(0,2,1)
368
+ h = self.transformer(x, **kwargs)
369
+ if self.exit_permute:
370
+ h = h.permute(0,2,1)
371
+ return h
ruth_tts_transformer/models/autoregressive.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
9
+ from ruth_tts_transformer.models.arch_util import AttentionBlock
10
+ from ruth_tts_transformer.utils.typical_sampling import TypicalLogitsWarper
11
+
12
+
13
+ def null_position_embeddings(range, dim):
14
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
15
+
16
+
17
+ class ResidualConvolutionBlock(nn.Module):
18
+
19
+ def __init__(self, chan):
20
+ super().__init__()
21
+ self.neural_network = nn.Sequential(
22
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
23
+ nn.GroupNorm(chan // 8, chan),
24
+ nn.ReLU(),
25
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
26
+ nn.GroupNorm(chan // 8, chan)
27
+ )
28
+
29
+ def forward(self, x):
30
+ return F.relu(self.neural_network(x) + x)
31
+
32
+
33
+ class GPT2InferenceModel(GPT2PreTrainedModel):
34
+ def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
35
+ super().__init__(config)
36
+ self.transformer = gpt
37
+ self.text_pos_embedding = text_pos_emb
38
+ self.embeddings = embeddings
39
+ self.lm_head = nn.Sequential(norm, linear)
40
+
41
+ # Model parallel
42
+ self.model_parallel = False
43
+ self.device_map = None
44
+ self.cached_mel_emb = None
45
+
46
+ def parallelize(self, device_map=None):
47
+ self.device_map = (
48
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
49
+ if device_map is None
50
+ else device_map
51
+ )
52
+ assert_device_map(self.device_map, len(self.transformer.h))
53
+ self.transformer.parallelize(self.device_map)
54
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
55
+ self.model_parallel = True
56
+
57
+ def deparallelize(self):
58
+ self.transformer.deparallelize()
59
+ self.transformer = self.transformer.to("cpu")
60
+ self.lm_head = self.lm_head.to("cpu")
61
+ self.model_parallel = False
62
+ torch.cuda.empty_cache()
63
+
64
+ def get_output_embeddings(self):
65
+ return self.lm_head
66
+
67
+ def set_output_embeddings(self, new_embeddings):
68
+ self.lm_head = new_embeddings
69
+
70
+ def store_mel_emb(self, mel_emb):
71
+ self.cached_mel_emb = mel_emb
72
+
73
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
74
+
75
+ token_type_ids = kwargs.get("token_type_ids", None)
76
+ # only last token for inputs_ids if past is defined in kwargs
77
+ if past:
78
+ input_ids = input_ids[:, -1].unsqueeze(-1)
79
+ if token_type_ids is not None:
80
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
81
+
82
+ attention_mask = kwargs.get("attention_mask", None)
83
+ position_ids = kwargs.get("position_ids", None)
84
+
85
+ if attention_mask is not None and position_ids is None:
86
+ # create position_ids on the fly for batch generation
87
+ position_ids = attention_mask.long().cumsum(-1) - 1
88
+ position_ids.masked_fill_(attention_mask == 0, 1)
89
+ if past:
90
+ position_ids = position_ids[:, -1].unsqueeze(-1)
91
+ else:
92
+ position_ids = None
93
+ return {
94
+ "input_ids": input_ids,
95
+ "past_key_values": past,
96
+ "use_cache": kwargs.get("use_cache"),
97
+ "position_ids": position_ids,
98
+ "attention_mask": attention_mask,
99
+ "token_type_ids": token_type_ids,
100
+ }
101
+
102
+ def forward(
103
+ self,
104
+ input_ids=None,
105
+ past_key_values=None,
106
+ attention_mask=None,
107
+ token_type_ids=None,
108
+ position_ids=None,
109
+ head_mask=None,
110
+ inputs_embeds=None,
111
+ encoder_hidden_states=None,
112
+ encoder_attention_mask=None,
113
+ labels=None,
114
+ use_cache=None,
115
+ output_attentions=None,
116
+ output_hidden_states=None,
117
+ return_dict=None,
118
+ ):
119
+ assert self.cached_mel_emb is not None
120
+ assert inputs_embeds is None # Not supported by this inference model.
121
+ assert labels is None # Training not supported by this inference model.
122
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
123
+
124
+ # Create embedding
125
+ mel_len = self.cached_mel_emb.shape[1]
126
+ if input_ids.shape[1] != 1:
127
+ text_inputs = input_ids[:, mel_len:]
128
+ text_emb = self.embeddings(text_inputs)
129
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
130
+ if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
131
+ mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0)
132
+ else:
133
+ mel_emb = self.cached_mel_emb
134
+ emb = torch.cat([mel_emb, text_emb], dim=1)
135
+ else:
136
+ emb = self.embeddings(input_ids)
137
+ emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - mel_len,
138
+ attention_mask.device)
139
+
140
+ transformer_outputs = self.transformer(
141
+ inputs_embeds=emb,
142
+ past_key_values=past_key_values,
143
+ attention_mask=attention_mask,
144
+ token_type_ids=token_type_ids,
145
+ position_ids=position_ids,
146
+ head_mask=head_mask,
147
+ encoder_hidden_states=encoder_hidden_states,
148
+ encoder_attention_mask=encoder_attention_mask,
149
+ use_cache=use_cache,
150
+ output_attentions=output_attentions,
151
+ output_hidden_states=output_hidden_states,
152
+ return_dict=return_dict,
153
+ )
154
+ hidden_states = transformer_outputs[0]
155
+
156
+ # Set device for model parallelism
157
+ if self.model_parallel:
158
+ torch.cuda.set_device(self.transformer.first_device)
159
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
160
+
161
+ lm_logits = self.lm_head(hidden_states)
162
+
163
+ if not return_dict:
164
+ return (lm_logits,) + transformer_outputs[1:]
165
+
166
+ return CausalLMOutputWithCrossAttentions(
167
+ loss=None,
168
+ logits=lm_logits,
169
+ past_key_values=transformer_outputs.past_key_values,
170
+ hidden_states=transformer_outputs.hidden_states,
171
+ attentions=transformer_outputs.attentions,
172
+ cross_attentions=transformer_outputs.cross_attentions,
173
+ )
174
+
175
+ @staticmethod
176
+ def _reorder_cache(past, beam_idx):
177
+ """
178
+ This function is used to re-order the :obj:`past_key_values` cache if
179
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
180
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
181
+ """
182
+ return tuple(
183
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
184
+ for layer_past in past
185
+ )
186
+
187
+
188
+ class ConditioningEncoder(nn.Module):
189
+ def __init__(self,
190
+ spec_dim,
191
+ embedding_dim,
192
+ attn_blocks=6,
193
+ num_attn_heads=4,
194
+ do_checkpointing=False,
195
+ mean=False):
196
+ super().__init__()
197
+ attn = []
198
+ self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
199
+ for a in range(attn_blocks):
200
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads))
201
+ self.attn = nn.Sequential(*attn)
202
+ self.dim = embedding_dim
203
+ self.do_checkpointing = do_checkpointing
204
+ self.mean = mean
205
+
206
+ def forward(self, x):
207
+ h = self.init(x)
208
+ h = self.attn(h)
209
+ if self.mean:
210
+ return h.mean(dim=2)
211
+ else:
212
+ return h[:, :, 0]
213
+
214
+
215
+ class LearnedPositionEmbeddings(nn.Module):
216
+ def __init__(self, seq_len, model_dim, init=.02):
217
+ super().__init__()
218
+ self.emb = nn.Embedding(seq_len, model_dim)
219
+ # Initializing this way is standard for GPT-2
220
+ self.emb.weight.data.normal_(mean=0.0, std=init)
221
+
222
+ def forward(self, x):
223
+ sl = x.shape[1]
224
+ return self.emb(torch.arange(0, sl, device=x.device))
225
+
226
+ def get_fixed_embedding(self, ind, dev):
227
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
228
+
229
+
230
+ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
231
+ """
232
+ GPT-2 implemented by the HuggingFace library.
233
+ """
234
+ from transformers import GPT2Config, GPT2Model
235
+ gpt_config = GPT2Config(vocab_size=256, # Unused.
236
+ n_positions=max_mel_seq_len + max_text_seq_len,
237
+ n_ctx=max_mel_seq_len + max_text_seq_len,
238
+ n_embd=model_dim,
239
+ n_layer=layers,
240
+ n_head=heads,
241
+ gradient_checkpointing=checkpointing,
242
+ use_cache=not checkpointing)
243
+ gpt = GPT2Model(gpt_config)
244
+ # Override the built in positional embeddings
245
+ del gpt.wpe
246
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
247
+ # Built-in token embeddings are unused.
248
+ del gpt.wte
249
+ return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len,
250
+ model_dim), \
251
+ None, None
252
+
253
+
254
+ class MelEncoder(nn.Module):
255
+ def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
256
+ super().__init__()
257
+ self.channels = channels
258
+ self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
259
+ nn.Sequential(*[ResidualConvolutionBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
260
+ nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
261
+ nn.GroupNorm(channels // 16, channels // 2),
262
+ nn.ReLU(),
263
+ nn.Sequential(*[ResidualConvolutionBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
264
+ nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
265
+ nn.GroupNorm(channels // 8, channels),
266
+ nn.ReLU(),
267
+ nn.Sequential(*[ResidualConvolutionBlock(channels) for _ in range(resblocks_per_reduction)]),
268
+ )
269
+ self.reduction = 4
270
+
271
+ def forward(self, x):
272
+ for e in self.encoder:
273
+ x = e(x)
274
+ return x.permute(0, 2, 1)
275
+
276
+
277
+ class UnifiedVoice(nn.Module):
278
+ def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250,
279
+ max_conditioning_inputs=1,
280
+ mel_length_compression=1024, number_text_tokens=256,
281
+ start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
282
+ stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
283
+ checkpointing=True, types=1):
284
+ """
285
+ Args:
286
+ layers: Number of layers in transformer stack.
287
+ model_dim: Operating dimensions of the transformer
288
+ heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
289
+ max_text_tokens: Maximum number of text tokens that will be encountered by model.
290
+ max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
291
+ max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
292
+ mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
293
+ number_text_tokens:
294
+ start_text_token:
295
+ stop_text_token:
296
+ number_mel_codes:
297
+ start_mel_token:
298
+ stop_mel_token:
299
+ train_solo_embeddings:
300
+ use_mel_codes_as_input:
301
+ checkpointing:
302
+ """
303
+ super().__init__()
304
+
305
+ self.number_text_tokens = number_text_tokens
306
+ self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
307
+ self.stop_text_token = 0
308
+ self.number_mel_codes = number_mel_codes
309
+ self.start_mel_token = start_mel_token
310
+ self.stop_mel_token = stop_mel_token
311
+ self.layers = layers
312
+ self.heads = heads
313
+ self.max_mel_tokens = max_mel_tokens
314
+ self.max_text_tokens = max_text_tokens
315
+ self.model_dim = model_dim
316
+ self.max_conditioning_inputs = max_conditioning_inputs
317
+ self.mel_length_compression = mel_length_compression
318
+ self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
319
+ self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
320
+ if use_mel_codes_as_input:
321
+ self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
322
+ else:
323
+ self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
324
+ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
325
+ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
326
+ self.max_text_tokens + 2, checkpointing)
327
+ if train_solo_embeddings:
328
+ self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
329
+ self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
330
+ else:
331
+ self.mel_solo_embedding = 0
332
+ self.text_solo_embedding = 0
333
+
334
+ self.final_norm = nn.LayerNorm(model_dim)
335
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
336
+ self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
337
+
338
+ # Initialize the embeddings per the GPT-2 scheme
339
+ embeddings = [self.text_embedding]
340
+ if use_mel_codes_as_input:
341
+ embeddings.append(self.mel_embedding)
342
+ for module in embeddings:
343
+ module.weight.data.normal_(mean=0.0, std=.02)
344
+
345
+ def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
346
+ inp = F.pad(input, (1, 0), value=start_token)
347
+ tar = F.pad(input, (0, 1), value=stop_token)
348
+ return inp, tar
349
+
350
+ def set_mel_padding(self, mel_input_tokens, wav_lengths):
351
+ """
352
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
353
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
354
+ preformatting to create a working TTS model.
355
+ """
356
+ # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
357
+ mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
358
+ for b in range(len(mel_lengths)):
359
+ actual_end = mel_lengths[
360
+ b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
361
+ if actual_end < mel_input_tokens.shape[-1]:
362
+ mel_input_tokens[b, actual_end:] = self.stop_mel_token
363
+ return mel_input_tokens
364
+
365
+ def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None,
366
+ get_attns=False, return_latent=False):
367
+ if second_inputs is not None:
368
+ emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
369
+ else:
370
+ emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
371
+
372
+ gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
373
+ if get_attns:
374
+ return gpt_out.attentions
375
+
376
+ enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
377
+ enc = self.final_norm(enc)
378
+
379
+ if return_latent:
380
+ return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1] + first_inputs.shape[
381
+ 1]], enc[:, -second_inputs.shape[1]:]
382
+
383
+ first_logits = enc[:, :first_inputs.shape[1]]
384
+ first_logits = first_head(first_logits)
385
+ first_logits = first_logits.permute(0, 2, 1)
386
+ if second_inputs is not None:
387
+ second_logits = enc[:, -second_inputs.shape[1]:]
388
+ second_logits = second_head(second_logits)
389
+ second_logits = second_logits.permute(0, 2, 1)
390
+ return first_logits, second_logits
391
+ else:
392
+ return first_logits
393
+
394
+ def get_conditioning(self, speech_conditioning_input):
395
+ speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
396
+ speech_conditioning_input.shape) == 3 else speech_conditioning_input
397
+ conds = []
398
+ for j in range(speech_conditioning_input.shape[1]):
399
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
400
+ conds = torch.stack(conds, dim=1)
401
+ conds = conds.mean(dim=1)
402
+ return conds
403
+
404
+ def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, types=None,
405
+ text_first=True, raw_mels=None, return_attentions=False,
406
+ return_latent=False, clip_inputs=True):
407
+ """
408
+ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
409
+ (actuated by `text_first`).
410
+
411
+ speech_conditioning_input: MEL float tensor, (b,1024)
412
+ text_inputs: long tensor, (b,t)
413
+ text_lengths: long tensor, (b,)
414
+ mel_inputs: long tensor, (b,m)
415
+ wav_lengths: long tensor, (b,)
416
+ raw_mels: MEL float tensor (b,80,s)
417
+
418
+ If return_attentions is specified, only logits are returned.
419
+ If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
420
+ If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
421
+ """
422
+ # Types are expressed by expanding the text embedding space.
423
+ if types is not None:
424
+ text_inputs = text_inputs * (1 + types).unsqueeze(-1)
425
+
426
+ if clip_inputs:
427
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
428
+ # chopping the inputs by the maximum actual length.
429
+ max_text_len = text_lengths.max()
430
+ text_inputs = text_inputs[:, :max_text_len]
431
+ max_mel_len = wav_lengths.max() // self.mel_length_compression
432
+ mel_codes = mel_codes[:, :max_mel_len]
433
+ if raw_mels is not None:
434
+ raw_mels = raw_mels[:, :, :max_mel_len * 4]
435
+ mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
436
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
437
+ mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
438
+
439
+ conds = speech_conditioning_latent.unsqueeze(1)
440
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token,
441
+ self.stop_text_token)
442
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
443
+ mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token,
444
+ self.stop_mel_token)
445
+ if raw_mels is not None:
446
+ mel_inp = F.pad(raw_mels, (0, 8))
447
+ else:
448
+ mel_inp = mel_codes
449
+ mel_emb = self.mel_embedding(mel_inp)
450
+ mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
451
+
452
+ if text_first:
453
+ text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head,
454
+ get_attns=return_attentions, return_latent=return_latent)
455
+ if return_latent:
456
+ return mel_logits[:,
457
+ :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
458
+ else:
459
+ mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head,
460
+ get_attns=return_attentions, return_latent=return_latent)
461
+ if return_latent:
462
+ return text_logits[:,
463
+ :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
464
+
465
+ if return_attentions:
466
+ return mel_logits
467
+ loss_text = F.cross_entropy(text_logits, text_targets.long())
468
+ loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
469
+ return loss_text.mean(), loss_mel.mean(), mel_logits
470
+
471
+ def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
472
+ max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
473
+ seq_length = self.max_mel_tokens + self.max_text_tokens + 2
474
+ if not hasattr(self, 'inference_model'):
475
+ # TODO: Decouple gpt_config from this inference model.
476
+ gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
477
+ n_positions=seq_length,
478
+ n_ctx=seq_length,
479
+ n_embd=self.model_dim,
480
+ n_layer=self.layers,
481
+ n_head=self.heads,
482
+ gradient_checkpointing=False,
483
+ use_cache=True)
484
+ self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding,
485
+ self.final_norm, self.mel_head)
486
+ self.gpt.wte = self.mel_embedding
487
+
488
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
489
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token,
490
+ self.stop_text_token)
491
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
492
+
493
+ conds = speech_conditioning_latent.unsqueeze(1)
494
+ emb = torch.cat([conds, text_emb], dim=1)
495
+ self.inference_model.store_mel_emb(emb)
496
+
497
+ fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
498
+ device=text_inputs.device)
499
+ fake_inputs[:, -1] = self.start_mel_token
500
+ trunc_index = fake_inputs.shape[1]
501
+ if input_tokens is None:
502
+ inputs = fake_inputs
503
+ else:
504
+ assert num_return_sequences % input_tokens.shape[
505
+ 0] == 0, "The number of return sequences must be divisible by the number of input sequences"
506
+ fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
507
+ input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
508
+ inputs = torch.cat([fake_inputs, input_tokens], dim=1)
509
+
510
+ logits_processor = LogitsProcessorList(
511
+ [TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
512
+ max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
513
+ gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
514
+ eos_token_id=self.stop_mel_token,
515
+ max_length=max_length, logits_processor=logits_processor,
516
+ num_return_sequences=num_return_sequences, **hf_generate_kwargs)
517
+ return gen[:, trunc_index:]
518
+
519
+
520
+ if __name__ == '__main__':
521
+ gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True,
522
+ max_conditioning_inputs=4)
523
+ l = gpt(torch.randn(2, 3, 80, 800),
524
+ torch.randint(high=120, size=(2, 120)),
525
+ torch.tensor([32, 120]),
526
+ torch.randint(high=8192, size=(2, 250)),
527
+ torch.tensor([250 * 256, 195 * 256]))
528
+ gpt.text_forward(torch.randn(2, 80, 800), torch.randint(high=50, size=(2, 80)), torch.tensor([32, 80]))
ruth_tts_transformer/models/clvp.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import einsum
5
+
6
+ from ruth_tts_transformer.models.arch_util import CheckpointedXTransformerEncoder
7
+ from ruth_tts_transformer.models.transformer import Transformer
8
+ from ruth_tts_transformer.models.xtransformers import Encoder
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def masked_mean(t, mask, dim = 1):
16
+ t = t.masked_fill(~mask[:, :, None], 0.)
17
+ return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
18
+
19
+ class CLVP(nn.Module):
20
+ """
21
+ CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
22
+ transcribed text.
23
+
24
+ Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ *,
30
+ dim_text=512,
31
+ dim_speech=512,
32
+ dim_latent=512,
33
+ num_text_tokens=256,
34
+ text_enc_depth=6,
35
+ text_seq_len=120,
36
+ text_heads=8,
37
+ num_speech_tokens=8192,
38
+ speech_enc_depth=6,
39
+ speech_heads=8,
40
+ speech_seq_len=250,
41
+ text_mask_percentage=0,
42
+ voice_mask_percentage=0,
43
+ wav_token_compression=1024,
44
+ use_xformers=False,
45
+ ):
46
+ super().__init__()
47
+ self.text_emb = nn.Embedding(num_text_tokens, dim_text)
48
+ self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
49
+
50
+ self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
51
+ self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
52
+
53
+ if use_xformers:
54
+ self.text_transformer = CheckpointedXTransformerEncoder(
55
+ needs_permute=False,
56
+ exit_permute=False,
57
+ max_seq_len=-1,
58
+ attn_layers=Encoder(
59
+ dim=dim_text,
60
+ depth=text_enc_depth,
61
+ heads=text_heads,
62
+ ff_dropout=.1,
63
+ ff_mult=2,
64
+ attn_dropout=.1,
65
+ use_rmsnorm=True,
66
+ ff_glu=True,
67
+ rotary_pos_emb=True,
68
+ ))
69
+ self.speech_transformer = CheckpointedXTransformerEncoder(
70
+ needs_permute=False,
71
+ exit_permute=False,
72
+ max_seq_len=-1,
73
+ attn_layers=Encoder(
74
+ dim=dim_speech,
75
+ depth=speech_enc_depth,
76
+ heads=speech_heads,
77
+ ff_dropout=.1,
78
+ ff_mult=2,
79
+ attn_dropout=.1,
80
+ use_rmsnorm=True,
81
+ ff_glu=True,
82
+ rotary_pos_emb=True,
83
+ ))
84
+ else:
85
+ self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
86
+ heads=text_heads)
87
+ self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
88
+ depth=speech_enc_depth, heads=speech_heads)
89
+
90
+ self.temperature = nn.Parameter(torch.tensor(1.))
91
+ self.text_mask_percentage = text_mask_percentage
92
+ self.voice_mask_percentage = voice_mask_percentage
93
+ self.wav_token_compression = wav_token_compression
94
+ self.xformers = use_xformers
95
+ if not use_xformers:
96
+ self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
97
+ self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
98
+
99
+ def forward(
100
+ self,
101
+ text,
102
+ speech_tokens,
103
+ return_loss=False
104
+ ):
105
+ b, device = text.shape[0], text.device
106
+ if self.training:
107
+ text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
108
+ voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
109
+ else:
110
+ text_mask = torch.ones_like(text.float()).bool()
111
+ voice_mask = torch.ones_like(speech_tokens.float()).bool()
112
+
113
+ text_emb = self.text_emb(text)
114
+ speech_emb = self.speech_emb(speech_tokens)
115
+
116
+ if not self.xformers:
117
+ text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
118
+ speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
119
+
120
+ enc_text = self.text_transformer(text_emb, mask=text_mask)
121
+ enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
122
+
123
+ text_latents = masked_mean(enc_text, text_mask, dim=1)
124
+ speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
125
+
126
+ text_latents = self.to_text_latent(text_latents)
127
+ speech_latents = self.to_speech_latent(speech_latents)
128
+
129
+ text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
130
+
131
+ temp = self.temperature.exp()
132
+
133
+ if not return_loss:
134
+ sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
135
+ return sim
136
+
137
+ sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
138
+ labels = torch.arange(b, device=device)
139
+ loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
140
+ return loss
141
+
142
+
143
+ if __name__ == '__main__':
144
+ clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2)
145
+ clip(torch.randint(0,256,(2,120)),
146
+ torch.tensor([50,100]),
147
+ torch.randint(0,8192,(2,250)),
148
+ torch.tensor([101,102]),
149
+ return_loss=True)
150
+ nonloss = clip(torch.randint(0,256,(2,120)),
151
+ torch.tensor([50,100]),
152
+ torch.randint(0,8192,(2,250)),
153
+ torch.tensor([101,102]),
154
+ return_loss=False)
155
+ print(nonloss.shape)
ruth_tts_transformer/models/diffusion_decoder.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from abc import abstractmethod
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import autocast
9
+
10
+ from ruth_tts_transformer.models.arch_util import normalization, AttentionBlock
11
+
12
+
13
+ def is_latent(t):
14
+ return t.dtype == torch.float
15
+
16
+
17
+ def is_sequence(t):
18
+ return t.dtype == torch.long
19
+
20
+
21
+ def timestep_embedding(timesteps, dim, max_period=10000):
22
+ """
23
+ Create sinusoidal timestep embeddings.
24
+
25
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
26
+ These may be fractional.
27
+ :param dim: the dimension of the output.
28
+ :param max_period: controls the minimum frequency of the embeddings.
29
+ :return: an [N x dim] Tensor of positional embeddings.
30
+ """
31
+ half = dim // 2
32
+ freqs = torch.exp(
33
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
34
+ ).to(device=timesteps.device)
35
+ args = timesteps[:, None].float() * freqs[None]
36
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
37
+ if dim % 2:
38
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
39
+ return embedding
40
+
41
+
42
+ class TimestepBlock(nn.Module):
43
+ @abstractmethod
44
+ def forward(self, x, emb):
45
+ """
46
+ Apply the module to `x` given `emb` timestep embeddings.
47
+ """
48
+
49
+
50
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
51
+ def forward(self, x, emb):
52
+ for layer in self:
53
+ if isinstance(layer, TimestepBlock):
54
+ x = layer(x, emb)
55
+ else:
56
+ x = layer(x)
57
+ return x
58
+
59
+
60
+ class ResBlock(TimestepBlock):
61
+ def __init__(
62
+ self,
63
+ channels,
64
+ emb_channels,
65
+ dropout,
66
+ out_channels=None,
67
+ dims=2,
68
+ kernel_size=3,
69
+ efficient_config=True,
70
+ use_scale_shift_norm=False,
71
+ ):
72
+ super().__init__()
73
+ self.channels = channels
74
+ self.emb_channels = emb_channels
75
+ self.dropout = dropout
76
+ self.out_channels = out_channels or channels
77
+ self.use_scale_shift_norm = use_scale_shift_norm
78
+ padding = {1: 0, 3: 1, 5: 2}[kernel_size]
79
+ eff_kernel = 1 if efficient_config else 3
80
+ eff_padding = 0 if efficient_config else 1
81
+
82
+ self.in_layers = nn.Sequential(
83
+ normalization(channels),
84
+ nn.SiLU(),
85
+ nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
86
+ )
87
+
88
+ self.emb_layers = nn.Sequential(
89
+ nn.SiLU(),
90
+ nn.Linear(
91
+ emb_channels,
92
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
93
+ ),
94
+ )
95
+ self.out_layers = nn.Sequential(
96
+ normalization(self.out_channels),
97
+ nn.SiLU(),
98
+ nn.Dropout(p=dropout),
99
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
100
+ )
101
+
102
+ if self.out_channels == channels:
103
+ self.skip_connection = nn.Identity()
104
+ else:
105
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
106
+
107
+ def forward(self, x, emb):
108
+ h = self.in_layers(x)
109
+ emb_out = self.emb_layers(emb).type(h.dtype)
110
+ while len(emb_out.shape) < len(h.shape):
111
+ emb_out = emb_out[..., None]
112
+ if self.use_scale_shift_norm:
113
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
114
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
115
+ h = out_norm(h) * (1 + scale) + shift
116
+ h = out_rest(h)
117
+ else:
118
+ h = h + emb_out
119
+ h = self.out_layers(h)
120
+ return self.skip_connection(x) + h
121
+
122
+
123
+ class DiffusionLayer(TimestepBlock):
124
+ def __init__(self, model_channels, dropout, num_heads):
125
+ super().__init__()
126
+ self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1,
127
+ use_scale_shift_norm=True)
128
+ self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
129
+
130
+ def forward(self, x, time_emb):
131
+ y = self.resblk(x, time_emb)
132
+ return self.attn(y)
133
+
134
+
135
+ class DiffusionTts(nn.Module):
136
+ def __init__(
137
+ self,
138
+ model_channels=512,
139
+ num_layers=8,
140
+ in_channels=100,
141
+ in_latent_channels=512,
142
+ in_tokens=8193,
143
+ out_channels=200, # mean and variance
144
+ dropout=0,
145
+ use_fp16=False,
146
+ num_heads=16,
147
+ # Parameters for regularization.
148
+ layer_drop=.1,
149
+ unconditioned_percentage=.1,
150
+ # This implements a mechanism similar to what is used in classifier-free training.
151
+ ):
152
+ super().__init__()
153
+
154
+ self.in_channels = in_channels
155
+ self.model_channels = model_channels
156
+ self.out_channels = out_channels
157
+ self.dropout = dropout
158
+ self.num_heads = num_heads
159
+ self.unconditioned_percentage = unconditioned_percentage
160
+ self.enable_fp16 = use_fp16
161
+ self.layer_drop = layer_drop
162
+
163
+ self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
164
+ self.time_embed = nn.Sequential(
165
+ nn.Linear(model_channels, model_channels),
166
+ nn.SiLU(),
167
+ nn.Linear(model_channels, model_channels),
168
+ )
169
+
170
+ # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
171
+ # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
172
+ # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
173
+ # transformer network.
174
+ self.code_embedding = nn.Embedding(in_tokens, model_channels)
175
+ self.code_converter = nn.Sequential(
176
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
177
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
178
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
179
+ )
180
+ self.code_norm = normalization(model_channels)
181
+ self.latent_conditioner = nn.Sequential(
182
+ nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
183
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
184
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
185
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
186
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
187
+ )
188
+ self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
189
+ nn.Conv1d(model_channels, model_channels * 2, 3, padding=1, stride=2),
190
+ AttentionBlock(model_channels * 2, num_heads,
191
+ relative_pos_embeddings=True, do_checkpoint=False),
192
+ AttentionBlock(model_channels * 2, num_heads,
193
+ relative_pos_embeddings=True, do_checkpoint=False),
194
+ AttentionBlock(model_channels * 2, num_heads,
195
+ relative_pos_embeddings=True, do_checkpoint=False),
196
+ AttentionBlock(model_channels * 2, num_heads,
197
+ relative_pos_embeddings=True, do_checkpoint=False),
198
+ AttentionBlock(model_channels * 2, num_heads,
199
+ relative_pos_embeddings=True, do_checkpoint=False))
200
+ self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
201
+ self.conditioning_timestep_integrator = TimestepEmbedSequential(
202
+ DiffusionLayer(model_channels, dropout, num_heads),
203
+ DiffusionLayer(model_channels, dropout, num_heads),
204
+ DiffusionLayer(model_channels, dropout, num_heads),
205
+ )
206
+
207
+ self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1)
208
+ self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
209
+
210
+ self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
211
+ [ResBlock(model_channels, model_channels, dropout, dims=1,
212
+ use_scale_shift_norm=True) for _ in range(3)])
213
+
214
+ self.out = nn.Sequential(
215
+ normalization(model_channels),
216
+ nn.SiLU(),
217
+ nn.Conv1d(model_channels, out_channels, 3, padding=1),
218
+ )
219
+
220
+ def get_grad_norm_parameter_groups(self):
221
+ groups = {
222
+ 'minicoder': list(self.contextual_embedder.parameters()),
223
+ 'layers': list(self.layers.parameters()),
224
+ 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(
225
+ self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()),
226
+ 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(
227
+ self.integrating_conv.parameters()),
228
+ 'time_embed': list(self.time_embed.parameters()),
229
+ }
230
+ return groups
231
+
232
+ def get_conditioning(self, conditioning_input):
233
+ speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
234
+ conditioning_input.shape) == 3 else conditioning_input
235
+ conds = []
236
+ for j in range(speech_conditioning_input.shape[1]):
237
+ conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
238
+ conds = torch.cat(conds, dim=-1)
239
+ conds = conds.mean(dim=-1)
240
+ return conds
241
+
242
+ def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred):
243
+ # Shuffle aligned_latent to BxCxS format
244
+ if is_latent(aligned_conditioning):
245
+ aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
246
+
247
+ cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1)
248
+ if is_latent(aligned_conditioning):
249
+ code_emb = self.latent_conditioner(aligned_conditioning)
250
+ else:
251
+ code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
252
+ code_emb = self.code_converter(code_emb)
253
+ code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
254
+
255
+ unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
256
+ # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
257
+ if self.training and self.unconditioned_percentage > 0:
258
+ unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
259
+ device=code_emb.device) < self.unconditioned_percentage
260
+ code_emb = torch.where(unconditioned_batches,
261
+ self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
262
+ code_emb)
263
+ expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
264
+
265
+ if not return_code_pred:
266
+ return expanded_code_emb
267
+ else:
268
+ mel_pred = self.mel_head(expanded_code_emb)
269
+ # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
270
+ mel_pred = mel_pred * unconditioned_batches.logical_not()
271
+ return expanded_code_emb, mel_pred
272
+
273
+ def forward(self, x, timesteps, aligned_conditioning=None, conditioning_latent=None,
274
+ precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
275
+ """
276
+ Apply the model to an input batch.
277
+
278
+ :param x: an [N x C x ...] Tensor of inputs.
279
+ :param timesteps: a 1-D batch of timesteps.
280
+ :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
281
+ :param conditioning_latent: a pre-computed conditioning latent; see get_conditioning().
282
+ :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
283
+ :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
284
+ :return: an [N x C x ...] Tensor of outputs.
285
+ """
286
+ assert precomputed_aligned_embeddings is not None or (
287
+ aligned_conditioning is not None and conditioning_latent is not None)
288
+ assert not (
289
+ return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
290
+
291
+ unused_params = []
292
+ if conditioning_free:
293
+ code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
294
+ unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
295
+ unused_params.extend(list(self.latent_conditioner.parameters()))
296
+ else:
297
+ if precomputed_aligned_embeddings is not None:
298
+ code_emb = precomputed_aligned_embeddings
299
+ else:
300
+ code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_latent, x.shape[-1],
301
+ True)
302
+ if is_latent(aligned_conditioning):
303
+ unused_params.extend(
304
+ list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
305
+ else:
306
+ unused_params.extend(list(self.latent_conditioner.parameters()))
307
+
308
+ unused_params.append(self.unconditioned_embedding)
309
+
310
+ time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
311
+ code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
312
+ x = self.inp_block(x)
313
+ x = torch.cat([x, code_emb], dim=1)
314
+ x = self.integrating_conv(x)
315
+ for i, lyr in enumerate(self.layers):
316
+ # Do layer drop where applicable. Do not drop first and last layers.
317
+ if self.training and self.layer_drop > 0 and i != 0 and i != (
318
+ len(self.layers) - 1) and random.random() < self.layer_drop:
319
+ unused_params.extend(list(lyr.parameters()))
320
+ else:
321
+ # First and last blocks will have autocast disabled for improved precision.
322
+ with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
323
+ x = lyr(x, time_emb)
324
+
325
+ x = x.float()
326
+ out = self.out(x)
327
+
328
+ # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
329
+ extraneous_addition = 0
330
+ for p in unused_params:
331
+ extraneous_addition = extraneous_addition + p.mean()
332
+ out = out + extraneous_addition * 0
333
+
334
+ if return_code_pred:
335
+ return out, mel_pred
336
+ return out
337
+
338
+
339
+ if __name__ == '__main__':
340
+ clip = torch.randn(2, 100, 400)
341
+ aligned_latent = torch.randn(2, 388, 512)
342
+ aligned_sequence = torch.randint(0, 8192, (2, 100))
343
+ cond = torch.randn(2, 100, 400)
344
+ ts = torch.LongTensor([600, 600])
345
+ model = DiffusionTts(512, layer_drop=.3, unconditioned_percentage=.5)
346
+ # Test with latent aligned conditioning
347
+ # o = model(clip, ts, aligned_latent, cond)
348
+ # Test with sequence aligned conditioning
349
+ o = model(clip, ts, aligned_sequence, cond)
ruth_tts_transformer/models/transformer.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+
7
+ # helpers
8
+
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+
14
+ def default(val, d):
15
+ return val if exists(val) else d
16
+
17
+
18
+ def cast_tuple(val, depth=1):
19
+ if isinstance(val, list):
20
+ val = tuple(val)
21
+ return val if isinstance(val, tuple) else (val,) * depth
22
+
23
+
24
+ def max_neg_value(t):
25
+ return -torch.finfo(t.dtype).max
26
+
27
+
28
+ def stable_softmax(t, dim=-1, alpha=32 ** 2):
29
+ t = t / alpha
30
+ t = t - torch.amax(t, dim=dim, keepdim=True).detach()
31
+ return (t * alpha).softmax(dim=dim)
32
+
33
+
34
+ def route_args(router, args, depth):
35
+ routed_args = [(dict(), dict()) for _ in range(depth)]
36
+ matched_keys = [key for key in args.keys() if key in router]
37
+
38
+ for key in matched_keys:
39
+ val = args[key]
40
+ for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
41
+ new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
42
+ routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
43
+ return routed_args
44
+
45
+
46
+ # classes
47
+ class SequentialSequence(nn.Module):
48
+ def __init__(self, layers, args_route={}, layer_dropout=0.):
49
+ super().__init__()
50
+ assert all(len(route) == len(layers) for route in
51
+ args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
52
+ self.layers = layers
53
+ self.args_route = args_route
54
+ self.layer_dropout = layer_dropout
55
+
56
+ def forward(self, x, **kwargs):
57
+ args = route_args(self.args_route, kwargs, len(self.layers))
58
+ layers_and_args = list(zip(self.layers, args))
59
+
60
+ for (f, g), (f_args, g_args) in layers_and_args:
61
+ x = x + f(x, **f_args)
62
+ x = x + g(x, **g_args)
63
+ return x
64
+
65
+
66
+ class DivideMax(nn.Module):
67
+ def __init__(self, dim):
68
+ super().__init__()
69
+ self.dim = dim
70
+
71
+ def forward(self, x):
72
+ maxes = x.amax(dim=self.dim, keepdim=True).detach()
73
+ return x / maxes
74
+
75
+
76
+ # https://arxiv.org/abs/2103.17239
77
+ class LayerScale(nn.Module):
78
+ def __init__(self, dim, depth, fn):
79
+ super().__init__()
80
+ if depth <= 18:
81
+ init_eps = 0.1
82
+ elif depth > 18 and depth <= 24:
83
+ init_eps = 1e-5
84
+ else:
85
+ init_eps = 1e-6
86
+
87
+ scale = torch.zeros(1, 1, dim).fill_(init_eps)
88
+ self.scale = nn.Parameter(scale)
89
+ self.fn = fn
90
+
91
+ def forward(self, x, **kwargs):
92
+ return self.fn(x, **kwargs) * self.scale
93
+
94
+
95
+ # layer norm
96
+
97
+
98
+ class PreNorm(nn.Module):
99
+ def __init__(self, dim, fn, sandwich=False):
100
+ super().__init__()
101
+ self.norm = nn.LayerNorm(dim)
102
+ self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
103
+ self.fn = fn
104
+
105
+ def forward(self, x, **kwargs):
106
+ x = self.norm(x)
107
+ x = self.fn(x, **kwargs)
108
+ return self.norm_out(x)
109
+
110
+
111
+ # feed forward
112
+
113
+
114
+ class GEGLU(nn.Module):
115
+ def forward(self, x):
116
+ x, gates = x.chunk(2, dim=-1)
117
+ return x * F.gelu(gates)
118
+
119
+
120
+ class FeedForward(nn.Module):
121
+ def __init__(self, dim, dropout=0., mult=4.):
122
+ super().__init__()
123
+ self.net = nn.Sequential(
124
+ nn.Linear(dim, dim * mult * 2),
125
+ GEGLU(),
126
+ nn.Dropout(dropout),
127
+ nn.Linear(dim * mult, dim)
128
+ )
129
+
130
+ def forward(self, x):
131
+ return self.net(x)
132
+
133
+
134
+ # Attention
135
+
136
+
137
+ class Attention(nn.Module):
138
+ def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.):
139
+ super().__init__()
140
+ inner_dim = dim_head * heads
141
+ self.heads = heads
142
+ self.seq_len = seq_len
143
+ self.scale = dim_head ** -0.5
144
+
145
+ self.causal = causal
146
+
147
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
148
+ self.to_out = nn.Sequential(
149
+ nn.Linear(inner_dim, dim),
150
+ nn.Dropout(dropout)
151
+ )
152
+
153
+ def forward(self, x, mask=None):
154
+ b, n, _, h, device = *x.shape, self.heads, x.device
155
+ softmax = torch.softmax
156
+
157
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
158
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
159
+
160
+ q = q * self.scale
161
+
162
+ dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
163
+ mask_value = max_neg_value(dots)
164
+
165
+ if exists(mask):
166
+ mask = rearrange(mask, 'b j -> b () () j')
167
+ dots.masked_fill_(~mask, mask_value)
168
+ del mask
169
+
170
+ if self.causal:
171
+ i, j = dots.shape[-2:]
172
+ mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
173
+ dots.masked_fill_(mask, mask_value)
174
+
175
+ attn = softmax(dots, dim=-1)
176
+
177
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
178
+ out = rearrange(out, 'b h n d -> b n (h d)')
179
+ out = self.to_out(out)
180
+ return out
181
+
182
+
183
+ # main transformer class
184
+ class Transformer(nn.Module):
185
+ def __init__(
186
+ self,
187
+ *,
188
+ dim,
189
+ depth,
190
+ seq_len,
191
+ causal=True,
192
+ heads=8,
193
+ dim_head=64,
194
+ ff_mult=4,
195
+ attn_dropout=0.,
196
+ ff_dropout=0.,
197
+ sparse_attn=False,
198
+ sandwich_norm=False,
199
+ ):
200
+ super().__init__()
201
+ layers = nn.ModuleList([])
202
+ sparse_layer = cast_tuple(sparse_attn, depth)
203
+
204
+ for ind, sparse_attn in zip(range(depth), sparse_layer):
205
+ attn = Attention(dim, causal=causal, seq_len=seq_len, heads=heads, dim_head=dim_head, dropout=attn_dropout)
206
+
207
+ ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
208
+
209
+ layers.append(nn.ModuleList([
210
+ LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)),
211
+ LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm))
212
+ ]))
213
+
214
+ execute_type = SequentialSequence
215
+ route_attn = ((True, False),) * depth
216
+ attn_route_map = {'mask': route_attn}
217
+
218
+ self.layers = execute_type(layers, args_route=attn_route_map)
219
+
220
+ def forward(self, x, **kwargs):
221
+ return self.layers(x, **kwargs)
ruth_tts_transformer/models/vocoder.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ MAX_WAV_VALUE = 32768.0
6
+
7
+ class KernelPredictor(torch.nn.Module):
8
+ ''' Kernel predictor for the location-variable convolutions'''
9
+
10
+ def __init__(
11
+ self,
12
+ cond_channels,
13
+ conv_in_channels,
14
+ conv_out_channels,
15
+ conv_layers,
16
+ conv_kernel_size=3,
17
+ kpnet_hidden_channels=64,
18
+ kpnet_conv_size=3,
19
+ kpnet_dropout=0.0,
20
+ kpnet_nonlinear_activation="LeakyReLU",
21
+ kpnet_nonlinear_activation_params={"negative_slope": 0.1},
22
+ ):
23
+ '''
24
+ Args:
25
+ cond_channels (int): number of channel for the conditioning sequence,
26
+ conv_in_channels (int): number of channel for the input sequence,
27
+ conv_out_channels (int): number of channel for the output sequence,
28
+ conv_layers (int): number of layers
29
+ '''
30
+ super().__init__()
31
+
32
+ self.conv_in_channels = conv_in_channels
33
+ self.conv_out_channels = conv_out_channels
34
+ self.conv_kernel_size = conv_kernel_size
35
+ self.conv_layers = conv_layers
36
+
37
+ kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
38
+ kpnet_bias_channels = conv_out_channels * conv_layers # l_b
39
+
40
+ self.input_conv = nn.Sequential(
41
+ nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
42
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
43
+ )
44
+
45
+ self.residual_convs = nn.ModuleList()
46
+ padding = (kpnet_conv_size - 1) // 2
47
+ for _ in range(3):
48
+ self.residual_convs.append(
49
+ nn.Sequential(
50
+ nn.Dropout(kpnet_dropout),
51
+ nn.utils.weight_norm(
52
+ nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
53
+ bias=True)),
54
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
55
+ nn.utils.weight_norm(
56
+ nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
57
+ bias=True)),
58
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
59
+ )
60
+ )
61
+ self.kernel_conv = nn.utils.weight_norm(
62
+ nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True))
63
+ self.bias_conv = nn.utils.weight_norm(
64
+ nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True))
65
+
66
+ def forward(self, c):
67
+ '''
68
+ Args:
69
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
70
+ '''
71
+ batch, _, cond_length = c.shape
72
+ c = self.input_conv(c)
73
+ for residual_conv in self.residual_convs:
74
+ residual_conv.to(c.device)
75
+ c = c + residual_conv(c)
76
+ k = self.kernel_conv(c)
77
+ b = self.bias_conv(c)
78
+ kernels = k.contiguous().view(
79
+ batch,
80
+ self.conv_layers,
81
+ self.conv_in_channels,
82
+ self.conv_out_channels,
83
+ self.conv_kernel_size,
84
+ cond_length,
85
+ )
86
+ bias = b.contiguous().view(
87
+ batch,
88
+ self.conv_layers,
89
+ self.conv_out_channels,
90
+ cond_length,
91
+ )
92
+
93
+ return kernels, bias
94
+
95
+ def remove_weight_norm(self):
96
+ nn.utils.remove_weight_norm(self.input_conv[0])
97
+ nn.utils.remove_weight_norm(self.kernel_conv)
98
+ nn.utils.remove_weight_norm(self.bias_conv)
99
+ for block in self.residual_convs:
100
+ nn.utils.remove_weight_norm(block[1])
101
+ nn.utils.remove_weight_norm(block[3])
102
+
103
+
104
+ class LVCBlock(torch.nn.Module):
105
+ '''the location-variable convolutions'''
106
+
107
+ def __init__(
108
+ self,
109
+ in_channels,
110
+ cond_channels,
111
+ stride,
112
+ dilations=[1, 3, 9, 27],
113
+ lReLU_slope=0.2,
114
+ conv_kernel_size=3,
115
+ cond_hop_length=256,
116
+ kpnet_hidden_channels=64,
117
+ kpnet_conv_size=3,
118
+ kpnet_dropout=0.0,
119
+ ):
120
+ super().__init__()
121
+
122
+ self.cond_hop_length = cond_hop_length
123
+ self.conv_layers = len(dilations)
124
+ self.conv_kernel_size = conv_kernel_size
125
+
126
+ self.kernel_predictor = KernelPredictor(
127
+ cond_channels=cond_channels,
128
+ conv_in_channels=in_channels,
129
+ conv_out_channels=2 * in_channels,
130
+ conv_layers=len(dilations),
131
+ conv_kernel_size=conv_kernel_size,
132
+ kpnet_hidden_channels=kpnet_hidden_channels,
133
+ kpnet_conv_size=kpnet_conv_size,
134
+ kpnet_dropout=kpnet_dropout,
135
+ kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}
136
+ )
137
+
138
+ self.convt_pre = nn.Sequential(
139
+ nn.LeakyReLU(lReLU_slope),
140
+ nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride,
141
+ padding=stride // 2 + stride % 2, output_padding=stride % 2)),
142
+ )
143
+
144
+ self.conv_blocks = nn.ModuleList()
145
+ for dilation in dilations:
146
+ self.conv_blocks.append(
147
+ nn.Sequential(
148
+ nn.LeakyReLU(lReLU_slope),
149
+ nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size,
150
+ padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)),
151
+ nn.LeakyReLU(lReLU_slope),
152
+ )
153
+ )
154
+
155
+ def forward(self, x, c):
156
+ ''' forward propagation of the location-variable convolutions.
157
+ Args:
158
+ x (Tensor): the input sequence (batch, in_channels, in_length)
159
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
160
+
161
+ Returns:
162
+ Tensor: the output sequence (batch, in_channels, in_length)
163
+ '''
164
+ _, in_channels, _ = x.shape # (B, c_g, L')
165
+
166
+ x = self.convt_pre(x) # (B, c_g, stride * L')
167
+ kernels, bias = self.kernel_predictor(c)
168
+
169
+ for i, conv in enumerate(self.conv_blocks):
170
+ output = conv(x) # (B, c_g, stride * L')
171
+
172
+ k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
173
+ b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
174
+
175
+ output = self.location_variable_convolution(output, k, b,
176
+ hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC
177
+ x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
178
+ output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU
179
+
180
+ return x
181
+
182
+ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
183
+ ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
184
+ Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
185
+ Args:
186
+ x (Tensor): the input sequence (batch, in_channels, in_length).
187
+ kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
188
+ bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
189
+ dilation (int): the dilation of convolution.
190
+ hop_size (int): the hop_size of the conditioning sequence.
191
+ Returns:
192
+ (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
193
+ '''
194
+ batch, _, in_length = x.shape
195
+ batch, _, out_channels, kernel_size, kernel_length = kernel.shape
196
+ assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
197
+
198
+ padding = dilation * int((kernel_size - 1) / 2)
199
+ x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
200
+ x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
201
+
202
+ if hop_size < dilation:
203
+ x = F.pad(x, (0, dilation), 'constant', 0)
204
+ x = x.unfold(3, dilation,
205
+ dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
206
+ x = x[:, :, :, :, :hop_size]
207
+ x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
208
+ x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
209
+
210
+ o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
211
+ o = o.to(memory_format=torch.channels_last_3d)
212
+ bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
213
+ o = o + bias
214
+ o = o.contiguous().view(batch, out_channels, -1)
215
+
216
+ return o
217
+
218
+ def remove_weight_norm(self):
219
+ self.kernel_predictor.remove_weight_norm()
220
+ nn.utils.remove_weight_norm(self.convt_pre[1])
221
+ for block in self.conv_blocks:
222
+ nn.utils.remove_weight_norm(block[1])
223
+
224
+
225
+ class UnivNetGenerator(nn.Module):
226
+ """UnivNet Generator"""
227
+
228
+ def __init__(self, noise_dim=64, channel_size=16, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
229
+ # Below are MEL configurations options that this generator requires.
230
+ hop_length=256, n_mel_channels=100):
231
+ super(UnivNetGenerator, self).__init__()
232
+ self.mel_channel = n_mel_channels
233
+ self.noise_dim = noise_dim
234
+ self.hop_length = hop_length
235
+ channel_size = channel_size
236
+ kpnet_conv_size = kpnet_conv_size
237
+
238
+ self.res_stack = nn.ModuleList()
239
+ hop_length = 1
240
+ for stride in strides:
241
+ hop_length = stride * hop_length
242
+ self.res_stack.append(
243
+ LVCBlock(
244
+ channel_size,
245
+ n_mel_channels,
246
+ stride=stride,
247
+ dilations=dilations,
248
+ lReLU_slope=lReLU_slope,
249
+ cond_hop_length=hop_length,
250
+ kpnet_conv_size=kpnet_conv_size
251
+ )
252
+ )
253
+
254
+ self.conv_pre = \
255
+ nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
256
+
257
+ self.conv_post = nn.Sequential(
258
+ nn.LeakyReLU(lReLU_slope),
259
+ nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
260
+ nn.Tanh(),
261
+ )
262
+
263
+ def forward(self, c, z):
264
+ '''
265
+ Args:
266
+ c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
267
+ z (Tensor): the noise sequence (batch, noise_dim, in_length)
268
+
269
+ '''
270
+ z = self.conv_pre(z) # (B, c_g, L)
271
+
272
+ for res_block in self.res_stack:
273
+ res_block.to(z.device)
274
+ z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
275
+
276
+ z = self.conv_post(z) # (B, 1, L * 256)
277
+
278
+ return z
279
+
280
+ def eval(self, inference=False):
281
+ super(UnivNetGenerator, self).eval()
282
+ # don't remove weight norm while validation in training loop
283
+ if inference:
284
+ self.remove_weight_norm()
285
+
286
+ def remove_weight_norm(self):
287
+ nn.utils.remove_weight_norm(self.conv_pre)
288
+
289
+ for layer in self.conv_post:
290
+ if len(layer.state_dict()) != 0:
291
+ nn.utils.remove_weight_norm(layer)
292
+
293
+ for res_block in self.res_stack:
294
+ res_block.remove_weight_norm()
295
+
296
+ def inference(self, c, z=None):
297
+ # pad input mel with zeros to cut artifact
298
+ # see https://github.com/seungwonpark/melgan/issues/8
299
+ zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
300
+ mel = torch.cat((c, zero), dim=2)
301
+
302
+ if z is None:
303
+ z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
304
+
305
+ audio = self.forward(mel, z)
306
+ audio = audio[:, :, :-(self.hop_length * 10)]
307
+ audio = audio.clamp(min=-1, max=1)
308
+ return audio
309
+
310
+
311
+ if __name__ == '__main__':
312
+ model = UnivNetGenerator()
313
+
314
+ c = torch.randn(3, 100, 10)
315
+ z = torch.randn(3, 64, 10)
316
+ print(c.shape)
317
+
318
+ y = model(c, z)
319
+ print(y.shape)
320
+ assert y.shape == torch.Size([3, 1, 2560])
321
+
322
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
323
+ print(pytorch_total_params)
ruth_tts_transformer/models/xtransformers.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import namedtuple
3
+ from functools import partial
4
+ from inspect import isfunction
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from torch import nn, einsum
10
+
11
+ DEFAULT_DIM_HEAD = 64
12
+
13
+ Intermediates = namedtuple('Intermediates', [
14
+ 'pre_softmax_attn',
15
+ 'post_softmax_attn'
16
+ ])
17
+
18
+ LayerIntermediates = namedtuple('Intermediates', [
19
+ 'hiddens',
20
+ 'attn_intermediates',
21
+ 'past_key_values',
22
+ ])
23
+
24
+
25
+ # helpers
26
+
27
+ def exists(val):
28
+ return val is not None
29
+
30
+
31
+ def default(val, d):
32
+ if exists(val):
33
+ return val
34
+ return d() if isfunction(d) else d
35
+
36
+
37
+ def cast_tuple(val, depth):
38
+ return val if isinstance(val, tuple) else (val,) * depth
39
+
40
+
41
+ class always():
42
+ def __init__(self, val):
43
+ self.val = val
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ return self.val
47
+
48
+
49
+ class not_equals():
50
+ def __init__(self, val):
51
+ self.val = val
52
+
53
+ def __call__(self, x, *args, **kwargs):
54
+ return x != self.val
55
+
56
+
57
+ class equals():
58
+ def __init__(self, val):
59
+ self.val = val
60
+
61
+ def __call__(self, x, *args, **kwargs):
62
+ return x == self.val
63
+
64
+
65
+ def max_neg_value(tensor):
66
+ return -torch.finfo(tensor.dtype).max
67
+
68
+
69
+ def l2norm(t):
70
+ return F.normalize(t, p=2, dim=-1)
71
+
72
+
73
+ # init helpers
74
+
75
+ def init_zero_(layer):
76
+ nn.init.constant_(layer.weight, 0.)
77
+ if exists(layer.bias):
78
+ nn.init.constant_(layer.bias, 0.)
79
+
80
+
81
+ # keyword argument helpers
82
+
83
+ def pick_and_pop(keys, d):
84
+ values = list(map(lambda key: d.pop(key), keys))
85
+ return dict(zip(keys, values))
86
+
87
+
88
+ def group_dict_by_key(cond, d):
89
+ return_val = [dict(), dict()]
90
+ for key in d.keys():
91
+ match = bool(cond(key))
92
+ ind = int(not match)
93
+ return_val[ind][key] = d[key]
94
+ return (*return_val,)
95
+
96
+
97
+ def string_begins_with(prefix, str):
98
+ return str.startswith(prefix)
99
+
100
+
101
+ def group_by_key_prefix(prefix, d):
102
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
103
+
104
+
105
+ def groupby_prefix_and_trim(prefix, d):
106
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
107
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
108
+ return kwargs_without_prefix, kwargs
109
+
110
+
111
+ # activations
112
+
113
+ class ReluSquared(nn.Module):
114
+ def forward(self, x):
115
+ return F.relu(x) ** 2
116
+
117
+
118
+ # positional embeddings
119
+
120
+ class AbsolutePositionalEmbedding(nn.Module):
121
+ def __init__(self, dim, max_seq_len):
122
+ super().__init__()
123
+ self.scale = dim ** -0.5
124
+ self.emb = nn.Embedding(max_seq_len, dim)
125
+
126
+ def forward(self, x):
127
+ n = torch.arange(x.shape[1], device=x.device)
128
+ pos_emb = self.emb(n)
129
+ pos_emb = rearrange(pos_emb, 'n d -> () n d')
130
+ return pos_emb * self.scale
131
+
132
+
133
+ class FixedPositionalEmbedding(nn.Module):
134
+ def __init__(self, dim):
135
+ super().__init__()
136
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
137
+ self.register_buffer('inv_freq', inv_freq)
138
+
139
+ def forward(self, x, seq_dim=1, offset=0):
140
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
141
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
142
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
143
+ return rearrange(emb, 'n d -> () n d')
144
+
145
+
146
+ class RelativePositionBias(nn.Module):
147
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
148
+ super().__init__()
149
+ self.scale = scale
150
+ self.causal = causal
151
+ self.num_buckets = num_buckets
152
+ self.max_distance = max_distance
153
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
154
+
155
+ @staticmethod
156
+ def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
157
+ ret = 0
158
+ n = -relative_position
159
+ if not causal:
160
+ num_buckets //= 2
161
+ ret += (n < 0).long() * num_buckets
162
+ n = torch.abs(n)
163
+ else:
164
+ n = torch.max(n, torch.zeros_like(n))
165
+
166
+ max_exact = num_buckets // 2
167
+ is_small = n < max_exact
168
+
169
+ val_if_large = max_exact + (
170
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
171
+ ).long()
172
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
173
+
174
+ ret += torch.where(is_small, n, val_if_large)
175
+ return ret
176
+
177
+ def forward(self, qk_dots):
178
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
179
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
180
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
181
+ rel_pos = k_pos[None, :] - q_pos[:, None]
182
+ rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
183
+ max_distance=self.max_distance)
184
+ values = self.relative_attention_bias(rp_bucket)
185
+ bias = rearrange(values, 'i j h -> () h i j')
186
+ return qk_dots + (bias * self.scale)
187
+
188
+
189
+ class AlibiPositionalBias(nn.Module):
190
+ def __init__(self, heads, **kwargs):
191
+ super().__init__()
192
+ self.heads = heads
193
+ slopes = torch.Tensor(self._get_slopes(heads))
194
+ slopes = rearrange(slopes, 'h -> () h () ()')
195
+ self.register_buffer('slopes', slopes, persistent=False)
196
+ self.register_buffer('bias', None, persistent=False)
197
+
198
+ @staticmethod
199
+ def _get_slopes(heads):
200
+ def get_slopes_power_of_2(n):
201
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
202
+ ratio = start
203
+ return [start * ratio ** i for i in range(n)]
204
+
205
+ if math.log2(heads).is_integer():
206
+ return get_slopes_power_of_2(heads)
207
+
208
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
209
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
210
+ :heads - closest_power_of_2]
211
+
212
+ def forward(self, qk_dots):
213
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
214
+
215
+ if exists(self.bias) and self.bias.shape[-1] >= j:
216
+ return qk_dots + self.bias[..., :j]
217
+
218
+ bias = torch.arange(j, device=device)
219
+ bias = rearrange(bias, 'j -> () () () j')
220
+ bias = bias * self.slopes
221
+
222
+ num_heads_unalibied = h - bias.shape[1]
223
+ bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
224
+
225
+ self.register_buffer('bias', bias, persistent=False)
226
+ return qk_dots + self.bias
227
+
228
+
229
+ class LearnedAlibiPositionalBias(AlibiPositionalBias):
230
+ def __init__(self, heads, bidirectional=False):
231
+ super().__init__(heads)
232
+ los_slopes = torch.log(self.slopes)
233
+ self.learned_logslopes = nn.Parameter(los_slopes)
234
+
235
+ self.bidirectional = bidirectional
236
+ if self.bidirectional:
237
+ self.learned_logslopes_future = nn.Parameter(los_slopes)
238
+
239
+ def forward(self, qk_dots):
240
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
241
+
242
+ def get_slopes(param):
243
+ return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
244
+
245
+ if exists(self.bias) and self.bias.shape[-1] >= j:
246
+ bias = self.bias[..., :i, :j]
247
+ else:
248
+ i_arange = torch.arange(i, device=device)
249
+ j_arange = torch.arange(j, device=device)
250
+ bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
251
+ self.register_buffer('bias', bias, persistent=False)
252
+
253
+ if self.bidirectional:
254
+ past_slopes = get_slopes(self.learned_logslopes)
255
+ future_slopes = get_slopes(self.learned_logslopes_future)
256
+ bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
257
+ else:
258
+ slopes = get_slopes(self.learned_logslopes)
259
+ bias = bias * slopes
260
+
261
+ return qk_dots + bias
262
+
263
+
264
+ class RotaryEmbedding(nn.Module):
265
+ def __init__(self, dim):
266
+ super().__init__()
267
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
268
+ self.register_buffer('inv_freq', inv_freq)
269
+
270
+ def forward(self, max_seq_len, device):
271
+ t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
272
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
273
+ emb = torch.cat((freqs, freqs), dim=-1)
274
+ return rearrange(emb, 'n d -> () () n d')
275
+
276
+
277
+ def rotate_half(x):
278
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
279
+ x1, x2 = x.unbind(dim=-2)
280
+ return torch.cat((-x2, x1), dim=-1)
281
+
282
+
283
+ def apply_rotary_pos_emb(t, freqs):
284
+ seq_len = t.shape[-2]
285
+ freqs = freqs[:, :, -seq_len:]
286
+ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
287
+
288
+
289
+ # norms
290
+
291
+ class Scale(nn.Module):
292
+ def __init__(self, value, fn):
293
+ super().__init__()
294
+ self.value = value
295
+ self.fn = fn
296
+
297
+ def forward(self, x, **kwargs):
298
+ out = self.fn(x, **kwargs)
299
+ scale_fn = lambda t: t * self.value
300
+
301
+ if not isinstance(out, tuple):
302
+ return scale_fn(out)
303
+
304
+ return (scale_fn(out[0]), *out[1:])
305
+
306
+
307
+ class Rezero(nn.Module):
308
+ def __init__(self, fn):
309
+ super().__init__()
310
+ self.fn = fn
311
+ self.g = nn.Parameter(torch.zeros(1))
312
+
313
+ def forward(self, x, **kwargs):
314
+ out = self.fn(x, **kwargs)
315
+ rezero_fn = lambda t: t * self.g
316
+
317
+ if not isinstance(out, tuple):
318
+ return rezero_fn(out)
319
+
320
+ return (rezero_fn(out[0]), *out[1:])
321
+
322
+
323
+ class ScaleNorm(nn.Module):
324
+ def __init__(self, dim, eps=1e-5):
325
+ super().__init__()
326
+ self.scale = dim ** -0.5
327
+ self.eps = eps
328
+ self.g = nn.Parameter(torch.ones(1))
329
+
330
+ def forward(self, x):
331
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
332
+ return x / norm.clamp(min=self.eps) * self.g
333
+
334
+
335
+ class RMSNorm(nn.Module):
336
+ def __init__(self, dim, eps=1e-8):
337
+ super().__init__()
338
+ self.scale = dim ** -0.5
339
+ self.eps = eps
340
+ self.g = nn.Parameter(torch.ones(dim))
341
+
342
+ def forward(self, x):
343
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
344
+ return x / norm.clamp(min=self.eps) * self.g
345
+
346
+
347
+ class RMSScaleShiftNorm(nn.Module):
348
+ def __init__(self, dim, eps=1e-8):
349
+ super().__init__()
350
+ self.scale = dim ** -0.5
351
+ self.eps = eps
352
+ self.g = nn.Parameter(torch.ones(dim))
353
+ self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
354
+
355
+ def forward(self, x, norm_scale_shift_inp):
356
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
357
+ norm = x / norm.clamp(min=self.eps) * self.g
358
+
359
+ ss_emb = self.scale_shift_process(norm_scale_shift_inp)
360
+ scale, shift = torch.chunk(ss_emb, 2, dim=1)
361
+ h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
362
+ return h
363
+
364
+
365
+ # residual and residual gates
366
+
367
+ class Residual(nn.Module):
368
+ def __init__(self, dim, scale_residual=False):
369
+ super().__init__()
370
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
371
+
372
+ def forward(self, x, residual):
373
+ if exists(self.residual_scale):
374
+ residual = residual * self.residual_scale
375
+
376
+ return x + residual
377
+
378
+
379
+ class GRUGating(nn.Module):
380
+ def __init__(self, dim, scale_residual=False):
381
+ super().__init__()
382
+ self.gru = nn.GRUCell(dim, dim)
383
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
384
+
385
+ def forward(self, x, residual):
386
+ if exists(self.residual_scale):
387
+ residual = residual * self.residual_scale
388
+
389
+ gated_output = self.gru(
390
+ rearrange(x, 'b n d -> (b n) d'),
391
+ rearrange(residual, 'b n d -> (b n) d')
392
+ )
393
+
394
+ return gated_output.reshape_as(x)
395
+
396
+
397
+ # token shifting
398
+
399
+ def shift(t, amount, mask=None):
400
+ if amount == 0:
401
+ return t
402
+
403
+ if exists(mask):
404
+ t = t.masked_fill(~mask[..., None], 0.)
405
+
406
+ return F.pad(t, (0, 0, amount, -amount), value=0.)
407
+
408
+
409
+ class ShiftTokens(nn.Module):
410
+ def __init__(self, shifts, fn):
411
+ super().__init__()
412
+ self.fn = fn
413
+ self.shifts = tuple(shifts)
414
+
415
+ def forward(self, x, **kwargs):
416
+ mask = kwargs.get('mask', None)
417
+ shifts = self.shifts
418
+ segments = len(shifts)
419
+ feats_per_shift = x.shape[-1] // segments
420
+ splitted = x.split(feats_per_shift, dim=-1)
421
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
422
+ segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
423
+ x = torch.cat((*segments_to_shift, *rest), dim=-1)
424
+ return self.fn(x, **kwargs)
425
+
426
+
427
+ # feedforward
428
+
429
+ class GLU(nn.Module):
430
+ def __init__(self, dim_in, dim_out, activation):
431
+ super().__init__()
432
+ self.act = activation
433
+ self.proj = nn.Linear(dim_in, dim_out * 2)
434
+
435
+ def forward(self, x):
436
+ x, gate = self.proj(x).chunk(2, dim=-1)
437
+ return x * self.act(gate)
438
+
439
+
440
+ class FeedForward(nn.Module):
441
+ def __init__(
442
+ self,
443
+ dim,
444
+ dim_out=None,
445
+ mult=4,
446
+ glu=False,
447
+ relu_squared=False,
448
+ post_act_ln=False,
449
+ dropout=0.,
450
+ zero_init_output=False
451
+ ):
452
+ super().__init__()
453
+ inner_dim = int(dim * mult)
454
+ dim_out = default(dim_out, dim)
455
+ activation = ReluSquared() if relu_squared else nn.GELU()
456
+
457
+ project_in = nn.Sequential(
458
+ nn.Linear(dim, inner_dim),
459
+ activation
460
+ ) if not glu else GLU(dim, inner_dim, activation)
461
+
462
+ self.net = nn.Sequential(
463
+ project_in,
464
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
465
+ nn.Dropout(dropout),
466
+ nn.Linear(inner_dim, dim_out)
467
+ )
468
+
469
+ # init last linear layer to 0
470
+ if zero_init_output:
471
+ init_zero_(self.net[-1])
472
+
473
+ def forward(self, x):
474
+ return self.net(x)
475
+
476
+
477
+ # attention.
478
+
479
+ class Attention(nn.Module):
480
+ def __init__(
481
+ self,
482
+ dim,
483
+ dim_head=DEFAULT_DIM_HEAD,
484
+ heads=8,
485
+ causal=False,
486
+ talking_heads=False,
487
+ head_scale=False,
488
+ collab_heads=False,
489
+ collab_compression=.3,
490
+ sparse_topk=None,
491
+ use_entmax15=False,
492
+ num_mem_kv=0,
493
+ dropout=0.,
494
+ on_attn=False,
495
+ gate_values=False,
496
+ zero_init_output=False,
497
+ max_attend_past=None,
498
+ qk_norm=False,
499
+ scale_init_value=None,
500
+ rel_pos_bias=False,
501
+ rel_pos_num_buckets=32,
502
+ rel_pos_max_distance=128,
503
+ ):
504
+ super().__init__()
505
+ self.scale = dim_head ** -0.5
506
+
507
+ self.heads = heads
508
+ self.causal = causal
509
+ self.max_attend_past = max_attend_past
510
+
511
+ qk_dim = v_dim = dim_head * heads
512
+
513
+ # collaborative heads
514
+ self.collab_heads = collab_heads
515
+ if self.collab_heads:
516
+ qk_dim = int(collab_compression * qk_dim)
517
+ self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
518
+
519
+ self.to_q = nn.Linear(dim, qk_dim, bias=False)
520
+ self.to_k = nn.Linear(dim, qk_dim, bias=False)
521
+ self.to_v = nn.Linear(dim, v_dim, bias=False)
522
+
523
+ self.dropout = nn.Dropout(dropout)
524
+
525
+ # add GLU gating for aggregated values, from alphafold2
526
+ self.to_v_gate = None
527
+ if gate_values:
528
+ self.to_v_gate = nn.Linear(dim, v_dim)
529
+ nn.init.constant_(self.to_v_gate.weight, 0)
530
+ nn.init.constant_(self.to_v_gate.bias, 1)
531
+
532
+ # cosine sim attention
533
+ self.qk_norm = qk_norm
534
+ if qk_norm:
535
+ scale_init_value = default(scale_init_value,
536
+ -3) # if not provided, initialize as though it were sequence length of 1024
537
+ self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
538
+
539
+ # talking heads
540
+ self.talking_heads = talking_heads
541
+ if talking_heads:
542
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
543
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
544
+
545
+ # head scaling
546
+ self.head_scale = head_scale
547
+ if head_scale:
548
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
549
+
550
+ # explicit topk sparse attention
551
+ self.sparse_topk = sparse_topk
552
+
553
+ # entmax
554
+ self.attn_fn = F.softmax
555
+
556
+ # add memory key / values
557
+ self.num_mem_kv = num_mem_kv
558
+ if num_mem_kv > 0:
559
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
560
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
561
+
562
+ # attention on attention
563
+ self.attn_on_attn = on_attn
564
+ self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
565
+
566
+ self.rel_pos_bias = rel_pos_bias
567
+ if rel_pos_bias:
568
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
569
+ self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
570
+ num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
571
+
572
+ # init output projection 0
573
+ if zero_init_output:
574
+ init_zero_(self.to_out)
575
+
576
+ def forward(
577
+ self,
578
+ x,
579
+ context=None,
580
+ mask=None,
581
+ context_mask=None,
582
+ attn_mask=None,
583
+ sinusoidal_emb=None,
584
+ rotary_pos_emb=None,
585
+ prev_attn=None,
586
+ mem=None,
587
+ layer_past=None,
588
+ ):
589
+ b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
590
+ context)
591
+ kv_input = default(context, x)
592
+
593
+ q_input = x
594
+ k_input = kv_input
595
+ v_input = kv_input
596
+
597
+ if exists(mem):
598
+ k_input = torch.cat((mem, k_input), dim=-2)
599
+ v_input = torch.cat((mem, v_input), dim=-2)
600
+
601
+ if exists(sinusoidal_emb):
602
+ # in shortformer, the query would start at a position offset depending on the past cached memory
603
+ offset = k_input.shape[-2] - q_input.shape[-2]
604
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
605
+ k_input = k_input + sinusoidal_emb(k_input)
606
+
607
+ q = self.to_q(q_input)
608
+ k = self.to_k(k_input)
609
+ v = self.to_v(v_input)
610
+
611
+ if not collab_heads:
612
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
613
+ else:
614
+ q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
615
+ k = rearrange(k, 'b n d -> b () n d')
616
+ v = rearrange(v, 'b n (h d) -> b h n d', h=h)
617
+
618
+ if layer_past is not None:
619
+ past_key, past_value = layer_past
620
+ k = torch.cat([past_key, k], dim=-2)
621
+ v = torch.cat([past_value, v], dim=-2)
622
+ k_cache = k
623
+ v_cache = v
624
+
625
+ if exists(rotary_pos_emb) and not has_context:
626
+ l = rotary_pos_emb.shape[-1]
627
+ (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
628
+ ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
629
+ q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
630
+
631
+ input_mask = None
632
+ if any(map(exists, (mask, context_mask))):
633
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
634
+ k_mask = q_mask if not exists(context) else context_mask
635
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
636
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
637
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
638
+ input_mask = q_mask * k_mask
639
+
640
+ if self.num_mem_kv > 0:
641
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
642
+ k = torch.cat((mem_k, k), dim=-2)
643
+ v = torch.cat((mem_v, v), dim=-2)
644
+ if exists(input_mask):
645
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
646
+
647
+ if collab_heads:
648
+ k = k.expand(-1, h, -1, -1)
649
+
650
+ if self.qk_norm:
651
+ q, k = map(l2norm, (q, k))
652
+ scale = 1 / (self.scale.exp().clamp(min=1e-2))
653
+
654
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
655
+ mask_value = max_neg_value(dots)
656
+
657
+ if exists(prev_attn):
658
+ dots = dots + prev_attn
659
+
660
+ pre_softmax_attn = dots.clone()
661
+
662
+ if talking_heads:
663
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
664
+
665
+ if self.rel_pos_bias:
666
+ dots = self.rel_pos(dots)
667
+
668
+ if exists(input_mask):
669
+ dots.masked_fill_(~input_mask, mask_value)
670
+ del input_mask
671
+
672
+ if exists(attn_mask):
673
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
674
+ if attn_mask.ndim == 2:
675
+ attn_mask = rearrange(attn_mask, 'i j -> () () i j')
676
+ elif attn_mask.ndim == 3:
677
+ attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
678
+ dots.masked_fill_(~attn_mask, mask_value)
679
+
680
+ if exists(self.max_attend_past):
681
+ i, j = dots.shape[-2:]
682
+ range_q = torch.arange(j - i, j, device=device)
683
+ range_k = torch.arange(j, device=device)
684
+ dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
685
+ mask = dist > self.max_attend_past
686
+ dots.masked_fill_(mask, mask_value)
687
+ del mask
688
+
689
+ if self.causal:
690
+ i, j = dots.shape[-2:]
691
+ r = torch.arange(i, device=device)
692
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
693
+ mask = F.pad(mask, (j - i, 0), value=False)
694
+ dots.masked_fill_(mask, mask_value)
695
+ del mask
696
+
697
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
698
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
699
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
700
+ mask = dots < vk
701
+ dots.masked_fill_(mask, mask_value)
702
+ del mask
703
+
704
+ attn = self.attn_fn(dots, dim=-1)
705
+ post_softmax_attn = attn.clone()
706
+
707
+ attn = self.dropout(attn)
708
+
709
+ if talking_heads:
710
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
711
+
712
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
713
+
714
+ if head_scale:
715
+ out = out * self.head_scale_params
716
+
717
+ out = rearrange(out, 'b h n d -> b n (h d)')
718
+
719
+ if exists(self.to_v_gate):
720
+ gates = self.to_v_gate(x)
721
+ out = out * gates.sigmoid()
722
+
723
+ intermediates = Intermediates(
724
+ pre_softmax_attn=pre_softmax_attn,
725
+ post_softmax_attn=post_softmax_attn
726
+ )
727
+
728
+ return self.to_out(out), intermediates, k_cache, v_cache
729
+
730
+
731
+ class AttentionLayers(nn.Module):
732
+ def __init__(
733
+ self,
734
+ dim,
735
+ depth,
736
+ heads=8,
737
+ causal=False,
738
+ cross_attend=False,
739
+ only_cross=False,
740
+ use_scalenorm=False,
741
+ use_rms_scaleshift_norm=False,
742
+ use_rmsnorm=False,
743
+ use_rezero=False,
744
+ alibi_pos_bias=False,
745
+ alibi_num_heads=None,
746
+ alibi_learned=False,
747
+ position_infused_attn=False,
748
+ rotary_pos_emb=False,
749
+ rotary_emb_dim=None,
750
+ custom_layers=None,
751
+ sandwich_coef=None,
752
+ par_ratio=None,
753
+ residual_attn=False,
754
+ cross_residual_attn=False,
755
+ macaron=False,
756
+ pre_norm=True,
757
+ gate_residual=False,
758
+ scale_residual=False,
759
+ shift_tokens=0,
760
+ sandwich_norm=False,
761
+ use_qk_norm_attn=False,
762
+ qk_norm_attn_seq_len=None,
763
+ zero_init_branch_output=False,
764
+ **kwargs
765
+ ):
766
+ super().__init__()
767
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
768
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
769
+
770
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
771
+
772
+ self.dim = dim
773
+ self.depth = depth
774
+ self.layers = nn.ModuleList([])
775
+ self.causal = causal
776
+
777
+ rel_pos_bias = 'rel_pos_bias' in attn_kwargs
778
+ self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
779
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
780
+
781
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
782
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
783
+
784
+ assert not (
785
+ alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
786
+
787
+ if alibi_pos_bias:
788
+ alibi_num_heads = default(alibi_num_heads, heads)
789
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
790
+ alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
791
+ self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
792
+ else:
793
+ self.rel_pos = None
794
+
795
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
796
+ self.pre_norm = pre_norm
797
+ self.sandwich_norm = sandwich_norm
798
+
799
+ self.residual_attn = residual_attn
800
+ self.cross_residual_attn = cross_residual_attn
801
+ self.cross_attend = cross_attend
802
+
803
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
804
+ norm_class = RMSNorm if use_rmsnorm else norm_class
805
+ norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
806
+ norm_fn = partial(norm_class, dim)
807
+
808
+ norm_fn = nn.Identity if use_rezero else norm_fn
809
+ branch_fn = Rezero if use_rezero else None
810
+
811
+ if cross_attend and not only_cross:
812
+ default_block = ('a', 'c', 'f')
813
+ elif cross_attend and only_cross:
814
+ default_block = ('c', 'f')
815
+ else:
816
+ default_block = ('a', 'f')
817
+
818
+ if macaron:
819
+ default_block = ('f',) + default_block
820
+
821
+ # qk normalization
822
+
823
+ if use_qk_norm_attn:
824
+ attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
825
+ qk_norm_attn_seq_len) else None
826
+ attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
827
+
828
+ # zero init
829
+
830
+ if zero_init_branch_output:
831
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
832
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
833
+
834
+ # calculate layer block order
835
+
836
+ if exists(custom_layers):
837
+ layer_types = custom_layers
838
+ elif exists(par_ratio):
839
+ par_depth = depth * len(default_block)
840
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
841
+ default_block = tuple(filter(not_equals('f'), default_block))
842
+ par_attn = par_depth // par_ratio
843
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
844
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
845
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
846
+ par_block = default_block + ('f',) * (par_width - len(default_block))
847
+ par_head = par_block * par_attn
848
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
849
+ elif exists(sandwich_coef):
850
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
851
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
852
+ else:
853
+ layer_types = default_block * depth
854
+
855
+ self.layer_types = layer_types
856
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
857
+
858
+ # calculate token shifting
859
+
860
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
861
+
862
+ # iterate and construct layers
863
+
864
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
865
+ is_last_layer = ind == (len(self.layer_types) - 1)
866
+
867
+ if layer_type == 'a':
868
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
869
+ elif layer_type == 'c':
870
+ layer = Attention(dim, heads=heads, **attn_kwargs)
871
+ elif layer_type == 'f':
872
+ layer = FeedForward(dim, **ff_kwargs)
873
+ layer = layer if not macaron else Scale(0.5, layer)
874
+ else:
875
+ raise Exception(f'invalid layer type {layer_type}')
876
+
877
+ if layer_shift_tokens > 0:
878
+ shift_range_upper = layer_shift_tokens + 1
879
+ shift_range_lower = -layer_shift_tokens if not causal else 0
880
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
881
+
882
+ if exists(branch_fn):
883
+ layer = branch_fn(layer)
884
+
885
+ residual_fn = GRUGating if gate_residual else Residual
886
+ residual = residual_fn(dim, scale_residual=scale_residual)
887
+
888
+ layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
889
+
890
+ pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
891
+ post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
892
+ post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
893
+
894
+ norms = nn.ModuleList([
895
+ pre_branch_norm,
896
+ post_branch_norm,
897
+ post_main_norm
898
+ ])
899
+
900
+ self.layers.append(nn.ModuleList([
901
+ norms,
902
+ layer,
903
+ residual
904
+ ]))
905
+
906
+ def forward(
907
+ self,
908
+ x,
909
+ context=None,
910
+ full_context=None, # for passing a list of hidden states from an encoder
911
+ mask=None,
912
+ context_mask=None,
913
+ attn_mask=None,
914
+ mems=None,
915
+ return_hiddens=False,
916
+ norm_scale_shift_inp=None,
917
+ past_key_values=None,
918
+ expected_seq_len=None,
919
+ ):
920
+
921
+ assert not (self.cross_attend ^ (exists(context) or exists(
922
+ full_context))), 'context must be passed in if cross_attend is set to True'
923
+ assert context is None or full_context is None, 'only one of full_context or context can be provided'
924
+
925
+ hiddens = []
926
+ intermediates = []
927
+ prev_attn = None
928
+ prev_cross_attn = None
929
+
930
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
931
+ norm_args = {}
932
+ if exists(norm_scale_shift_inp):
933
+ norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
934
+
935
+ rotary_pos_emb = None
936
+ if exists(self.rotary_pos_emb):
937
+ if not self.training and self.causal:
938
+ assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
939
+ elif expected_seq_len is None:
940
+ expected_seq_len = 0
941
+ seq_len = x.shape[1]
942
+ if past_key_values is not None:
943
+ seq_len += past_key_values[0][0].shape[-2]
944
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
945
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
946
+
947
+ present_key_values = []
948
+ cross_attn_count = 0
949
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
950
+ if layer_type == 'a':
951
+ layer_mem = mems.pop(0) if mems else None
952
+
953
+ residual = x
954
+
955
+ pre_branch_norm, post_branch_norm, post_main_norm = norm
956
+
957
+ if exists(pre_branch_norm):
958
+ x = pre_branch_norm(x, **norm_args)
959
+
960
+ if layer_type == 'a' or layer_type == 'c':
961
+ if past_key_values is not None:
962
+ layer_kv = past_key_values.pop(0)
963
+ layer_past = tuple(s.to(x.device) for s in layer_kv)
964
+ else:
965
+ layer_past = None
966
+
967
+ if layer_type == 'a':
968
+ out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
969
+ prev_attn, layer_mem, layer_past)
970
+ elif layer_type == 'c':
971
+ if exists(full_context):
972
+ out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
973
+ None, prev_attn, None, layer_past)
974
+ else:
975
+ out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
976
+ elif layer_type == 'f':
977
+ out = block(x)
978
+
979
+ if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
980
+ present_key_values.append((k.detach(), v.detach()))
981
+
982
+ if exists(post_branch_norm):
983
+ out = post_branch_norm(out, **norm_args)
984
+
985
+ x = residual_fn(out, residual)
986
+
987
+ if layer_type in ('a', 'c'):
988
+ intermediates.append(inter)
989
+
990
+ if layer_type == 'a' and self.residual_attn:
991
+ prev_attn = inter.pre_softmax_attn
992
+ elif layer_type == 'c' and self.cross_residual_attn:
993
+ prev_cross_attn = inter.pre_softmax_attn
994
+
995
+ if exists(post_main_norm):
996
+ x = post_main_norm(x, **norm_args)
997
+
998
+ if layer_type == 'c':
999
+ cross_attn_count += 1
1000
+
1001
+ if layer_type == 'f':
1002
+ hiddens.append(x)
1003
+
1004
+ if return_hiddens:
1005
+ intermediates = LayerIntermediates(
1006
+ hiddens=hiddens,
1007
+ attn_intermediates=intermediates,
1008
+ past_key_values=present_key_values
1009
+ )
1010
+
1011
+ return x, intermediates
1012
+
1013
+ return x
1014
+
1015
+
1016
+ class Encoder(AttentionLayers):
1017
+ def __init__(self, **kwargs):
1018
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
1019
+ super().__init__(causal=False, **kwargs)
1020
+
1021
+
1022
+ class Decoder(AttentionLayers):
1023
+ def __init__(self, **kwargs):
1024
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
1025
+ super().__init__(causal=True, **kwargs)
1026
+
1027
+
1028
+ class CrossAttender(AttentionLayers):
1029
+ def __init__(self, **kwargs):
1030
+ super().__init__(cross_attend=True, only_cross=True, **kwargs)
1031
+
1032
+
1033
+ class ViTransformerWrapper(nn.Module):
1034
+ def __init__(
1035
+ self,
1036
+ *,
1037
+ image_size,
1038
+ patch_size,
1039
+ attn_layers,
1040
+ num_classes=None,
1041
+ dropout=0.,
1042
+ emb_dropout=0.
1043
+ ):
1044
+ super().__init__()
1045
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
1046
+ assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
1047
+ dim = attn_layers.dim
1048
+ num_patches = (image_size // patch_size) ** 2
1049
+ patch_dim = 3 * patch_size ** 2
1050
+
1051
+ self.patch_size = patch_size
1052
+
1053
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1054
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
1055
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
1056
+ self.dropout = nn.Dropout(emb_dropout)
1057
+
1058
+ self.attn_layers = attn_layers
1059
+ self.norm = nn.LayerNorm(dim)
1060
+ self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
1061
+
1062
+ def forward(
1063
+ self,
1064
+ img,
1065
+ return_embeddings=False
1066
+ ):
1067
+ p = self.patch_size
1068
+
1069
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
1070
+ x = self.patch_to_embedding(x)
1071
+ b, n, _ = x.shape
1072
+
1073
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
1074
+ x = torch.cat((cls_tokens, x), dim=1)
1075
+ x = x + self.pos_embedding[:, :(n + 1)]
1076
+ x = self.dropout(x)
1077
+
1078
+ x = self.attn_layers(x)
1079
+ x = self.norm(x)
1080
+
1081
+ if not exists(self.mlp_head) or return_embeddings:
1082
+ return x
1083
+
1084
+ return self.mlp_head(x[:, 0])
1085
+
1086
+
1087
+ class TransformerWrapper(nn.Module):
1088
+ def __init__(
1089
+ self,
1090
+ *,
1091
+ num_tokens,
1092
+ max_seq_len,
1093
+ attn_layers,
1094
+ emb_dim=None,
1095
+ max_mem_len=0.,
1096
+ shift_mem_down=0,
1097
+ emb_dropout=0.,
1098
+ num_memory_tokens=None,
1099
+ tie_embedding=False,
1100
+ use_pos_emb=True
1101
+ ):
1102
+ super().__init__()
1103
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1104
+
1105
+ dim = attn_layers.dim
1106
+ emb_dim = default(emb_dim, dim)
1107
+
1108
+ self.max_seq_len = max_seq_len
1109
+ self.max_mem_len = max_mem_len
1110
+ self.shift_mem_down = shift_mem_down
1111
+
1112
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
1113
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
1114
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1115
+ self.emb_dropout = nn.Dropout(emb_dropout)
1116
+
1117
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1118
+ self.attn_layers = attn_layers
1119
+ self.norm = nn.LayerNorm(dim)
1120
+
1121
+ self.init_()
1122
+
1123
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
1124
+
1125
+ # memory tokens (like [cls]) from Memory Transformers paper
1126
+ num_memory_tokens = default(num_memory_tokens, 0)
1127
+ self.num_memory_tokens = num_memory_tokens
1128
+ if num_memory_tokens > 0:
1129
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1130
+
1131
+ def init_(self):
1132
+ nn.init.kaiming_normal_(self.token_emb.weight)
1133
+
1134
+ def forward(
1135
+ self,
1136
+ x,
1137
+ return_embeddings=False,
1138
+ mask=None,
1139
+ return_hiddens=False,
1140
+ return_attn=False,
1141
+ mems=None,
1142
+ use_cache=False,
1143
+ **kwargs
1144
+ ):
1145
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
1146
+ x = self.token_emb(x)
1147
+ x = x + self.pos_emb(x)
1148
+ x = self.emb_dropout(x)
1149
+
1150
+ x = self.project_emb(x)
1151
+
1152
+ if num_mem > 0:
1153
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
1154
+ x = torch.cat((mem, x), dim=1)
1155
+
1156
+ # auto-handle masking after appending memory tokens
1157
+ if exists(mask):
1158
+ mask = F.pad(mask, (num_mem, 0), value=True)
1159
+
1160
+ if self.shift_mem_down and exists(mems):
1161
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
1162
+ mems = [*mems_r, *mems_l]
1163
+
1164
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1165
+ x = self.norm(x)
1166
+
1167
+ mem, x = x[:, :num_mem], x[:, num_mem:]
1168
+
1169
+ out = self.to_logits(x) if not return_embeddings else x
1170
+
1171
+ if return_hiddens:
1172
+ hiddens = intermediates.hiddens
1173
+ return out, hiddens
1174
+
1175
+ res = [out]
1176
+ if return_attn:
1177
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1178
+ res.append(attn_maps)
1179
+ if use_cache:
1180
+ res.append(intermediates.past_key_values)
1181
+
1182
+ if len(res) > 1:
1183
+ return tuple(res)
1184
+ return res[0]
1185
+
1186
+
1187
+ class ContinuousTransformerWrapper(nn.Module):
1188
+ def __init__(
1189
+ self,
1190
+ *,
1191
+ max_seq_len,
1192
+ attn_layers,
1193
+ dim_in=None,
1194
+ dim_out=None,
1195
+ emb_dim=None,
1196
+ emb_dropout=0.,
1197
+ use_pos_emb=True
1198
+ ):
1199
+ super().__init__()
1200
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1201
+
1202
+ dim = attn_layers.dim
1203
+
1204
+ self.max_seq_len = max_seq_len
1205
+
1206
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
1207
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1208
+ self.emb_dropout = nn.Dropout(emb_dropout)
1209
+
1210
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1211
+
1212
+ self.attn_layers = attn_layers
1213
+ self.norm = nn.LayerNorm(dim)
1214
+
1215
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1216
+
1217
+ def forward(
1218
+ self,
1219
+ x,
1220
+ return_embeddings=False,
1221
+ mask=None,
1222
+ return_attn=False,
1223
+ mems=None,
1224
+ use_cache=False,
1225
+ **kwargs
1226
+ ):
1227
+ b, n, _, device = *x.shape, x.device
1228
+
1229
+ x = self.project_in(x)
1230
+ x = x + self.pos_emb(x)
1231
+ x = self.emb_dropout(x)
1232
+
1233
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1234
+ x = self.norm(x)
1235
+
1236
+ out = self.project_out(x) if not return_embeddings else x
1237
+
1238
+ res = [out]
1239
+ if return_attn:
1240
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1241
+ res.append(attn_maps)
1242
+ if use_cache:
1243
+ res.append(intermediates.past_key_values)
1244
+
1245
+ if len(res) > 1:
1246
+ return tuple(res)
1247
+ return res[0]
1248
+
ruth_tts_transformer/utils/__init__.py ADDED
File without changes