prakashr7d
commited on
Commit
•
b0bf39f
1
Parent(s):
9d1ae0d
written handler
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __pycache__/constants.cpython-310.pyc +0 -0
- __pycache__/constants.cpython-38.pyc +0 -0
- __pycache__/handler.cpython-38.pyc +0 -0
- __pycache__/serve.cpython-310.pyc +0 -0
- __pycache__/serve.cpython-38.pyc +0 -0
- __pycache__/server.cpython-38.pyc +0 -0
- __pycache__/try.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-38.pyc +0 -0
- config-model.yaml +12 -0
- constants.py +12 -0
- handler.py +479 -0
- requirements.txt +18 -0
- ruth_tts_transformer/.gitignore +1 -0
- ruth_tts_transformer/__init__.py +2 -0
- ruth_tts_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- ruth_tts_transformer/__pycache__/__init__.cpython-37.pyc +0 -0
- ruth_tts_transformer/__pycache__/__init__.cpython-38.pyc +0 -0
- ruth_tts_transformer/__pycache__/__init__.cpython-39.pyc +0 -0
- ruth_tts_transformer/data/latents.pkl +0 -0
- ruth_tts_transformer/data/layman.txt +0 -0
- ruth_tts_transformer/data/mel_norms.pth +0 -0
- ruth_tts_transformer/data/riding_hood.txt +54 -0
- ruth_tts_transformer/data/seal_copypasta.txt +1 -0
- ruth_tts_transformer/data/tokenizer.json +1 -0
- ruth_tts_transformer/models/__init__.py +0 -0
- ruth_tts_transformer/models/__pycache__/__init__.cpython-310.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/__init__.cpython-38.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/arch_util.cpython-310.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/arch_util.cpython-38.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/autoregressive.cpython-310.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/autoregressive.cpython-38.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/clvp.cpython-310.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/clvp.cpython-38.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/diffusion_decoder.cpython-310.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/diffusion_decoder.cpython-38.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/transformer.cpython-310.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/transformer.cpython-38.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/vocoder.cpython-310.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/vocoder.cpython-38.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/xtransformers.cpython-310.pyc +0 -0
- ruth_tts_transformer/models/__pycache__/xtransformers.cpython-38.pyc +0 -0
- ruth_tts_transformer/models/arch_util.py +371 -0
- ruth_tts_transformer/models/autoregressive.py +528 -0
- ruth_tts_transformer/models/clvp.py +155 -0
- ruth_tts_transformer/models/diffusion_decoder.py +349 -0
- ruth_tts_transformer/models/transformer.py +221 -0
- ruth_tts_transformer/models/vocoder.py +323 -0
- ruth_tts_transformer/models/xtransformers.py +1248 -0
- ruth_tts_transformer/utils/__init__.py +0 -0
__pycache__/constants.cpython-310.pyc
ADDED
Binary file (538 Bytes). View file
|
|
__pycache__/constants.cpython-38.pyc
ADDED
Binary file (530 Bytes). View file
|
|
__pycache__/handler.cpython-38.pyc
ADDED
Binary file (13.2 kB). View file
|
|
__pycache__/serve.cpython-310.pyc
ADDED
Binary file (14.5 kB). View file
|
|
__pycache__/serve.cpython-38.pyc
ADDED
Binary file (13.6 kB). View file
|
|
__pycache__/server.cpython-38.pyc
ADDED
Binary file (13.1 kB). View file
|
|
__pycache__/try.cpython-310.pyc
ADDED
Binary file (1.17 kB). View file
|
|
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.7 kB). View file
|
|
__pycache__/utils.cpython-38.pyc
ADDED
Binary file (2.66 kB). View file
|
|
config-model.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gpt:
|
2 |
+
num_autoregressive_samples: 16
|
3 |
+
top_p: 0.8
|
4 |
+
temperature: 0.8
|
5 |
+
length_penalty: 1
|
6 |
+
max_mel_tokens: 500
|
7 |
+
repetition_penalty: 2.0
|
8 |
+
autoregressive_batch_size: 16
|
9 |
+
clvp:
|
10 |
+
k: 1
|
11 |
+
diffusion:
|
12 |
+
diffusion_temperature: 1.0
|
constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NUM_AUTOREGRESSIVE_SAMPLES = "num_autoregressive_samples"
|
2 |
+
TOP_P = "top_p"
|
3 |
+
TEMPERATURE = "temperature"
|
4 |
+
LENGTH_PENALTY = "length_penalty"
|
5 |
+
REPETITION_PENALTY = "repetition_penalty"
|
6 |
+
MAX_MEL_TOKENS = "max_mel_tokens"
|
7 |
+
AUTO_REGRESSIVE_BATCH_SIZE = "autoregressive_batch_size"
|
8 |
+
DIFFUSION_TEMPERATURE = "diffusion_temperature"
|
9 |
+
# MODELS
|
10 |
+
GPT = "gpt"
|
11 |
+
CLVP_const = "clvp"
|
12 |
+
DIFFUSION = "diffusion"
|
handler.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import hashlib
|
3 |
+
from io import BytesIO
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
from copy import copy
|
9 |
+
from datetime import datetime
|
10 |
+
from fastapi import FastAPI
|
11 |
+
from fastapi.responses import FileResponse
|
12 |
+
from pathlib import Path
|
13 |
+
from pydantic import BaseModel
|
14 |
+
|
15 |
+
from time import time
|
16 |
+
from typing import Any, Dict, List, Text, Tuple
|
17 |
+
|
18 |
+
from constants import (
|
19 |
+
AUTO_REGRESSIVE_BATCH_SIZE,
|
20 |
+
DIFFUSION,
|
21 |
+
DIFFUSION_TEMPERATURE,
|
22 |
+
GPT,
|
23 |
+
LENGTH_PENALTY,
|
24 |
+
MAX_MEL_TOKENS,
|
25 |
+
NUM_AUTOREGRESSIVE_SAMPLES,
|
26 |
+
REPETITION_PENALTY,
|
27 |
+
TEMPERATURE,
|
28 |
+
TOP_P,
|
29 |
+
CLVP_const,
|
30 |
+
)
|
31 |
+
from ruth_tts_transformer.models.autoregressive import UnifiedVoice
|
32 |
+
from ruth_tts_transformer.models.clvp import CLVP
|
33 |
+
from ruth_tts_transformer.models.diffusion_decoder import DiffusionTts
|
34 |
+
from ruth_tts_transformer.models.vocoder import UnivNetGenerator
|
35 |
+
from ruth_tts_transformer.utils.audio import load_voice
|
36 |
+
from ruth_tts_transformer.utils.tokenizer import VoiceBpeTokenizer
|
37 |
+
from ruth_tts_transformer.utils.wav2vec_alignment import Wav2VecAlignment
|
38 |
+
from utils import (
|
39 |
+
MODELS_DIR,
|
40 |
+
get_config_file,
|
41 |
+
get_model_path,
|
42 |
+
load_discrete_vocoder_diffuser,
|
43 |
+
)
|
44 |
+
|
45 |
+
app = FastAPI()
|
46 |
+
|
47 |
+
|
48 |
+
class Item(BaseModel):
|
49 |
+
text: str
|
50 |
+
voice: str
|
51 |
+
seed: int = 3
|
52 |
+
|
53 |
+
|
54 |
+
class Gpt:
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
num_autoregressive_samples: int,
|
58 |
+
top_p: float,
|
59 |
+
temperature: float,
|
60 |
+
length_penalty: int,
|
61 |
+
repetition_penalty: float,
|
62 |
+
max_mel_tokens: int,
|
63 |
+
autoregressive_batch_size: int,
|
64 |
+
):
|
65 |
+
self.num_autoregressive_samples = num_autoregressive_samples
|
66 |
+
self.top_p = top_p
|
67 |
+
self.temperature = temperature
|
68 |
+
self.length_penalty = length_penalty
|
69 |
+
self.repetition_penalty = repetition_penalty
|
70 |
+
self.max_mel_tokens = max_mel_tokens
|
71 |
+
self.autoregressive_batch_size = autoregressive_batch_size
|
72 |
+
self.gpt = (
|
73 |
+
UnifiedVoice(
|
74 |
+
max_mel_tokens=604,
|
75 |
+
max_text_tokens=402,
|
76 |
+
max_conditioning_inputs=2,
|
77 |
+
layers=30,
|
78 |
+
model_dim=1024,
|
79 |
+
heads=16,
|
80 |
+
number_text_tokens=255,
|
81 |
+
start_text_token=255,
|
82 |
+
checkpointing=False,
|
83 |
+
train_solo_embeddings=False,
|
84 |
+
)
|
85 |
+
.cpu()
|
86 |
+
.eval()
|
87 |
+
)
|
88 |
+
self.gpt.load_state_dict(
|
89 |
+
torch.load(get_model_path("autoregressive.pth", MODELS_DIR))
|
90 |
+
)
|
91 |
+
self.gpt = self.gpt.to("cuda")
|
92 |
+
|
93 |
+
def __num_batches(self):
|
94 |
+
return self.num_autoregressive_samples // self.autoregressive_batch_size
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def deterministic_state(seed=None):
|
98 |
+
seed = int(time()) if seed is None else seed
|
99 |
+
torch.manual_seed(seed)
|
100 |
+
random.seed(seed)
|
101 |
+
return seed
|
102 |
+
|
103 |
+
def parse(self, auto_conditioning, text_tokens, best_results, seed, k=1):
|
104 |
+
self.deterministic_state(seed=seed)
|
105 |
+
auto_conditioning = copy(auto_conditioning).to("cuda")
|
106 |
+
text_tokens = copy(text_tokens).to("cuda")
|
107 |
+
best_results = copy(best_results).to("cuda")
|
108 |
+
best_latents = self.gpt(
|
109 |
+
auto_conditioning.repeat(k, 1),
|
110 |
+
text_tokens.repeat(k, 1),
|
111 |
+
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
|
112 |
+
best_results,
|
113 |
+
torch.tensor(
|
114 |
+
[best_results.shape[-1] * self.gpt.mel_length_compression],
|
115 |
+
device=text_tokens.device,
|
116 |
+
),
|
117 |
+
return_latent=True,
|
118 |
+
clip_inputs=False,
|
119 |
+
)
|
120 |
+
# return best_latents.cpu().detach().numpy()
|
121 |
+
return best_latents
|
122 |
+
|
123 |
+
def parse_inference(
|
124 |
+
self, auto_conditioning: torch.Tensor, text_tokens: torch.Tensor, seed
|
125 |
+
) -> Tuple[List[torch.Tensor], int]:
|
126 |
+
self.deterministic_state(seed=seed)
|
127 |
+
auto_conditioning = copy(auto_conditioning).to("cuda")
|
128 |
+
text_tokens = copy(text_tokens).to("cuda")
|
129 |
+
with torch.no_grad():
|
130 |
+
samples = []
|
131 |
+
num_batches = self.__num_batches()
|
132 |
+
for b in range(num_batches):
|
133 |
+
codes = self.gpt.inference_speech(
|
134 |
+
auto_conditioning,
|
135 |
+
text_tokens,
|
136 |
+
do_sample=True,
|
137 |
+
top_p=self.top_p,
|
138 |
+
temperature=self.temperature,
|
139 |
+
num_return_sequences=self.autoregressive_batch_size,
|
140 |
+
length_penalty=self.length_penalty,
|
141 |
+
repetition_penalty=self.repetition_penalty,
|
142 |
+
max_generate_length=self.max_mel_tokens,
|
143 |
+
)
|
144 |
+
padding_needed = self.max_mel_tokens - codes.shape[1]
|
145 |
+
codes = F.pad(codes, (0, padding_needed), value=self.gpt.stop_mel_token)
|
146 |
+
# samples.append(codes.cpu().detach().numpy())
|
147 |
+
samples.append(codes)
|
148 |
+
|
149 |
+
return samples, self.gpt.stop_mel_token
|
150 |
+
|
151 |
+
|
152 |
+
class clvp:
|
153 |
+
def __init__(self, K):
|
154 |
+
|
155 |
+
self.clvp = (
|
156 |
+
CLVP(
|
157 |
+
dim_text=768,
|
158 |
+
dim_speech=768,
|
159 |
+
dim_latent=768,
|
160 |
+
num_text_tokens=256,
|
161 |
+
text_enc_depth=20,
|
162 |
+
text_seq_len=350,
|
163 |
+
text_heads=12,
|
164 |
+
num_speech_tokens=8192,
|
165 |
+
speech_enc_depth=20,
|
166 |
+
speech_heads=12,
|
167 |
+
speech_seq_len=430,
|
168 |
+
use_xformers=True,
|
169 |
+
)
|
170 |
+
.cpu()
|
171 |
+
.eval()
|
172 |
+
)
|
173 |
+
self.clvp.load_state_dict(torch.load(get_model_path("clvp2.pth", MODELS_DIR)))
|
174 |
+
self.clvp.to("cuda")
|
175 |
+
self.K = K
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def fix_gpt_output(codes, stop_token, complain=True):
|
179 |
+
stop_token_indices = (codes == stop_token).nonzero()
|
180 |
+
if len(stop_token_indices) == 0:
|
181 |
+
if complain:
|
182 |
+
print(
|
183 |
+
"No stop tokens found in one of the generated voice clips. This typically means the spoken audio "
|
184 |
+
"is "
|
185 |
+
"too long. In some cases, the output will still be good, though. Listen to it and if it is "
|
186 |
+
"missing words, "
|
187 |
+
"try breaking up your input text."
|
188 |
+
)
|
189 |
+
return codes
|
190 |
+
else:
|
191 |
+
codes[stop_token_indices] = 83
|
192 |
+
stm = stop_token_indices.min().item()
|
193 |
+
codes[stm:] = 83
|
194 |
+
if stm - 3 < codes.shape[0]:
|
195 |
+
codes[-3] = 45
|
196 |
+
codes[-2] = 45
|
197 |
+
codes[-1] = 248
|
198 |
+
|
199 |
+
return codes
|
200 |
+
|
201 |
+
def parse(
|
202 |
+
self,
|
203 |
+
text_tokens: torch.Tensor,
|
204 |
+
samples: List[torch.Tensor],
|
205 |
+
stop_mel_token: int,
|
206 |
+
seed: int,
|
207 |
+
) -> torch.Tensor:
|
208 |
+
self.deterministic_state(seed=seed)
|
209 |
+
clip_results = []
|
210 |
+
text_tokens = copy(text_tokens).to("cuda")
|
211 |
+
samples = [copy(batch).to("cuda") for batch in samples]
|
212 |
+
for batch in samples:
|
213 |
+
for i in range(batch.shape[0]):
|
214 |
+
batch[i] = self.fix_gpt_output(batch[i], stop_mel_token)
|
215 |
+
|
216 |
+
clvp = self.clvp(
|
217 |
+
text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False
|
218 |
+
)
|
219 |
+
clip_results.append(clvp)
|
220 |
+
|
221 |
+
clip_results = torch.cat(clip_results, dim=0)
|
222 |
+
samples = torch.cat(samples, dim=0)
|
223 |
+
# return samples[torch.topk(clip_results, self.K).indices].cpu().detach().numpy()
|
224 |
+
return samples[torch.topk(clip_results, self.K).indices]
|
225 |
+
|
226 |
+
@staticmethod
|
227 |
+
def deterministic_state(seed=None):
|
228 |
+
seed = int(time()) if seed is None else seed
|
229 |
+
torch.manual_seed(seed)
|
230 |
+
random.seed(seed)
|
231 |
+
return seed
|
232 |
+
|
233 |
+
|
234 |
+
class Diffusion:
|
235 |
+
def __init__(
|
236 |
+
self,
|
237 |
+
diffusion_temperature,
|
238 |
+
diffusion_iterations=30,
|
239 |
+
cond_free=True,
|
240 |
+
cond_free_k=2,
|
241 |
+
):
|
242 |
+
self.diffusion_temperature = diffusion_temperature
|
243 |
+
self.diffusion = (
|
244 |
+
DiffusionTts(
|
245 |
+
model_channels=1024,
|
246 |
+
num_layers=10,
|
247 |
+
in_channels=100,
|
248 |
+
out_channels=200,
|
249 |
+
in_latent_channels=1024,
|
250 |
+
in_tokens=8193,
|
251 |
+
dropout=0,
|
252 |
+
use_fp16=False,
|
253 |
+
num_heads=16,
|
254 |
+
layer_drop=0,
|
255 |
+
unconditioned_percentage=0,
|
256 |
+
)
|
257 |
+
.cpu()
|
258 |
+
.eval()
|
259 |
+
)
|
260 |
+
self.diffusion.load_state_dict(
|
261 |
+
torch.load(get_model_path("diffusion_decoder.pth", MODELS_DIR))
|
262 |
+
)
|
263 |
+
self.diffuser = load_discrete_vocoder_diffuser(
|
264 |
+
desired_diffusion_steps=diffusion_iterations,
|
265 |
+
cond_free=cond_free,
|
266 |
+
cond_free_k=cond_free_k,
|
267 |
+
)
|
268 |
+
|
269 |
+
self.vocoder = UnivNetGenerator().cpu()
|
270 |
+
self.vocoder.load_state_dict(
|
271 |
+
torch.load(
|
272 |
+
get_model_path("vocoder.pth", MODELS_DIR),
|
273 |
+
map_location=torch.device("cpu"),
|
274 |
+
)["model_g"]
|
275 |
+
)
|
276 |
+
self.vocoder.eval(inference=True)
|
277 |
+
self.diffusion.to("cuda")
|
278 |
+
self.vocoder.to("cuda")
|
279 |
+
self.aligner = Wav2VecAlignment()
|
280 |
+
# state = self.deterministic_state(seed=0) #Remove after testing
|
281 |
+
self.TACOTRON_MEL_MAX = 2.3143386840820312
|
282 |
+
self.TACOTRON_MEL_MIN = -11.512925148010254
|
283 |
+
|
284 |
+
def denormalize_tacotron_mel(self, norm_mel):
|
285 |
+
return ((norm_mel + 1) / 2) * (
|
286 |
+
self.TACOTRON_MEL_MAX - self.TACOTRON_MEL_MIN
|
287 |
+
) + self.TACOTRON_MEL_MIN
|
288 |
+
|
289 |
+
def potentially_redact(self, clip, text):
|
290 |
+
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
|
291 |
+
|
292 |
+
@staticmethod
|
293 |
+
def deterministic_state(seed=None):
|
294 |
+
seed = int(time()) if seed is None else seed
|
295 |
+
torch.manual_seed(seed)
|
296 |
+
random.seed(seed)
|
297 |
+
return seed
|
298 |
+
|
299 |
+
def do_spectrogram_diffusion(
|
300 |
+
self,
|
301 |
+
diffusion_model,
|
302 |
+
diffuser,
|
303 |
+
latents,
|
304 |
+
conditioning_latents,
|
305 |
+
seed,
|
306 |
+
temperature=1,
|
307 |
+
verbose=False,
|
308 |
+
):
|
309 |
+
self.deterministic_state(seed=seed)
|
310 |
+
with torch.no_grad():
|
311 |
+
output_seq_len = (
|
312 |
+
latents.shape[1] * 4 * 24000 // 22050
|
313 |
+
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
314 |
+
output_shape = (latents.shape[0], 100, output_seq_len)
|
315 |
+
precomputed_embeddings = diffusion_model.timestep_independent(
|
316 |
+
latents, conditioning_latents, output_seq_len, False
|
317 |
+
)
|
318 |
+
|
319 |
+
noise = torch.randn(output_shape, device=latents.device) * temperature
|
320 |
+
mel = diffuser.p_sample_loop(
|
321 |
+
diffusion_model,
|
322 |
+
output_shape,
|
323 |
+
noise=noise,
|
324 |
+
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
|
325 |
+
progress=verbose,
|
326 |
+
)
|
327 |
+
return self.denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
|
328 |
+
|
329 |
+
def parse(
|
330 |
+
self, best_results, best_latents, calm_token, diffusion_conditioning, text, seed
|
331 |
+
):
|
332 |
+
self.deterministic_state(seed=seed)
|
333 |
+
best_results = copy(best_results).to("cuda")
|
334 |
+
best_latents = copy(best_latents).to("cuda")
|
335 |
+
diffusion_conditioning = copy(diffusion_conditioning).to("cuda")
|
336 |
+
wav_candidates = []
|
337 |
+
for b in range(best_results.shape[0]):
|
338 |
+
|
339 |
+
codes = best_results[b].unsqueeze(0)
|
340 |
+
latents = best_latents[b].unsqueeze(0)
|
341 |
+
|
342 |
+
ctokens = 0
|
343 |
+
for k in range(codes.shape[-1]):
|
344 |
+
if codes[0, k] == calm_token:
|
345 |
+
ctokens += 1
|
346 |
+
else:
|
347 |
+
ctokens = 0
|
348 |
+
if ctokens > 8:
|
349 |
+
latents = latents[:, :k]
|
350 |
+
break
|
351 |
+
|
352 |
+
mel = self.do_spectrogram_diffusion(
|
353 |
+
self.diffusion,
|
354 |
+
self.diffuser,
|
355 |
+
latents,
|
356 |
+
diffusion_conditioning,
|
357 |
+
seed,
|
358 |
+
temperature=self.diffusion_temperature,
|
359 |
+
verbose=False,
|
360 |
+
)
|
361 |
+
wav = self.vocoder.inference(mel)
|
362 |
+
wav_candidates.append(wav)
|
363 |
+
# wav_candidates = [self.potentially_redact(wav_candidate, text).cpu().detach().numpy() for wav_candidate in
|
364 |
+
# wav_candidates]
|
365 |
+
# TODO: Check whether wav candidates should be in numpy
|
366 |
+
wav_candidates = [
|
367 |
+
self.potentially_redact(wav_candidate, text)
|
368 |
+
for wav_candidate in wav_candidates
|
369 |
+
]
|
370 |
+
return wav_candidates
|
371 |
+
|
372 |
+
class EndpointHandler():
|
373 |
+
def __init__(self, path="config-model.yaml"):
|
374 |
+
config = get_config_file(Path(path))
|
375 |
+
self.calm_token = 83
|
376 |
+
self.tokenizer = VoiceBpeTokenizer()
|
377 |
+
_, conditioning_latent_1 = load_voice("gabby_reading", map_location="cpu")
|
378 |
+
_, conditioning_latent_2 = load_voice("gabby_conversation", map_location="cpu")
|
379 |
+
|
380 |
+
# self.conditioning_latents1 = (latent.cpu().detach().numpy() for latent in conditioning_latent_1)
|
381 |
+
# self.conditioning_latents2 = (latent.cpu().detach().numpy() for latent in conditioning_latent_2)
|
382 |
+
self.conditioning_latents1 = (latent for latent in conditioning_latent_1)
|
383 |
+
self.conditioning_latents2 = (latent for latent in conditioning_latent_2)
|
384 |
+
(
|
385 |
+
self.auto_conditioning1,
|
386 |
+
self.diffusion_conditioning1,
|
387 |
+
) = self.conditioning_latents1
|
388 |
+
(
|
389 |
+
self.auto_conditioning2,
|
390 |
+
self.diffusion_conditioning2,
|
391 |
+
) = self.conditioning_latents2
|
392 |
+
|
393 |
+
self.auto_conditioning = None
|
394 |
+
self.diffusion_conditioning = None
|
395 |
+
self.gpt = Gpt(
|
396 |
+
config[GPT][NUM_AUTOREGRESSIVE_SAMPLES],
|
397 |
+
config[GPT][TOP_P],
|
398 |
+
config[GPT][TEMPERATURE],
|
399 |
+
config[GPT][LENGTH_PENALTY],
|
400 |
+
config[GPT][REPETITION_PENALTY],
|
401 |
+
config[GPT][MAX_MEL_TOKENS],
|
402 |
+
config[GPT][AUTO_REGRESSIVE_BATCH_SIZE],
|
403 |
+
)
|
404 |
+
self.clvp = clvp(config[CLVP_const]["k"])
|
405 |
+
self.diffusion = Diffusion(config[DIFFUSION][DIFFUSION_TEMPERATURE])
|
406 |
+
self.calm_token = 83
|
407 |
+
print("orchestrator setup completed")
|
408 |
+
|
409 |
+
@staticmethod
|
410 |
+
def __check_for_long_sentence(text_tokens):
|
411 |
+
assert (
|
412 |
+
text_tokens.shape[-1] < 400
|
413 |
+
), "Too much text provided. Break the text up into separate segments and re-try inference."
|
414 |
+
# TODO: split the text into several pieces and do the generation and combine them last
|
415 |
+
|
416 |
+
@staticmethod
|
417 |
+
def deterministic_state(seed=None):
|
418 |
+
seed = int(time()) if seed is None else seed
|
419 |
+
torch.manual_seed(seed)
|
420 |
+
random.seed(seed)
|
421 |
+
return seed
|
422 |
+
|
423 |
+
def preprocess_text(self, text: Text):
|
424 |
+
torch_tensor = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0)
|
425 |
+
return torch_tensor
|
426 |
+
|
427 |
+
def parse(self, res):
|
428 |
+
print("parsing")
|
429 |
+
file_name = hashlib.sha1(str(datetime.now()).encode("UTF-8"))
|
430 |
+
res = [torch.Tensor(copy(split)).squeeze(0).cpu() for split in res]
|
431 |
+
res = [torch.flatten(split) for split in res]
|
432 |
+
merged_audio_tensor = torch.cat(res).reshape(1, -1)
|
433 |
+
torchaudio.save(f"./{file_name.hexdigest()}.wav", merged_audio_tensor, 24000)
|
434 |
+
# torchaudio.save(f"./{file_name.hexdigest()}.wav", torch.Tensor(copy(res)).squeeze(0).cpu(), 24000)
|
435 |
+
return file_name.hexdigest()
|
436 |
+
|
437 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
438 |
+
voice = data["voice"]
|
439 |
+
text = data["text"]
|
440 |
+
seed = data["seed"]
|
441 |
+
if voice == "gabby_reading":
|
442 |
+
self.auto_conditioning = self.auto_conditioning1
|
443 |
+
self.diffusion_conditioning = self.diffusion_conditioning1
|
444 |
+
elif voice == "gabby_conversation":
|
445 |
+
self.auto_conditioning = self.auto_conditioning2
|
446 |
+
self.diffusion_conditioning = self.diffusion_conditioning2
|
447 |
+
|
448 |
+
self.deterministic_state(seed=seed)
|
449 |
+
text_tokens = self.preprocess_text(
|
450 |
+
text
|
451 |
+
) # preprocess the in-coming text into tokens
|
452 |
+
self.__check_for_long_sentence(text_tokens)
|
453 |
+
# text_tokens = text_tokens.cpu().detach().numpy()
|
454 |
+
samples, stop_mel_token = self.gpt.parse_inference(
|
455 |
+
self.auto_conditioning, text_tokens, seed
|
456 |
+
)
|
457 |
+
best_sample = self.clvp.parse(text_tokens, samples, stop_mel_token, seed)
|
458 |
+
best_latent = self.gpt.parse(
|
459 |
+
self.auto_conditioning, text_tokens, best_sample, seed
|
460 |
+
)
|
461 |
+
wav_candidates = self.diffusion.parse(
|
462 |
+
best_sample,
|
463 |
+
best_latent,
|
464 |
+
self.calm_token,
|
465 |
+
self.diffusion_conditioning,
|
466 |
+
text,
|
467 |
+
seed,
|
468 |
+
)
|
469 |
+
if len(wav_candidates) > 1:
|
470 |
+
res = wav_candidates
|
471 |
+
else:
|
472 |
+
res = wav_candidates[0]
|
473 |
+
|
474 |
+
buffered = BytesIO()
|
475 |
+
self.parse(res)
|
476 |
+
img_str = base64.b64encode(buffered.getvalue())
|
477 |
+
|
478 |
+
# postprocess the prediction
|
479 |
+
return {"audio": img_str.decode()}
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm~=4.64.0
|
2 |
+
rotary_embedding_torch
|
3 |
+
transformers~=4.21.2
|
4 |
+
tokenizers~=0.12.1
|
5 |
+
inflect~=6.0.0
|
6 |
+
progressbar~=2.5
|
7 |
+
einops~=0.4.1
|
8 |
+
unidecode~=1.3.4
|
9 |
+
scipy~=1.9.1
|
10 |
+
librosa~=0.9.2
|
11 |
+
numba==0.48.0
|
12 |
+
ffmpeg
|
13 |
+
fastapi~=0.81.0
|
14 |
+
ray[serve]~=2.0.0
|
15 |
+
PyYAML~=6.0
|
16 |
+
starlette~=0.19.1
|
17 |
+
numpy~=1.23.2
|
18 |
+
setuptools~=60.2.0
|
ruth_tts_transformer/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.idea
|
ruth_tts_transformer/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
VERSION = "0.0.27"
|
2 |
+
|
ruth_tts_transformer/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (164 Bytes). View file
|
|
ruth_tts_transformer/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (178 Bytes). View file
|
|
ruth_tts_transformer/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (156 Bytes). View file
|
|
ruth_tts_transformer/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (182 Bytes). View file
|
|
ruth_tts_transformer/data/latents.pkl
ADDED
Binary file (510 kB). View file
|
|
ruth_tts_transformer/data/layman.txt
ADDED
File without changes
|
ruth_tts_transformer/data/mel_norms.pth
ADDED
Binary file (1.07 kB). View file
|
|
ruth_tts_transformer/data/riding_hood.txt
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her. It suited the girl so extremely well that everybody called her Little Red Riding Hood.
|
2 |
+
One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter."
|
3 |
+
|
4 |
+
Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village.
|
5 |
+
|
6 |
+
As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest. He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother."
|
7 |
+
|
8 |
+
"Does she live far off?" said the wolf
|
9 |
+
|
10 |
+
"Oh I say," answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village."
|
11 |
+
|
12 |
+
"Well," said the wolf, "and I'll go and see her too. I'll go this way and go you that, and we shall see who will be there first."
|
13 |
+
|
14 |
+
The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers. It was not long before the wolf arrived at the old woman's house. He knocked at the door: tap, tap.
|
15 |
+
|
16 |
+
"Who's there?"
|
17 |
+
|
18 |
+
"Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother."
|
19 |
+
|
20 |
+
The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up."
|
21 |
+
|
22 |
+
The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten. He then shut the door and got into the grandmother's bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap.
|
23 |
+
|
24 |
+
"Who's there?"
|
25 |
+
|
26 |
+
Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you."
|
27 |
+
|
28 |
+
The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up."
|
29 |
+
|
30 |
+
Little Red Riding Hood pulled the bobbin, and the door opened.
|
31 |
+
|
32 |
+
The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me."
|
33 |
+
|
34 |
+
Little Red Riding Hood took off her clothes and got into bed. She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!"
|
35 |
+
|
36 |
+
"All the better to hug you with, my dear."
|
37 |
+
|
38 |
+
"Grandmother, what big legs you have!"
|
39 |
+
|
40 |
+
"All the better to run with, my child."
|
41 |
+
|
42 |
+
"Grandmother, what big ears you have!"
|
43 |
+
|
44 |
+
"All the better to hear with, my child."
|
45 |
+
|
46 |
+
"Grandmother, what big eyes you have!"
|
47 |
+
|
48 |
+
"All the better to see with, my child."
|
49 |
+
|
50 |
+
"Grandmother, what big teeth you have got!"
|
51 |
+
|
52 |
+
"All the better to eat you up with."
|
53 |
+
|
54 |
+
And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up.
|
ruth_tts_transformer/data/seal_copypasta.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al kayda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire U S armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the U S A and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little "clever" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo.
|
ruth_tts_transformer/data/tokenizer.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
|
ruth_tts_transformer/models/__init__.py
ADDED
File without changes
|
ruth_tts_transformer/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (153 Bytes). View file
|
|
ruth_tts_transformer/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (145 Bytes). View file
|
|
ruth_tts_transformer/models/__pycache__/arch_util.cpython-310.pyc
ADDED
Binary file (11.4 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/arch_util.cpython-38.pyc
ADDED
Binary file (11.4 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/autoregressive.cpython-310.pyc
ADDED
Binary file (17.8 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/autoregressive.cpython-38.pyc
ADDED
Binary file (17.8 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/clvp.cpython-310.pyc
ADDED
Binary file (4.13 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/clvp.cpython-38.pyc
ADDED
Binary file (4.11 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/diffusion_decoder.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/diffusion_decoder.cpython-38.pyc
ADDED
Binary file (10.1 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (7.85 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/transformer.cpython-38.pyc
ADDED
Binary file (7.84 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/vocoder.cpython-310.pyc
ADDED
Binary file (9.17 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/vocoder.cpython-38.pyc
ADDED
Binary file (9.13 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/xtransformers.cpython-310.pyc
ADDED
Binary file (34.7 kB). View file
|
|
ruth_tts_transformer/models/__pycache__/xtransformers.cpython-38.pyc
ADDED
Binary file (35.3 kB). View file
|
|
ruth_tts_transformer/models/arch_util.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import functools
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
from ruth_tts_transformer.models.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
|
10 |
+
|
11 |
+
|
12 |
+
def zero_module(module):
|
13 |
+
"""
|
14 |
+
Zero out the parameters of a module and return it.
|
15 |
+
"""
|
16 |
+
for p in module.parameters():
|
17 |
+
p.detach().zero_()
|
18 |
+
return module
|
19 |
+
|
20 |
+
|
21 |
+
class GroupNorm32(nn.GroupNorm):
|
22 |
+
def forward(self, x):
|
23 |
+
return super().forward(x.float()).type(x.dtype)
|
24 |
+
|
25 |
+
|
26 |
+
def normalization(channels):
|
27 |
+
"""
|
28 |
+
Make a standard normalization layer.
|
29 |
+
|
30 |
+
:param channels: number of input channels.
|
31 |
+
:return: an nn.Module for normalization.
|
32 |
+
"""
|
33 |
+
groups = 32
|
34 |
+
if channels <= 16:
|
35 |
+
groups = 8
|
36 |
+
elif channels <= 64:
|
37 |
+
groups = 16
|
38 |
+
while channels % groups != 0:
|
39 |
+
groups = int(groups / 2)
|
40 |
+
assert groups > 2
|
41 |
+
return GroupNorm32(groups, channels)
|
42 |
+
|
43 |
+
|
44 |
+
class QKVAttentionLegacy(nn.Module):
|
45 |
+
"""
|
46 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, n_heads):
|
50 |
+
super().__init__()
|
51 |
+
self.n_heads = n_heads
|
52 |
+
|
53 |
+
def forward(self, qkv, mask=None, rel_pos=None):
|
54 |
+
"""
|
55 |
+
Apply QKV attention.
|
56 |
+
|
57 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
58 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
59 |
+
"""
|
60 |
+
bs, width, length = qkv.shape
|
61 |
+
assert width % (3 * self.n_heads) == 0
|
62 |
+
ch = width // (3 * self.n_heads)
|
63 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
64 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
65 |
+
weight = torch.einsum(
|
66 |
+
"bct,bcs->bts", q * scale, k * scale
|
67 |
+
) # More stable with f16 than dividing afterwards
|
68 |
+
if rel_pos is not None:
|
69 |
+
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
70 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
71 |
+
if mask is not None:
|
72 |
+
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
73 |
+
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
74 |
+
weight = weight * mask
|
75 |
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
76 |
+
|
77 |
+
return a.reshape(bs, -1, length)
|
78 |
+
|
79 |
+
|
80 |
+
class AttentionBlock(nn.Module):
|
81 |
+
"""
|
82 |
+
An attention block that allows spatial positions to attend to each other.
|
83 |
+
|
84 |
+
Originally ported from here, but adapted to the N-d case.
|
85 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
channels,
|
91 |
+
num_heads=1,
|
92 |
+
num_head_channels=-1,
|
93 |
+
do_checkpoint=True,
|
94 |
+
relative_pos_embeddings=False,
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.channels = channels
|
98 |
+
self.do_checkpoint = do_checkpoint
|
99 |
+
if num_head_channels == -1:
|
100 |
+
self.num_heads = num_heads
|
101 |
+
else:
|
102 |
+
assert (
|
103 |
+
channels % num_head_channels == 0
|
104 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
105 |
+
self.num_heads = channels // num_head_channels
|
106 |
+
self.norm = normalization(channels)
|
107 |
+
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
108 |
+
# split heads before split qkv
|
109 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
110 |
+
|
111 |
+
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
112 |
+
if relative_pos_embeddings:
|
113 |
+
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
114 |
+
else:
|
115 |
+
self.relative_pos_embeddings = None
|
116 |
+
|
117 |
+
def forward(self, x, mask=None):
|
118 |
+
b, c, *spatial = x.shape
|
119 |
+
x = x.reshape(b, c, -1)
|
120 |
+
qkv = self.qkv(self.norm(x))
|
121 |
+
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
122 |
+
h = self.proj_out(h)
|
123 |
+
return (x + h).reshape(b, c, *spatial)
|
124 |
+
|
125 |
+
|
126 |
+
class Upsample(nn.Module):
|
127 |
+
"""
|
128 |
+
An upsampling layer with an optional convolution.
|
129 |
+
|
130 |
+
:param channels: channels in the inputs and outputs.
|
131 |
+
:param use_conv: a bool determining if a convolution is applied.
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self, channels, use_conv, out_channels=None, factor=4):
|
135 |
+
super().__init__()
|
136 |
+
self.channels = channels
|
137 |
+
self.out_channels = out_channels or channels
|
138 |
+
self.use_conv = use_conv
|
139 |
+
self.factor = factor
|
140 |
+
if use_conv:
|
141 |
+
ksize = 5
|
142 |
+
pad = 2
|
143 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
assert x.shape[1] == self.channels
|
147 |
+
x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
|
148 |
+
if self.use_conv:
|
149 |
+
x = self.conv(x)
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
class Downsample(nn.Module):
|
154 |
+
"""
|
155 |
+
A downsampling layer with an optional convolution.
|
156 |
+
|
157 |
+
:param channels: channels in the inputs and outputs.
|
158 |
+
:param use_conv: a bool determining if a convolution is applied.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
|
162 |
+
super().__init__()
|
163 |
+
self.channels = channels
|
164 |
+
self.out_channels = out_channels or channels
|
165 |
+
self.use_conv = use_conv
|
166 |
+
|
167 |
+
stride = factor
|
168 |
+
if use_conv:
|
169 |
+
self.op = nn.Conv1d(
|
170 |
+
self.channels, self.out_channels, ksize, stride=stride, padding=pad
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
assert self.channels == self.out_channels
|
174 |
+
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
assert x.shape[1] == self.channels
|
178 |
+
return self.op(x)
|
179 |
+
|
180 |
+
|
181 |
+
class ResBlock(nn.Module):
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
channels,
|
185 |
+
dropout,
|
186 |
+
out_channels=None,
|
187 |
+
use_conv=False,
|
188 |
+
use_scale_shift_norm=False,
|
189 |
+
up=False,
|
190 |
+
down=False,
|
191 |
+
kernel_size=3,
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
self.channels = channels
|
195 |
+
self.dropout = dropout
|
196 |
+
self.out_channels = out_channels or channels
|
197 |
+
self.use_conv = use_conv
|
198 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
199 |
+
padding = 1 if kernel_size == 3 else 2
|
200 |
+
|
201 |
+
self.in_layers = nn.Sequential(
|
202 |
+
normalization(channels),
|
203 |
+
nn.SiLU(),
|
204 |
+
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
|
205 |
+
)
|
206 |
+
|
207 |
+
self.updown = up or down
|
208 |
+
|
209 |
+
if up:
|
210 |
+
self.h_upd = Upsample(channels, False)
|
211 |
+
self.x_upd = Upsample(channels, False)
|
212 |
+
elif down:
|
213 |
+
self.h_upd = Downsample(channels, False)
|
214 |
+
self.x_upd = Downsample(channels, False)
|
215 |
+
else:
|
216 |
+
self.h_upd = self.x_upd = nn.Identity()
|
217 |
+
|
218 |
+
self.out_layers = nn.Sequential(
|
219 |
+
normalization(self.out_channels),
|
220 |
+
nn.SiLU(),
|
221 |
+
nn.Dropout(p=dropout),
|
222 |
+
zero_module(
|
223 |
+
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
224 |
+
),
|
225 |
+
)
|
226 |
+
|
227 |
+
if self.out_channels == channels:
|
228 |
+
self.skip_connection = nn.Identity()
|
229 |
+
elif use_conv:
|
230 |
+
self.skip_connection = nn.Conv1d(
|
231 |
+
channels, self.out_channels, kernel_size, padding=padding
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
if self.updown:
|
238 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
239 |
+
h = in_rest(x)
|
240 |
+
h = self.h_upd(h)
|
241 |
+
x = self.x_upd(x)
|
242 |
+
h = in_conv(h)
|
243 |
+
else:
|
244 |
+
h = self.in_layers(x)
|
245 |
+
h = self.out_layers(h)
|
246 |
+
return self.skip_connection(x) + h
|
247 |
+
|
248 |
+
|
249 |
+
class AudioMiniEncoder(nn.Module):
|
250 |
+
def __init__(self,
|
251 |
+
spec_dim,
|
252 |
+
embedding_dim,
|
253 |
+
base_channels=128,
|
254 |
+
depth=2,
|
255 |
+
resnet_blocks=2,
|
256 |
+
attn_blocks=4,
|
257 |
+
num_attn_heads=4,
|
258 |
+
dropout=0,
|
259 |
+
downsample_factor=2,
|
260 |
+
kernel_size=3):
|
261 |
+
super().__init__()
|
262 |
+
self.init = nn.Sequential(
|
263 |
+
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
|
264 |
+
)
|
265 |
+
ch = base_channels
|
266 |
+
res = []
|
267 |
+
for l in range(depth):
|
268 |
+
for r in range(resnet_blocks):
|
269 |
+
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
|
270 |
+
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
|
271 |
+
ch *= 2
|
272 |
+
self.res = nn.Sequential(*res)
|
273 |
+
self.final = nn.Sequential(
|
274 |
+
normalization(ch),
|
275 |
+
nn.SiLU(),
|
276 |
+
nn.Conv1d(ch, embedding_dim, 1)
|
277 |
+
)
|
278 |
+
attn = []
|
279 |
+
for a in range(attn_blocks):
|
280 |
+
attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
|
281 |
+
self.attn = nn.Sequential(*attn)
|
282 |
+
self.dim = embedding_dim
|
283 |
+
|
284 |
+
def forward(self, x):
|
285 |
+
h = self.init(x)
|
286 |
+
h = self.res(h)
|
287 |
+
h = self.final(h)
|
288 |
+
h = self.attn(h)
|
289 |
+
return h[:, :, 0]
|
290 |
+
|
291 |
+
|
292 |
+
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth')
|
293 |
+
|
294 |
+
|
295 |
+
class TorchMelSpectrogram(nn.Module):
|
296 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000,
|
297 |
+
sampling_rate=22050, normalize=False, mel_norm_file=DEFAULT_MEL_NORM_FILE):
|
298 |
+
super().__init__()
|
299 |
+
# These are the default tacotron values for the MEL spectrogram.
|
300 |
+
self.filter_length = filter_length
|
301 |
+
self.hop_length = hop_length
|
302 |
+
self.win_length = win_length
|
303 |
+
self.n_mel_channels = n_mel_channels
|
304 |
+
self.mel_fmin = mel_fmin
|
305 |
+
self.mel_fmax = mel_fmax
|
306 |
+
self.sampling_rate = sampling_rate
|
307 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
|
308 |
+
win_length=self.win_length, power=2, normalized=normalize,
|
309 |
+
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
|
310 |
+
f_max=self.mel_fmax, n_mels=self.n_mel_channels,
|
311 |
+
norm="slaney")
|
312 |
+
self.mel_norm_file = mel_norm_file
|
313 |
+
if self.mel_norm_file is not None:
|
314 |
+
self.mel_norms = torch.load(self.mel_norm_file)
|
315 |
+
else:
|
316 |
+
self.mel_norms = None
|
317 |
+
|
318 |
+
def forward(self, inp):
|
319 |
+
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
320 |
+
inp = inp.squeeze(1)
|
321 |
+
assert len(inp.shape) == 2
|
322 |
+
self.mel_stft = self.mel_stft.to(inp.device)
|
323 |
+
mel = self.mel_stft(inp)
|
324 |
+
# Perform dynamic range compression
|
325 |
+
mel = torch.log(torch.clamp(mel, min=1e-5))
|
326 |
+
if self.mel_norms is not None:
|
327 |
+
self.mel_norms = self.mel_norms.to(mel.device)
|
328 |
+
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
329 |
+
return mel
|
330 |
+
|
331 |
+
|
332 |
+
class CheckpointedLayer(nn.Module):
|
333 |
+
"""
|
334 |
+
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
335 |
+
checkpoint for all other args.
|
336 |
+
"""
|
337 |
+
def __init__(self, wrap):
|
338 |
+
super().__init__()
|
339 |
+
self.wrap = wrap
|
340 |
+
|
341 |
+
def forward(self, x, *args, **kwargs):
|
342 |
+
for k, v in kwargs.items():
|
343 |
+
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
344 |
+
partial = functools.partial(self.wrap, **kwargs)
|
345 |
+
return partial(x, *args)
|
346 |
+
|
347 |
+
|
348 |
+
class CheckpointedXTransformerEncoder(nn.Module):
|
349 |
+
"""
|
350 |
+
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
351 |
+
to channels-last that XTransformer expects.
|
352 |
+
"""
|
353 |
+
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
|
354 |
+
super().__init__()
|
355 |
+
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
356 |
+
self.needs_permute = needs_permute
|
357 |
+
self.exit_permute = exit_permute
|
358 |
+
|
359 |
+
if not checkpoint:
|
360 |
+
return
|
361 |
+
for i in range(len(self.transformer.attn_layers.layers)):
|
362 |
+
n, b, r = self.transformer.attn_layers.layers[i]
|
363 |
+
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
364 |
+
|
365 |
+
def forward(self, x, **kwargs):
|
366 |
+
if self.needs_permute:
|
367 |
+
x = x.permute(0,2,1)
|
368 |
+
h = self.transformer(x, **kwargs)
|
369 |
+
if self.exit_permute:
|
370 |
+
h = h.permute(0,2,1)
|
371 |
+
return h
|
ruth_tts_transformer/models/autoregressive.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
7 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
8 |
+
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
9 |
+
from ruth_tts_transformer.models.arch_util import AttentionBlock
|
10 |
+
from ruth_tts_transformer.utils.typical_sampling import TypicalLogitsWarper
|
11 |
+
|
12 |
+
|
13 |
+
def null_position_embeddings(range, dim):
|
14 |
+
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
15 |
+
|
16 |
+
|
17 |
+
class ResidualConvolutionBlock(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, chan):
|
20 |
+
super().__init__()
|
21 |
+
self.neural_network = nn.Sequential(
|
22 |
+
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
23 |
+
nn.GroupNorm(chan // 8, chan),
|
24 |
+
nn.ReLU(),
|
25 |
+
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
26 |
+
nn.GroupNorm(chan // 8, chan)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
return F.relu(self.neural_network(x) + x)
|
31 |
+
|
32 |
+
|
33 |
+
class GPT2InferenceModel(GPT2PreTrainedModel):
|
34 |
+
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
|
35 |
+
super().__init__(config)
|
36 |
+
self.transformer = gpt
|
37 |
+
self.text_pos_embedding = text_pos_emb
|
38 |
+
self.embeddings = embeddings
|
39 |
+
self.lm_head = nn.Sequential(norm, linear)
|
40 |
+
|
41 |
+
# Model parallel
|
42 |
+
self.model_parallel = False
|
43 |
+
self.device_map = None
|
44 |
+
self.cached_mel_emb = None
|
45 |
+
|
46 |
+
def parallelize(self, device_map=None):
|
47 |
+
self.device_map = (
|
48 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
49 |
+
if device_map is None
|
50 |
+
else device_map
|
51 |
+
)
|
52 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
53 |
+
self.transformer.parallelize(self.device_map)
|
54 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
55 |
+
self.model_parallel = True
|
56 |
+
|
57 |
+
def deparallelize(self):
|
58 |
+
self.transformer.deparallelize()
|
59 |
+
self.transformer = self.transformer.to("cpu")
|
60 |
+
self.lm_head = self.lm_head.to("cpu")
|
61 |
+
self.model_parallel = False
|
62 |
+
torch.cuda.empty_cache()
|
63 |
+
|
64 |
+
def get_output_embeddings(self):
|
65 |
+
return self.lm_head
|
66 |
+
|
67 |
+
def set_output_embeddings(self, new_embeddings):
|
68 |
+
self.lm_head = new_embeddings
|
69 |
+
|
70 |
+
def store_mel_emb(self, mel_emb):
|
71 |
+
self.cached_mel_emb = mel_emb
|
72 |
+
|
73 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
74 |
+
|
75 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
76 |
+
# only last token for inputs_ids if past is defined in kwargs
|
77 |
+
if past:
|
78 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
79 |
+
if token_type_ids is not None:
|
80 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
81 |
+
|
82 |
+
attention_mask = kwargs.get("attention_mask", None)
|
83 |
+
position_ids = kwargs.get("position_ids", None)
|
84 |
+
|
85 |
+
if attention_mask is not None and position_ids is None:
|
86 |
+
# create position_ids on the fly for batch generation
|
87 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
88 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
89 |
+
if past:
|
90 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
91 |
+
else:
|
92 |
+
position_ids = None
|
93 |
+
return {
|
94 |
+
"input_ids": input_ids,
|
95 |
+
"past_key_values": past,
|
96 |
+
"use_cache": kwargs.get("use_cache"),
|
97 |
+
"position_ids": position_ids,
|
98 |
+
"attention_mask": attention_mask,
|
99 |
+
"token_type_ids": token_type_ids,
|
100 |
+
}
|
101 |
+
|
102 |
+
def forward(
|
103 |
+
self,
|
104 |
+
input_ids=None,
|
105 |
+
past_key_values=None,
|
106 |
+
attention_mask=None,
|
107 |
+
token_type_ids=None,
|
108 |
+
position_ids=None,
|
109 |
+
head_mask=None,
|
110 |
+
inputs_embeds=None,
|
111 |
+
encoder_hidden_states=None,
|
112 |
+
encoder_attention_mask=None,
|
113 |
+
labels=None,
|
114 |
+
use_cache=None,
|
115 |
+
output_attentions=None,
|
116 |
+
output_hidden_states=None,
|
117 |
+
return_dict=None,
|
118 |
+
):
|
119 |
+
assert self.cached_mel_emb is not None
|
120 |
+
assert inputs_embeds is None # Not supported by this inference model.
|
121 |
+
assert labels is None # Training not supported by this inference model.
|
122 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
123 |
+
|
124 |
+
# Create embedding
|
125 |
+
mel_len = self.cached_mel_emb.shape[1]
|
126 |
+
if input_ids.shape[1] != 1:
|
127 |
+
text_inputs = input_ids[:, mel_len:]
|
128 |
+
text_emb = self.embeddings(text_inputs)
|
129 |
+
text_emb = text_emb + self.text_pos_embedding(text_emb)
|
130 |
+
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
131 |
+
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0)
|
132 |
+
else:
|
133 |
+
mel_emb = self.cached_mel_emb
|
134 |
+
emb = torch.cat([mel_emb, text_emb], dim=1)
|
135 |
+
else:
|
136 |
+
emb = self.embeddings(input_ids)
|
137 |
+
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - mel_len,
|
138 |
+
attention_mask.device)
|
139 |
+
|
140 |
+
transformer_outputs = self.transformer(
|
141 |
+
inputs_embeds=emb,
|
142 |
+
past_key_values=past_key_values,
|
143 |
+
attention_mask=attention_mask,
|
144 |
+
token_type_ids=token_type_ids,
|
145 |
+
position_ids=position_ids,
|
146 |
+
head_mask=head_mask,
|
147 |
+
encoder_hidden_states=encoder_hidden_states,
|
148 |
+
encoder_attention_mask=encoder_attention_mask,
|
149 |
+
use_cache=use_cache,
|
150 |
+
output_attentions=output_attentions,
|
151 |
+
output_hidden_states=output_hidden_states,
|
152 |
+
return_dict=return_dict,
|
153 |
+
)
|
154 |
+
hidden_states = transformer_outputs[0]
|
155 |
+
|
156 |
+
# Set device for model parallelism
|
157 |
+
if self.model_parallel:
|
158 |
+
torch.cuda.set_device(self.transformer.first_device)
|
159 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
160 |
+
|
161 |
+
lm_logits = self.lm_head(hidden_states)
|
162 |
+
|
163 |
+
if not return_dict:
|
164 |
+
return (lm_logits,) + transformer_outputs[1:]
|
165 |
+
|
166 |
+
return CausalLMOutputWithCrossAttentions(
|
167 |
+
loss=None,
|
168 |
+
logits=lm_logits,
|
169 |
+
past_key_values=transformer_outputs.past_key_values,
|
170 |
+
hidden_states=transformer_outputs.hidden_states,
|
171 |
+
attentions=transformer_outputs.attentions,
|
172 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
173 |
+
)
|
174 |
+
|
175 |
+
@staticmethod
|
176 |
+
def _reorder_cache(past, beam_idx):
|
177 |
+
"""
|
178 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
179 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
180 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
181 |
+
"""
|
182 |
+
return tuple(
|
183 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
184 |
+
for layer_past in past
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
class ConditioningEncoder(nn.Module):
|
189 |
+
def __init__(self,
|
190 |
+
spec_dim,
|
191 |
+
embedding_dim,
|
192 |
+
attn_blocks=6,
|
193 |
+
num_attn_heads=4,
|
194 |
+
do_checkpointing=False,
|
195 |
+
mean=False):
|
196 |
+
super().__init__()
|
197 |
+
attn = []
|
198 |
+
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
199 |
+
for a in range(attn_blocks):
|
200 |
+
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
201 |
+
self.attn = nn.Sequential(*attn)
|
202 |
+
self.dim = embedding_dim
|
203 |
+
self.do_checkpointing = do_checkpointing
|
204 |
+
self.mean = mean
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
h = self.init(x)
|
208 |
+
h = self.attn(h)
|
209 |
+
if self.mean:
|
210 |
+
return h.mean(dim=2)
|
211 |
+
else:
|
212 |
+
return h[:, :, 0]
|
213 |
+
|
214 |
+
|
215 |
+
class LearnedPositionEmbeddings(nn.Module):
|
216 |
+
def __init__(self, seq_len, model_dim, init=.02):
|
217 |
+
super().__init__()
|
218 |
+
self.emb = nn.Embedding(seq_len, model_dim)
|
219 |
+
# Initializing this way is standard for GPT-2
|
220 |
+
self.emb.weight.data.normal_(mean=0.0, std=init)
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
sl = x.shape[1]
|
224 |
+
return self.emb(torch.arange(0, sl, device=x.device))
|
225 |
+
|
226 |
+
def get_fixed_embedding(self, ind, dev):
|
227 |
+
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
228 |
+
|
229 |
+
|
230 |
+
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
231 |
+
"""
|
232 |
+
GPT-2 implemented by the HuggingFace library.
|
233 |
+
"""
|
234 |
+
from transformers import GPT2Config, GPT2Model
|
235 |
+
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
236 |
+
n_positions=max_mel_seq_len + max_text_seq_len,
|
237 |
+
n_ctx=max_mel_seq_len + max_text_seq_len,
|
238 |
+
n_embd=model_dim,
|
239 |
+
n_layer=layers,
|
240 |
+
n_head=heads,
|
241 |
+
gradient_checkpointing=checkpointing,
|
242 |
+
use_cache=not checkpointing)
|
243 |
+
gpt = GPT2Model(gpt_config)
|
244 |
+
# Override the built in positional embeddings
|
245 |
+
del gpt.wpe
|
246 |
+
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
247 |
+
# Built-in token embeddings are unused.
|
248 |
+
del gpt.wte
|
249 |
+
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len,
|
250 |
+
model_dim), \
|
251 |
+
None, None
|
252 |
+
|
253 |
+
|
254 |
+
class MelEncoder(nn.Module):
|
255 |
+
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
256 |
+
super().__init__()
|
257 |
+
self.channels = channels
|
258 |
+
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
|
259 |
+
nn.Sequential(*[ResidualConvolutionBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
|
260 |
+
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
|
261 |
+
nn.GroupNorm(channels // 16, channels // 2),
|
262 |
+
nn.ReLU(),
|
263 |
+
nn.Sequential(*[ResidualConvolutionBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
|
264 |
+
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
|
265 |
+
nn.GroupNorm(channels // 8, channels),
|
266 |
+
nn.ReLU(),
|
267 |
+
nn.Sequential(*[ResidualConvolutionBlock(channels) for _ in range(resblocks_per_reduction)]),
|
268 |
+
)
|
269 |
+
self.reduction = 4
|
270 |
+
|
271 |
+
def forward(self, x):
|
272 |
+
for e in self.encoder:
|
273 |
+
x = e(x)
|
274 |
+
return x.permute(0, 2, 1)
|
275 |
+
|
276 |
+
|
277 |
+
class UnifiedVoice(nn.Module):
|
278 |
+
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250,
|
279 |
+
max_conditioning_inputs=1,
|
280 |
+
mel_length_compression=1024, number_text_tokens=256,
|
281 |
+
start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
|
282 |
+
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
283 |
+
checkpointing=True, types=1):
|
284 |
+
"""
|
285 |
+
Args:
|
286 |
+
layers: Number of layers in transformer stack.
|
287 |
+
model_dim: Operating dimensions of the transformer
|
288 |
+
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
|
289 |
+
max_text_tokens: Maximum number of text tokens that will be encountered by model.
|
290 |
+
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
|
291 |
+
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
292 |
+
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
293 |
+
number_text_tokens:
|
294 |
+
start_text_token:
|
295 |
+
stop_text_token:
|
296 |
+
number_mel_codes:
|
297 |
+
start_mel_token:
|
298 |
+
stop_mel_token:
|
299 |
+
train_solo_embeddings:
|
300 |
+
use_mel_codes_as_input:
|
301 |
+
checkpointing:
|
302 |
+
"""
|
303 |
+
super().__init__()
|
304 |
+
|
305 |
+
self.number_text_tokens = number_text_tokens
|
306 |
+
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
307 |
+
self.stop_text_token = 0
|
308 |
+
self.number_mel_codes = number_mel_codes
|
309 |
+
self.start_mel_token = start_mel_token
|
310 |
+
self.stop_mel_token = stop_mel_token
|
311 |
+
self.layers = layers
|
312 |
+
self.heads = heads
|
313 |
+
self.max_mel_tokens = max_mel_tokens
|
314 |
+
self.max_text_tokens = max_text_tokens
|
315 |
+
self.model_dim = model_dim
|
316 |
+
self.max_conditioning_inputs = max_conditioning_inputs
|
317 |
+
self.mel_length_compression = mel_length_compression
|
318 |
+
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
319 |
+
self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
|
320 |
+
if use_mel_codes_as_input:
|
321 |
+
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
322 |
+
else:
|
323 |
+
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
324 |
+
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
325 |
+
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
|
326 |
+
self.max_text_tokens + 2, checkpointing)
|
327 |
+
if train_solo_embeddings:
|
328 |
+
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
329 |
+
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
330 |
+
else:
|
331 |
+
self.mel_solo_embedding = 0
|
332 |
+
self.text_solo_embedding = 0
|
333 |
+
|
334 |
+
self.final_norm = nn.LayerNorm(model_dim)
|
335 |
+
self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
|
336 |
+
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
337 |
+
|
338 |
+
# Initialize the embeddings per the GPT-2 scheme
|
339 |
+
embeddings = [self.text_embedding]
|
340 |
+
if use_mel_codes_as_input:
|
341 |
+
embeddings.append(self.mel_embedding)
|
342 |
+
for module in embeddings:
|
343 |
+
module.weight.data.normal_(mean=0.0, std=.02)
|
344 |
+
|
345 |
+
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
346 |
+
inp = F.pad(input, (1, 0), value=start_token)
|
347 |
+
tar = F.pad(input, (0, 1), value=stop_token)
|
348 |
+
return inp, tar
|
349 |
+
|
350 |
+
def set_mel_padding(self, mel_input_tokens, wav_lengths):
|
351 |
+
"""
|
352 |
+
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
|
353 |
+
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
|
354 |
+
preformatting to create a working TTS model.
|
355 |
+
"""
|
356 |
+
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
357 |
+
mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
358 |
+
for b in range(len(mel_lengths)):
|
359 |
+
actual_end = mel_lengths[
|
360 |
+
b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
|
361 |
+
if actual_end < mel_input_tokens.shape[-1]:
|
362 |
+
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
363 |
+
return mel_input_tokens
|
364 |
+
|
365 |
+
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None,
|
366 |
+
get_attns=False, return_latent=False):
|
367 |
+
if second_inputs is not None:
|
368 |
+
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
369 |
+
else:
|
370 |
+
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
371 |
+
|
372 |
+
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
373 |
+
if get_attns:
|
374 |
+
return gpt_out.attentions
|
375 |
+
|
376 |
+
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
|
377 |
+
enc = self.final_norm(enc)
|
378 |
+
|
379 |
+
if return_latent:
|
380 |
+
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1] + first_inputs.shape[
|
381 |
+
1]], enc[:, -second_inputs.shape[1]:]
|
382 |
+
|
383 |
+
first_logits = enc[:, :first_inputs.shape[1]]
|
384 |
+
first_logits = first_head(first_logits)
|
385 |
+
first_logits = first_logits.permute(0, 2, 1)
|
386 |
+
if second_inputs is not None:
|
387 |
+
second_logits = enc[:, -second_inputs.shape[1]:]
|
388 |
+
second_logits = second_head(second_logits)
|
389 |
+
second_logits = second_logits.permute(0, 2, 1)
|
390 |
+
return first_logits, second_logits
|
391 |
+
else:
|
392 |
+
return first_logits
|
393 |
+
|
394 |
+
def get_conditioning(self, speech_conditioning_input):
|
395 |
+
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
396 |
+
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
397 |
+
conds = []
|
398 |
+
for j in range(speech_conditioning_input.shape[1]):
|
399 |
+
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
400 |
+
conds = torch.stack(conds, dim=1)
|
401 |
+
conds = conds.mean(dim=1)
|
402 |
+
return conds
|
403 |
+
|
404 |
+
def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, types=None,
|
405 |
+
text_first=True, raw_mels=None, return_attentions=False,
|
406 |
+
return_latent=False, clip_inputs=True):
|
407 |
+
"""
|
408 |
+
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
409 |
+
(actuated by `text_first`).
|
410 |
+
|
411 |
+
speech_conditioning_input: MEL float tensor, (b,1024)
|
412 |
+
text_inputs: long tensor, (b,t)
|
413 |
+
text_lengths: long tensor, (b,)
|
414 |
+
mel_inputs: long tensor, (b,m)
|
415 |
+
wav_lengths: long tensor, (b,)
|
416 |
+
raw_mels: MEL float tensor (b,80,s)
|
417 |
+
|
418 |
+
If return_attentions is specified, only logits are returned.
|
419 |
+
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
420 |
+
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
421 |
+
"""
|
422 |
+
# Types are expressed by expanding the text embedding space.
|
423 |
+
if types is not None:
|
424 |
+
text_inputs = text_inputs * (1 + types).unsqueeze(-1)
|
425 |
+
|
426 |
+
if clip_inputs:
|
427 |
+
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
428 |
+
# chopping the inputs by the maximum actual length.
|
429 |
+
max_text_len = text_lengths.max()
|
430 |
+
text_inputs = text_inputs[:, :max_text_len]
|
431 |
+
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
432 |
+
mel_codes = mel_codes[:, :max_mel_len]
|
433 |
+
if raw_mels is not None:
|
434 |
+
raw_mels = raw_mels[:, :, :max_mel_len * 4]
|
435 |
+
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
436 |
+
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
437 |
+
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
|
438 |
+
|
439 |
+
conds = speech_conditioning_latent.unsqueeze(1)
|
440 |
+
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token,
|
441 |
+
self.stop_text_token)
|
442 |
+
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
443 |
+
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token,
|
444 |
+
self.stop_mel_token)
|
445 |
+
if raw_mels is not None:
|
446 |
+
mel_inp = F.pad(raw_mels, (0, 8))
|
447 |
+
else:
|
448 |
+
mel_inp = mel_codes
|
449 |
+
mel_emb = self.mel_embedding(mel_inp)
|
450 |
+
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
451 |
+
|
452 |
+
if text_first:
|
453 |
+
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head,
|
454 |
+
get_attns=return_attentions, return_latent=return_latent)
|
455 |
+
if return_latent:
|
456 |
+
return mel_logits[:,
|
457 |
+
:-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
458 |
+
else:
|
459 |
+
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head,
|
460 |
+
get_attns=return_attentions, return_latent=return_latent)
|
461 |
+
if return_latent:
|
462 |
+
return text_logits[:,
|
463 |
+
:-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
464 |
+
|
465 |
+
if return_attentions:
|
466 |
+
return mel_logits
|
467 |
+
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
468 |
+
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
469 |
+
return loss_text.mean(), loss_mel.mean(), mel_logits
|
470 |
+
|
471 |
+
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
472 |
+
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
473 |
+
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
474 |
+
if not hasattr(self, 'inference_model'):
|
475 |
+
# TODO: Decouple gpt_config from this inference model.
|
476 |
+
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
477 |
+
n_positions=seq_length,
|
478 |
+
n_ctx=seq_length,
|
479 |
+
n_embd=self.model_dim,
|
480 |
+
n_layer=self.layers,
|
481 |
+
n_head=self.heads,
|
482 |
+
gradient_checkpointing=False,
|
483 |
+
use_cache=True)
|
484 |
+
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding,
|
485 |
+
self.final_norm, self.mel_head)
|
486 |
+
self.gpt.wte = self.mel_embedding
|
487 |
+
|
488 |
+
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
489 |
+
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token,
|
490 |
+
self.stop_text_token)
|
491 |
+
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
492 |
+
|
493 |
+
conds = speech_conditioning_latent.unsqueeze(1)
|
494 |
+
emb = torch.cat([conds, text_emb], dim=1)
|
495 |
+
self.inference_model.store_mel_emb(emb)
|
496 |
+
|
497 |
+
fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
|
498 |
+
device=text_inputs.device)
|
499 |
+
fake_inputs[:, -1] = self.start_mel_token
|
500 |
+
trunc_index = fake_inputs.shape[1]
|
501 |
+
if input_tokens is None:
|
502 |
+
inputs = fake_inputs
|
503 |
+
else:
|
504 |
+
assert num_return_sequences % input_tokens.shape[
|
505 |
+
0] == 0, "The number of return sequences must be divisible by the number of input sequences"
|
506 |
+
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
|
507 |
+
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
|
508 |
+
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
509 |
+
|
510 |
+
logits_processor = LogitsProcessorList(
|
511 |
+
[TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
512 |
+
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
513 |
+
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
|
514 |
+
eos_token_id=self.stop_mel_token,
|
515 |
+
max_length=max_length, logits_processor=logits_processor,
|
516 |
+
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
517 |
+
return gen[:, trunc_index:]
|
518 |
+
|
519 |
+
|
520 |
+
if __name__ == '__main__':
|
521 |
+
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True,
|
522 |
+
max_conditioning_inputs=4)
|
523 |
+
l = gpt(torch.randn(2, 3, 80, 800),
|
524 |
+
torch.randint(high=120, size=(2, 120)),
|
525 |
+
torch.tensor([32, 120]),
|
526 |
+
torch.randint(high=8192, size=(2, 250)),
|
527 |
+
torch.tensor([250 * 256, 195 * 256]))
|
528 |
+
gpt.text_forward(torch.randn(2, 80, 800), torch.randint(high=50, size=(2, 80)), torch.tensor([32, 80]))
|
ruth_tts_transformer/models/clvp.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import einsum
|
5 |
+
|
6 |
+
from ruth_tts_transformer.models.arch_util import CheckpointedXTransformerEncoder
|
7 |
+
from ruth_tts_transformer.models.transformer import Transformer
|
8 |
+
from ruth_tts_transformer.models.xtransformers import Encoder
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def masked_mean(t, mask, dim = 1):
|
16 |
+
t = t.masked_fill(~mask[:, :, None], 0.)
|
17 |
+
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
|
18 |
+
|
19 |
+
class CLVP(nn.Module):
|
20 |
+
"""
|
21 |
+
CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
|
22 |
+
transcribed text.
|
23 |
+
|
24 |
+
Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
*,
|
30 |
+
dim_text=512,
|
31 |
+
dim_speech=512,
|
32 |
+
dim_latent=512,
|
33 |
+
num_text_tokens=256,
|
34 |
+
text_enc_depth=6,
|
35 |
+
text_seq_len=120,
|
36 |
+
text_heads=8,
|
37 |
+
num_speech_tokens=8192,
|
38 |
+
speech_enc_depth=6,
|
39 |
+
speech_heads=8,
|
40 |
+
speech_seq_len=250,
|
41 |
+
text_mask_percentage=0,
|
42 |
+
voice_mask_percentage=0,
|
43 |
+
wav_token_compression=1024,
|
44 |
+
use_xformers=False,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
48 |
+
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
|
49 |
+
|
50 |
+
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
51 |
+
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
|
52 |
+
|
53 |
+
if use_xformers:
|
54 |
+
self.text_transformer = CheckpointedXTransformerEncoder(
|
55 |
+
needs_permute=False,
|
56 |
+
exit_permute=False,
|
57 |
+
max_seq_len=-1,
|
58 |
+
attn_layers=Encoder(
|
59 |
+
dim=dim_text,
|
60 |
+
depth=text_enc_depth,
|
61 |
+
heads=text_heads,
|
62 |
+
ff_dropout=.1,
|
63 |
+
ff_mult=2,
|
64 |
+
attn_dropout=.1,
|
65 |
+
use_rmsnorm=True,
|
66 |
+
ff_glu=True,
|
67 |
+
rotary_pos_emb=True,
|
68 |
+
))
|
69 |
+
self.speech_transformer = CheckpointedXTransformerEncoder(
|
70 |
+
needs_permute=False,
|
71 |
+
exit_permute=False,
|
72 |
+
max_seq_len=-1,
|
73 |
+
attn_layers=Encoder(
|
74 |
+
dim=dim_speech,
|
75 |
+
depth=speech_enc_depth,
|
76 |
+
heads=speech_heads,
|
77 |
+
ff_dropout=.1,
|
78 |
+
ff_mult=2,
|
79 |
+
attn_dropout=.1,
|
80 |
+
use_rmsnorm=True,
|
81 |
+
ff_glu=True,
|
82 |
+
rotary_pos_emb=True,
|
83 |
+
))
|
84 |
+
else:
|
85 |
+
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
|
86 |
+
heads=text_heads)
|
87 |
+
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
88 |
+
depth=speech_enc_depth, heads=speech_heads)
|
89 |
+
|
90 |
+
self.temperature = nn.Parameter(torch.tensor(1.))
|
91 |
+
self.text_mask_percentage = text_mask_percentage
|
92 |
+
self.voice_mask_percentage = voice_mask_percentage
|
93 |
+
self.wav_token_compression = wav_token_compression
|
94 |
+
self.xformers = use_xformers
|
95 |
+
if not use_xformers:
|
96 |
+
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
97 |
+
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
98 |
+
|
99 |
+
def forward(
|
100 |
+
self,
|
101 |
+
text,
|
102 |
+
speech_tokens,
|
103 |
+
return_loss=False
|
104 |
+
):
|
105 |
+
b, device = text.shape[0], text.device
|
106 |
+
if self.training:
|
107 |
+
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
108 |
+
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
|
109 |
+
else:
|
110 |
+
text_mask = torch.ones_like(text.float()).bool()
|
111 |
+
voice_mask = torch.ones_like(speech_tokens.float()).bool()
|
112 |
+
|
113 |
+
text_emb = self.text_emb(text)
|
114 |
+
speech_emb = self.speech_emb(speech_tokens)
|
115 |
+
|
116 |
+
if not self.xformers:
|
117 |
+
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
118 |
+
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
|
119 |
+
|
120 |
+
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
121 |
+
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|
122 |
+
|
123 |
+
text_latents = masked_mean(enc_text, text_mask, dim=1)
|
124 |
+
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
|
125 |
+
|
126 |
+
text_latents = self.to_text_latent(text_latents)
|
127 |
+
speech_latents = self.to_speech_latent(speech_latents)
|
128 |
+
|
129 |
+
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
130 |
+
|
131 |
+
temp = self.temperature.exp()
|
132 |
+
|
133 |
+
if not return_loss:
|
134 |
+
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
|
135 |
+
return sim
|
136 |
+
|
137 |
+
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
138 |
+
labels = torch.arange(b, device=device)
|
139 |
+
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
140 |
+
return loss
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == '__main__':
|
144 |
+
clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2)
|
145 |
+
clip(torch.randint(0,256,(2,120)),
|
146 |
+
torch.tensor([50,100]),
|
147 |
+
torch.randint(0,8192,(2,250)),
|
148 |
+
torch.tensor([101,102]),
|
149 |
+
return_loss=True)
|
150 |
+
nonloss = clip(torch.randint(0,256,(2,120)),
|
151 |
+
torch.tensor([50,100]),
|
152 |
+
torch.randint(0,8192,(2,250)),
|
153 |
+
torch.tensor([101,102]),
|
154 |
+
return_loss=False)
|
155 |
+
print(nonloss.shape)
|
ruth_tts_transformer/models/diffusion_decoder.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
from abc import abstractmethod
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import autocast
|
9 |
+
|
10 |
+
from ruth_tts_transformer.models.arch_util import normalization, AttentionBlock
|
11 |
+
|
12 |
+
|
13 |
+
def is_latent(t):
|
14 |
+
return t.dtype == torch.float
|
15 |
+
|
16 |
+
|
17 |
+
def is_sequence(t):
|
18 |
+
return t.dtype == torch.long
|
19 |
+
|
20 |
+
|
21 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
22 |
+
"""
|
23 |
+
Create sinusoidal timestep embeddings.
|
24 |
+
|
25 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
26 |
+
These may be fractional.
|
27 |
+
:param dim: the dimension of the output.
|
28 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
29 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
30 |
+
"""
|
31 |
+
half = dim // 2
|
32 |
+
freqs = torch.exp(
|
33 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
34 |
+
).to(device=timesteps.device)
|
35 |
+
args = timesteps[:, None].float() * freqs[None]
|
36 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
37 |
+
if dim % 2:
|
38 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
39 |
+
return embedding
|
40 |
+
|
41 |
+
|
42 |
+
class TimestepBlock(nn.Module):
|
43 |
+
@abstractmethod
|
44 |
+
def forward(self, x, emb):
|
45 |
+
"""
|
46 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
47 |
+
"""
|
48 |
+
|
49 |
+
|
50 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
51 |
+
def forward(self, x, emb):
|
52 |
+
for layer in self:
|
53 |
+
if isinstance(layer, TimestepBlock):
|
54 |
+
x = layer(x, emb)
|
55 |
+
else:
|
56 |
+
x = layer(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class ResBlock(TimestepBlock):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
channels,
|
64 |
+
emb_channels,
|
65 |
+
dropout,
|
66 |
+
out_channels=None,
|
67 |
+
dims=2,
|
68 |
+
kernel_size=3,
|
69 |
+
efficient_config=True,
|
70 |
+
use_scale_shift_norm=False,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
self.channels = channels
|
74 |
+
self.emb_channels = emb_channels
|
75 |
+
self.dropout = dropout
|
76 |
+
self.out_channels = out_channels or channels
|
77 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
78 |
+
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
|
79 |
+
eff_kernel = 1 if efficient_config else 3
|
80 |
+
eff_padding = 0 if efficient_config else 1
|
81 |
+
|
82 |
+
self.in_layers = nn.Sequential(
|
83 |
+
normalization(channels),
|
84 |
+
nn.SiLU(),
|
85 |
+
nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
|
86 |
+
)
|
87 |
+
|
88 |
+
self.emb_layers = nn.Sequential(
|
89 |
+
nn.SiLU(),
|
90 |
+
nn.Linear(
|
91 |
+
emb_channels,
|
92 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
93 |
+
),
|
94 |
+
)
|
95 |
+
self.out_layers = nn.Sequential(
|
96 |
+
normalization(self.out_channels),
|
97 |
+
nn.SiLU(),
|
98 |
+
nn.Dropout(p=dropout),
|
99 |
+
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
|
100 |
+
)
|
101 |
+
|
102 |
+
if self.out_channels == channels:
|
103 |
+
self.skip_connection = nn.Identity()
|
104 |
+
else:
|
105 |
+
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
|
106 |
+
|
107 |
+
def forward(self, x, emb):
|
108 |
+
h = self.in_layers(x)
|
109 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
110 |
+
while len(emb_out.shape) < len(h.shape):
|
111 |
+
emb_out = emb_out[..., None]
|
112 |
+
if self.use_scale_shift_norm:
|
113 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
114 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
115 |
+
h = out_norm(h) * (1 + scale) + shift
|
116 |
+
h = out_rest(h)
|
117 |
+
else:
|
118 |
+
h = h + emb_out
|
119 |
+
h = self.out_layers(h)
|
120 |
+
return self.skip_connection(x) + h
|
121 |
+
|
122 |
+
|
123 |
+
class DiffusionLayer(TimestepBlock):
|
124 |
+
def __init__(self, model_channels, dropout, num_heads):
|
125 |
+
super().__init__()
|
126 |
+
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1,
|
127 |
+
use_scale_shift_norm=True)
|
128 |
+
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
129 |
+
|
130 |
+
def forward(self, x, time_emb):
|
131 |
+
y = self.resblk(x, time_emb)
|
132 |
+
return self.attn(y)
|
133 |
+
|
134 |
+
|
135 |
+
class DiffusionTts(nn.Module):
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
model_channels=512,
|
139 |
+
num_layers=8,
|
140 |
+
in_channels=100,
|
141 |
+
in_latent_channels=512,
|
142 |
+
in_tokens=8193,
|
143 |
+
out_channels=200, # mean and variance
|
144 |
+
dropout=0,
|
145 |
+
use_fp16=False,
|
146 |
+
num_heads=16,
|
147 |
+
# Parameters for regularization.
|
148 |
+
layer_drop=.1,
|
149 |
+
unconditioned_percentage=.1,
|
150 |
+
# This implements a mechanism similar to what is used in classifier-free training.
|
151 |
+
):
|
152 |
+
super().__init__()
|
153 |
+
|
154 |
+
self.in_channels = in_channels
|
155 |
+
self.model_channels = model_channels
|
156 |
+
self.out_channels = out_channels
|
157 |
+
self.dropout = dropout
|
158 |
+
self.num_heads = num_heads
|
159 |
+
self.unconditioned_percentage = unconditioned_percentage
|
160 |
+
self.enable_fp16 = use_fp16
|
161 |
+
self.layer_drop = layer_drop
|
162 |
+
|
163 |
+
self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
|
164 |
+
self.time_embed = nn.Sequential(
|
165 |
+
nn.Linear(model_channels, model_channels),
|
166 |
+
nn.SiLU(),
|
167 |
+
nn.Linear(model_channels, model_channels),
|
168 |
+
)
|
169 |
+
|
170 |
+
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
171 |
+
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
172 |
+
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
173 |
+
# transformer network.
|
174 |
+
self.code_embedding = nn.Embedding(in_tokens, model_channels)
|
175 |
+
self.code_converter = nn.Sequential(
|
176 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
177 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
178 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
179 |
+
)
|
180 |
+
self.code_norm = normalization(model_channels)
|
181 |
+
self.latent_conditioner = nn.Sequential(
|
182 |
+
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
183 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
184 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
185 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
186 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
187 |
+
)
|
188 |
+
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
|
189 |
+
nn.Conv1d(model_channels, model_channels * 2, 3, padding=1, stride=2),
|
190 |
+
AttentionBlock(model_channels * 2, num_heads,
|
191 |
+
relative_pos_embeddings=True, do_checkpoint=False),
|
192 |
+
AttentionBlock(model_channels * 2, num_heads,
|
193 |
+
relative_pos_embeddings=True, do_checkpoint=False),
|
194 |
+
AttentionBlock(model_channels * 2, num_heads,
|
195 |
+
relative_pos_embeddings=True, do_checkpoint=False),
|
196 |
+
AttentionBlock(model_channels * 2, num_heads,
|
197 |
+
relative_pos_embeddings=True, do_checkpoint=False),
|
198 |
+
AttentionBlock(model_channels * 2, num_heads,
|
199 |
+
relative_pos_embeddings=True, do_checkpoint=False))
|
200 |
+
self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
|
201 |
+
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
202 |
+
DiffusionLayer(model_channels, dropout, num_heads),
|
203 |
+
DiffusionLayer(model_channels, dropout, num_heads),
|
204 |
+
DiffusionLayer(model_channels, dropout, num_heads),
|
205 |
+
)
|
206 |
+
|
207 |
+
self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1)
|
208 |
+
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
209 |
+
|
210 |
+
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
211 |
+
[ResBlock(model_channels, model_channels, dropout, dims=1,
|
212 |
+
use_scale_shift_norm=True) for _ in range(3)])
|
213 |
+
|
214 |
+
self.out = nn.Sequential(
|
215 |
+
normalization(model_channels),
|
216 |
+
nn.SiLU(),
|
217 |
+
nn.Conv1d(model_channels, out_channels, 3, padding=1),
|
218 |
+
)
|
219 |
+
|
220 |
+
def get_grad_norm_parameter_groups(self):
|
221 |
+
groups = {
|
222 |
+
'minicoder': list(self.contextual_embedder.parameters()),
|
223 |
+
'layers': list(self.layers.parameters()),
|
224 |
+
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(
|
225 |
+
self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()),
|
226 |
+
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(
|
227 |
+
self.integrating_conv.parameters()),
|
228 |
+
'time_embed': list(self.time_embed.parameters()),
|
229 |
+
}
|
230 |
+
return groups
|
231 |
+
|
232 |
+
def get_conditioning(self, conditioning_input):
|
233 |
+
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
|
234 |
+
conditioning_input.shape) == 3 else conditioning_input
|
235 |
+
conds = []
|
236 |
+
for j in range(speech_conditioning_input.shape[1]):
|
237 |
+
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
238 |
+
conds = torch.cat(conds, dim=-1)
|
239 |
+
conds = conds.mean(dim=-1)
|
240 |
+
return conds
|
241 |
+
|
242 |
+
def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred):
|
243 |
+
# Shuffle aligned_latent to BxCxS format
|
244 |
+
if is_latent(aligned_conditioning):
|
245 |
+
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
246 |
+
|
247 |
+
cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1)
|
248 |
+
if is_latent(aligned_conditioning):
|
249 |
+
code_emb = self.latent_conditioner(aligned_conditioning)
|
250 |
+
else:
|
251 |
+
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
252 |
+
code_emb = self.code_converter(code_emb)
|
253 |
+
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
254 |
+
|
255 |
+
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
256 |
+
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
257 |
+
if self.training and self.unconditioned_percentage > 0:
|
258 |
+
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
259 |
+
device=code_emb.device) < self.unconditioned_percentage
|
260 |
+
code_emb = torch.where(unconditioned_batches,
|
261 |
+
self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
262 |
+
code_emb)
|
263 |
+
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
264 |
+
|
265 |
+
if not return_code_pred:
|
266 |
+
return expanded_code_emb
|
267 |
+
else:
|
268 |
+
mel_pred = self.mel_head(expanded_code_emb)
|
269 |
+
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
|
270 |
+
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
271 |
+
return expanded_code_emb, mel_pred
|
272 |
+
|
273 |
+
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_latent=None,
|
274 |
+
precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
275 |
+
"""
|
276 |
+
Apply the model to an input batch.
|
277 |
+
|
278 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
279 |
+
:param timesteps: a 1-D batch of timesteps.
|
280 |
+
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
281 |
+
:param conditioning_latent: a pre-computed conditioning latent; see get_conditioning().
|
282 |
+
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
|
283 |
+
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
284 |
+
:return: an [N x C x ...] Tensor of outputs.
|
285 |
+
"""
|
286 |
+
assert precomputed_aligned_embeddings is not None or (
|
287 |
+
aligned_conditioning is not None and conditioning_latent is not None)
|
288 |
+
assert not (
|
289 |
+
return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
290 |
+
|
291 |
+
unused_params = []
|
292 |
+
if conditioning_free:
|
293 |
+
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
294 |
+
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
295 |
+
unused_params.extend(list(self.latent_conditioner.parameters()))
|
296 |
+
else:
|
297 |
+
if precomputed_aligned_embeddings is not None:
|
298 |
+
code_emb = precomputed_aligned_embeddings
|
299 |
+
else:
|
300 |
+
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_latent, x.shape[-1],
|
301 |
+
True)
|
302 |
+
if is_latent(aligned_conditioning):
|
303 |
+
unused_params.extend(
|
304 |
+
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
305 |
+
else:
|
306 |
+
unused_params.extend(list(self.latent_conditioner.parameters()))
|
307 |
+
|
308 |
+
unused_params.append(self.unconditioned_embedding)
|
309 |
+
|
310 |
+
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
311 |
+
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
312 |
+
x = self.inp_block(x)
|
313 |
+
x = torch.cat([x, code_emb], dim=1)
|
314 |
+
x = self.integrating_conv(x)
|
315 |
+
for i, lyr in enumerate(self.layers):
|
316 |
+
# Do layer drop where applicable. Do not drop first and last layers.
|
317 |
+
if self.training and self.layer_drop > 0 and i != 0 and i != (
|
318 |
+
len(self.layers) - 1) and random.random() < self.layer_drop:
|
319 |
+
unused_params.extend(list(lyr.parameters()))
|
320 |
+
else:
|
321 |
+
# First and last blocks will have autocast disabled for improved precision.
|
322 |
+
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
|
323 |
+
x = lyr(x, time_emb)
|
324 |
+
|
325 |
+
x = x.float()
|
326 |
+
out = self.out(x)
|
327 |
+
|
328 |
+
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
329 |
+
extraneous_addition = 0
|
330 |
+
for p in unused_params:
|
331 |
+
extraneous_addition = extraneous_addition + p.mean()
|
332 |
+
out = out + extraneous_addition * 0
|
333 |
+
|
334 |
+
if return_code_pred:
|
335 |
+
return out, mel_pred
|
336 |
+
return out
|
337 |
+
|
338 |
+
|
339 |
+
if __name__ == '__main__':
|
340 |
+
clip = torch.randn(2, 100, 400)
|
341 |
+
aligned_latent = torch.randn(2, 388, 512)
|
342 |
+
aligned_sequence = torch.randint(0, 8192, (2, 100))
|
343 |
+
cond = torch.randn(2, 100, 400)
|
344 |
+
ts = torch.LongTensor([600, 600])
|
345 |
+
model = DiffusionTts(512, layer_drop=.3, unconditioned_percentage=.5)
|
346 |
+
# Test with latent aligned conditioning
|
347 |
+
# o = model(clip, ts, aligned_latent, cond)
|
348 |
+
# Test with sequence aligned conditioning
|
349 |
+
o = model(clip, ts, aligned_sequence, cond)
|
ruth_tts_transformer/models/transformer.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from einops import rearrange
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
# helpers
|
8 |
+
|
9 |
+
|
10 |
+
def exists(val):
|
11 |
+
return val is not None
|
12 |
+
|
13 |
+
|
14 |
+
def default(val, d):
|
15 |
+
return val if exists(val) else d
|
16 |
+
|
17 |
+
|
18 |
+
def cast_tuple(val, depth=1):
|
19 |
+
if isinstance(val, list):
|
20 |
+
val = tuple(val)
|
21 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
22 |
+
|
23 |
+
|
24 |
+
def max_neg_value(t):
|
25 |
+
return -torch.finfo(t.dtype).max
|
26 |
+
|
27 |
+
|
28 |
+
def stable_softmax(t, dim=-1, alpha=32 ** 2):
|
29 |
+
t = t / alpha
|
30 |
+
t = t - torch.amax(t, dim=dim, keepdim=True).detach()
|
31 |
+
return (t * alpha).softmax(dim=dim)
|
32 |
+
|
33 |
+
|
34 |
+
def route_args(router, args, depth):
|
35 |
+
routed_args = [(dict(), dict()) for _ in range(depth)]
|
36 |
+
matched_keys = [key for key in args.keys() if key in router]
|
37 |
+
|
38 |
+
for key in matched_keys:
|
39 |
+
val = args[key]
|
40 |
+
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
41 |
+
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
42 |
+
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
43 |
+
return routed_args
|
44 |
+
|
45 |
+
|
46 |
+
# classes
|
47 |
+
class SequentialSequence(nn.Module):
|
48 |
+
def __init__(self, layers, args_route={}, layer_dropout=0.):
|
49 |
+
super().__init__()
|
50 |
+
assert all(len(route) == len(layers) for route in
|
51 |
+
args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
|
52 |
+
self.layers = layers
|
53 |
+
self.args_route = args_route
|
54 |
+
self.layer_dropout = layer_dropout
|
55 |
+
|
56 |
+
def forward(self, x, **kwargs):
|
57 |
+
args = route_args(self.args_route, kwargs, len(self.layers))
|
58 |
+
layers_and_args = list(zip(self.layers, args))
|
59 |
+
|
60 |
+
for (f, g), (f_args, g_args) in layers_and_args:
|
61 |
+
x = x + f(x, **f_args)
|
62 |
+
x = x + g(x, **g_args)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class DivideMax(nn.Module):
|
67 |
+
def __init__(self, dim):
|
68 |
+
super().__init__()
|
69 |
+
self.dim = dim
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
maxes = x.amax(dim=self.dim, keepdim=True).detach()
|
73 |
+
return x / maxes
|
74 |
+
|
75 |
+
|
76 |
+
# https://arxiv.org/abs/2103.17239
|
77 |
+
class LayerScale(nn.Module):
|
78 |
+
def __init__(self, dim, depth, fn):
|
79 |
+
super().__init__()
|
80 |
+
if depth <= 18:
|
81 |
+
init_eps = 0.1
|
82 |
+
elif depth > 18 and depth <= 24:
|
83 |
+
init_eps = 1e-5
|
84 |
+
else:
|
85 |
+
init_eps = 1e-6
|
86 |
+
|
87 |
+
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
88 |
+
self.scale = nn.Parameter(scale)
|
89 |
+
self.fn = fn
|
90 |
+
|
91 |
+
def forward(self, x, **kwargs):
|
92 |
+
return self.fn(x, **kwargs) * self.scale
|
93 |
+
|
94 |
+
|
95 |
+
# layer norm
|
96 |
+
|
97 |
+
|
98 |
+
class PreNorm(nn.Module):
|
99 |
+
def __init__(self, dim, fn, sandwich=False):
|
100 |
+
super().__init__()
|
101 |
+
self.norm = nn.LayerNorm(dim)
|
102 |
+
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
|
103 |
+
self.fn = fn
|
104 |
+
|
105 |
+
def forward(self, x, **kwargs):
|
106 |
+
x = self.norm(x)
|
107 |
+
x = self.fn(x, **kwargs)
|
108 |
+
return self.norm_out(x)
|
109 |
+
|
110 |
+
|
111 |
+
# feed forward
|
112 |
+
|
113 |
+
|
114 |
+
class GEGLU(nn.Module):
|
115 |
+
def forward(self, x):
|
116 |
+
x, gates = x.chunk(2, dim=-1)
|
117 |
+
return x * F.gelu(gates)
|
118 |
+
|
119 |
+
|
120 |
+
class FeedForward(nn.Module):
|
121 |
+
def __init__(self, dim, dropout=0., mult=4.):
|
122 |
+
super().__init__()
|
123 |
+
self.net = nn.Sequential(
|
124 |
+
nn.Linear(dim, dim * mult * 2),
|
125 |
+
GEGLU(),
|
126 |
+
nn.Dropout(dropout),
|
127 |
+
nn.Linear(dim * mult, dim)
|
128 |
+
)
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
return self.net(x)
|
132 |
+
|
133 |
+
|
134 |
+
# Attention
|
135 |
+
|
136 |
+
|
137 |
+
class Attention(nn.Module):
|
138 |
+
def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.):
|
139 |
+
super().__init__()
|
140 |
+
inner_dim = dim_head * heads
|
141 |
+
self.heads = heads
|
142 |
+
self.seq_len = seq_len
|
143 |
+
self.scale = dim_head ** -0.5
|
144 |
+
|
145 |
+
self.causal = causal
|
146 |
+
|
147 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
148 |
+
self.to_out = nn.Sequential(
|
149 |
+
nn.Linear(inner_dim, dim),
|
150 |
+
nn.Dropout(dropout)
|
151 |
+
)
|
152 |
+
|
153 |
+
def forward(self, x, mask=None):
|
154 |
+
b, n, _, h, device = *x.shape, self.heads, x.device
|
155 |
+
softmax = torch.softmax
|
156 |
+
|
157 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
158 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
|
159 |
+
|
160 |
+
q = q * self.scale
|
161 |
+
|
162 |
+
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
163 |
+
mask_value = max_neg_value(dots)
|
164 |
+
|
165 |
+
if exists(mask):
|
166 |
+
mask = rearrange(mask, 'b j -> b () () j')
|
167 |
+
dots.masked_fill_(~mask, mask_value)
|
168 |
+
del mask
|
169 |
+
|
170 |
+
if self.causal:
|
171 |
+
i, j = dots.shape[-2:]
|
172 |
+
mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
|
173 |
+
dots.masked_fill_(mask, mask_value)
|
174 |
+
|
175 |
+
attn = softmax(dots, dim=-1)
|
176 |
+
|
177 |
+
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
178 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
179 |
+
out = self.to_out(out)
|
180 |
+
return out
|
181 |
+
|
182 |
+
|
183 |
+
# main transformer class
|
184 |
+
class Transformer(nn.Module):
|
185 |
+
def __init__(
|
186 |
+
self,
|
187 |
+
*,
|
188 |
+
dim,
|
189 |
+
depth,
|
190 |
+
seq_len,
|
191 |
+
causal=True,
|
192 |
+
heads=8,
|
193 |
+
dim_head=64,
|
194 |
+
ff_mult=4,
|
195 |
+
attn_dropout=0.,
|
196 |
+
ff_dropout=0.,
|
197 |
+
sparse_attn=False,
|
198 |
+
sandwich_norm=False,
|
199 |
+
):
|
200 |
+
super().__init__()
|
201 |
+
layers = nn.ModuleList([])
|
202 |
+
sparse_layer = cast_tuple(sparse_attn, depth)
|
203 |
+
|
204 |
+
for ind, sparse_attn in zip(range(depth), sparse_layer):
|
205 |
+
attn = Attention(dim, causal=causal, seq_len=seq_len, heads=heads, dim_head=dim_head, dropout=attn_dropout)
|
206 |
+
|
207 |
+
ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
|
208 |
+
|
209 |
+
layers.append(nn.ModuleList([
|
210 |
+
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)),
|
211 |
+
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm))
|
212 |
+
]))
|
213 |
+
|
214 |
+
execute_type = SequentialSequence
|
215 |
+
route_attn = ((True, False),) * depth
|
216 |
+
attn_route_map = {'mask': route_attn}
|
217 |
+
|
218 |
+
self.layers = execute_type(layers, args_route=attn_route_map)
|
219 |
+
|
220 |
+
def forward(self, x, **kwargs):
|
221 |
+
return self.layers(x, **kwargs)
|
ruth_tts_transformer/models/vocoder.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
MAX_WAV_VALUE = 32768.0
|
6 |
+
|
7 |
+
class KernelPredictor(torch.nn.Module):
|
8 |
+
''' Kernel predictor for the location-variable convolutions'''
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
cond_channels,
|
13 |
+
conv_in_channels,
|
14 |
+
conv_out_channels,
|
15 |
+
conv_layers,
|
16 |
+
conv_kernel_size=3,
|
17 |
+
kpnet_hidden_channels=64,
|
18 |
+
kpnet_conv_size=3,
|
19 |
+
kpnet_dropout=0.0,
|
20 |
+
kpnet_nonlinear_activation="LeakyReLU",
|
21 |
+
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
22 |
+
):
|
23 |
+
'''
|
24 |
+
Args:
|
25 |
+
cond_channels (int): number of channel for the conditioning sequence,
|
26 |
+
conv_in_channels (int): number of channel for the input sequence,
|
27 |
+
conv_out_channels (int): number of channel for the output sequence,
|
28 |
+
conv_layers (int): number of layers
|
29 |
+
'''
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.conv_in_channels = conv_in_channels
|
33 |
+
self.conv_out_channels = conv_out_channels
|
34 |
+
self.conv_kernel_size = conv_kernel_size
|
35 |
+
self.conv_layers = conv_layers
|
36 |
+
|
37 |
+
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
38 |
+
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
39 |
+
|
40 |
+
self.input_conv = nn.Sequential(
|
41 |
+
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
42 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
43 |
+
)
|
44 |
+
|
45 |
+
self.residual_convs = nn.ModuleList()
|
46 |
+
padding = (kpnet_conv_size - 1) // 2
|
47 |
+
for _ in range(3):
|
48 |
+
self.residual_convs.append(
|
49 |
+
nn.Sequential(
|
50 |
+
nn.Dropout(kpnet_dropout),
|
51 |
+
nn.utils.weight_norm(
|
52 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
53 |
+
bias=True)),
|
54 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
55 |
+
nn.utils.weight_norm(
|
56 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
57 |
+
bias=True)),
|
58 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
59 |
+
)
|
60 |
+
)
|
61 |
+
self.kernel_conv = nn.utils.weight_norm(
|
62 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True))
|
63 |
+
self.bias_conv = nn.utils.weight_norm(
|
64 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True))
|
65 |
+
|
66 |
+
def forward(self, c):
|
67 |
+
'''
|
68 |
+
Args:
|
69 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
70 |
+
'''
|
71 |
+
batch, _, cond_length = c.shape
|
72 |
+
c = self.input_conv(c)
|
73 |
+
for residual_conv in self.residual_convs:
|
74 |
+
residual_conv.to(c.device)
|
75 |
+
c = c + residual_conv(c)
|
76 |
+
k = self.kernel_conv(c)
|
77 |
+
b = self.bias_conv(c)
|
78 |
+
kernels = k.contiguous().view(
|
79 |
+
batch,
|
80 |
+
self.conv_layers,
|
81 |
+
self.conv_in_channels,
|
82 |
+
self.conv_out_channels,
|
83 |
+
self.conv_kernel_size,
|
84 |
+
cond_length,
|
85 |
+
)
|
86 |
+
bias = b.contiguous().view(
|
87 |
+
batch,
|
88 |
+
self.conv_layers,
|
89 |
+
self.conv_out_channels,
|
90 |
+
cond_length,
|
91 |
+
)
|
92 |
+
|
93 |
+
return kernels, bias
|
94 |
+
|
95 |
+
def remove_weight_norm(self):
|
96 |
+
nn.utils.remove_weight_norm(self.input_conv[0])
|
97 |
+
nn.utils.remove_weight_norm(self.kernel_conv)
|
98 |
+
nn.utils.remove_weight_norm(self.bias_conv)
|
99 |
+
for block in self.residual_convs:
|
100 |
+
nn.utils.remove_weight_norm(block[1])
|
101 |
+
nn.utils.remove_weight_norm(block[3])
|
102 |
+
|
103 |
+
|
104 |
+
class LVCBlock(torch.nn.Module):
|
105 |
+
'''the location-variable convolutions'''
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
in_channels,
|
110 |
+
cond_channels,
|
111 |
+
stride,
|
112 |
+
dilations=[1, 3, 9, 27],
|
113 |
+
lReLU_slope=0.2,
|
114 |
+
conv_kernel_size=3,
|
115 |
+
cond_hop_length=256,
|
116 |
+
kpnet_hidden_channels=64,
|
117 |
+
kpnet_conv_size=3,
|
118 |
+
kpnet_dropout=0.0,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.cond_hop_length = cond_hop_length
|
123 |
+
self.conv_layers = len(dilations)
|
124 |
+
self.conv_kernel_size = conv_kernel_size
|
125 |
+
|
126 |
+
self.kernel_predictor = KernelPredictor(
|
127 |
+
cond_channels=cond_channels,
|
128 |
+
conv_in_channels=in_channels,
|
129 |
+
conv_out_channels=2 * in_channels,
|
130 |
+
conv_layers=len(dilations),
|
131 |
+
conv_kernel_size=conv_kernel_size,
|
132 |
+
kpnet_hidden_channels=kpnet_hidden_channels,
|
133 |
+
kpnet_conv_size=kpnet_conv_size,
|
134 |
+
kpnet_dropout=kpnet_dropout,
|
135 |
+
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}
|
136 |
+
)
|
137 |
+
|
138 |
+
self.convt_pre = nn.Sequential(
|
139 |
+
nn.LeakyReLU(lReLU_slope),
|
140 |
+
nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride,
|
141 |
+
padding=stride // 2 + stride % 2, output_padding=stride % 2)),
|
142 |
+
)
|
143 |
+
|
144 |
+
self.conv_blocks = nn.ModuleList()
|
145 |
+
for dilation in dilations:
|
146 |
+
self.conv_blocks.append(
|
147 |
+
nn.Sequential(
|
148 |
+
nn.LeakyReLU(lReLU_slope),
|
149 |
+
nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size,
|
150 |
+
padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)),
|
151 |
+
nn.LeakyReLU(lReLU_slope),
|
152 |
+
)
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(self, x, c):
|
156 |
+
''' forward propagation of the location-variable convolutions.
|
157 |
+
Args:
|
158 |
+
x (Tensor): the input sequence (batch, in_channels, in_length)
|
159 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Tensor: the output sequence (batch, in_channels, in_length)
|
163 |
+
'''
|
164 |
+
_, in_channels, _ = x.shape # (B, c_g, L')
|
165 |
+
|
166 |
+
x = self.convt_pre(x) # (B, c_g, stride * L')
|
167 |
+
kernels, bias = self.kernel_predictor(c)
|
168 |
+
|
169 |
+
for i, conv in enumerate(self.conv_blocks):
|
170 |
+
output = conv(x) # (B, c_g, stride * L')
|
171 |
+
|
172 |
+
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
173 |
+
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
174 |
+
|
175 |
+
output = self.location_variable_convolution(output, k, b,
|
176 |
+
hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC
|
177 |
+
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
178 |
+
output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
183 |
+
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
184 |
+
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
185 |
+
Args:
|
186 |
+
x (Tensor): the input sequence (batch, in_channels, in_length).
|
187 |
+
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
188 |
+
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
189 |
+
dilation (int): the dilation of convolution.
|
190 |
+
hop_size (int): the hop_size of the conditioning sequence.
|
191 |
+
Returns:
|
192 |
+
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
193 |
+
'''
|
194 |
+
batch, _, in_length = x.shape
|
195 |
+
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
196 |
+
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
197 |
+
|
198 |
+
padding = dilation * int((kernel_size - 1) / 2)
|
199 |
+
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
|
200 |
+
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
201 |
+
|
202 |
+
if hop_size < dilation:
|
203 |
+
x = F.pad(x, (0, dilation), 'constant', 0)
|
204 |
+
x = x.unfold(3, dilation,
|
205 |
+
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
206 |
+
x = x[:, :, :, :, :hop_size]
|
207 |
+
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
208 |
+
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
209 |
+
|
210 |
+
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
|
211 |
+
o = o.to(memory_format=torch.channels_last_3d)
|
212 |
+
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
213 |
+
o = o + bias
|
214 |
+
o = o.contiguous().view(batch, out_channels, -1)
|
215 |
+
|
216 |
+
return o
|
217 |
+
|
218 |
+
def remove_weight_norm(self):
|
219 |
+
self.kernel_predictor.remove_weight_norm()
|
220 |
+
nn.utils.remove_weight_norm(self.convt_pre[1])
|
221 |
+
for block in self.conv_blocks:
|
222 |
+
nn.utils.remove_weight_norm(block[1])
|
223 |
+
|
224 |
+
|
225 |
+
class UnivNetGenerator(nn.Module):
|
226 |
+
"""UnivNet Generator"""
|
227 |
+
|
228 |
+
def __init__(self, noise_dim=64, channel_size=16, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
|
229 |
+
# Below are MEL configurations options that this generator requires.
|
230 |
+
hop_length=256, n_mel_channels=100):
|
231 |
+
super(UnivNetGenerator, self).__init__()
|
232 |
+
self.mel_channel = n_mel_channels
|
233 |
+
self.noise_dim = noise_dim
|
234 |
+
self.hop_length = hop_length
|
235 |
+
channel_size = channel_size
|
236 |
+
kpnet_conv_size = kpnet_conv_size
|
237 |
+
|
238 |
+
self.res_stack = nn.ModuleList()
|
239 |
+
hop_length = 1
|
240 |
+
for stride in strides:
|
241 |
+
hop_length = stride * hop_length
|
242 |
+
self.res_stack.append(
|
243 |
+
LVCBlock(
|
244 |
+
channel_size,
|
245 |
+
n_mel_channels,
|
246 |
+
stride=stride,
|
247 |
+
dilations=dilations,
|
248 |
+
lReLU_slope=lReLU_slope,
|
249 |
+
cond_hop_length=hop_length,
|
250 |
+
kpnet_conv_size=kpnet_conv_size
|
251 |
+
)
|
252 |
+
)
|
253 |
+
|
254 |
+
self.conv_pre = \
|
255 |
+
nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
|
256 |
+
|
257 |
+
self.conv_post = nn.Sequential(
|
258 |
+
nn.LeakyReLU(lReLU_slope),
|
259 |
+
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
|
260 |
+
nn.Tanh(),
|
261 |
+
)
|
262 |
+
|
263 |
+
def forward(self, c, z):
|
264 |
+
'''
|
265 |
+
Args:
|
266 |
+
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
|
267 |
+
z (Tensor): the noise sequence (batch, noise_dim, in_length)
|
268 |
+
|
269 |
+
'''
|
270 |
+
z = self.conv_pre(z) # (B, c_g, L)
|
271 |
+
|
272 |
+
for res_block in self.res_stack:
|
273 |
+
res_block.to(z.device)
|
274 |
+
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
|
275 |
+
|
276 |
+
z = self.conv_post(z) # (B, 1, L * 256)
|
277 |
+
|
278 |
+
return z
|
279 |
+
|
280 |
+
def eval(self, inference=False):
|
281 |
+
super(UnivNetGenerator, self).eval()
|
282 |
+
# don't remove weight norm while validation in training loop
|
283 |
+
if inference:
|
284 |
+
self.remove_weight_norm()
|
285 |
+
|
286 |
+
def remove_weight_norm(self):
|
287 |
+
nn.utils.remove_weight_norm(self.conv_pre)
|
288 |
+
|
289 |
+
for layer in self.conv_post:
|
290 |
+
if len(layer.state_dict()) != 0:
|
291 |
+
nn.utils.remove_weight_norm(layer)
|
292 |
+
|
293 |
+
for res_block in self.res_stack:
|
294 |
+
res_block.remove_weight_norm()
|
295 |
+
|
296 |
+
def inference(self, c, z=None):
|
297 |
+
# pad input mel with zeros to cut artifact
|
298 |
+
# see https://github.com/seungwonpark/melgan/issues/8
|
299 |
+
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
|
300 |
+
mel = torch.cat((c, zero), dim=2)
|
301 |
+
|
302 |
+
if z is None:
|
303 |
+
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
|
304 |
+
|
305 |
+
audio = self.forward(mel, z)
|
306 |
+
audio = audio[:, :, :-(self.hop_length * 10)]
|
307 |
+
audio = audio.clamp(min=-1, max=1)
|
308 |
+
return audio
|
309 |
+
|
310 |
+
|
311 |
+
if __name__ == '__main__':
|
312 |
+
model = UnivNetGenerator()
|
313 |
+
|
314 |
+
c = torch.randn(3, 100, 10)
|
315 |
+
z = torch.randn(3, 64, 10)
|
316 |
+
print(c.shape)
|
317 |
+
|
318 |
+
y = model(c, z)
|
319 |
+
print(y.shape)
|
320 |
+
assert y.shape == torch.Size([3, 1, 2560])
|
321 |
+
|
322 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
323 |
+
print(pytorch_total_params)
|
ruth_tts_transformer/models/xtransformers.py
ADDED
@@ -0,0 +1,1248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import namedtuple
|
3 |
+
from functools import partial
|
4 |
+
from inspect import isfunction
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
from torch import nn, einsum
|
10 |
+
|
11 |
+
DEFAULT_DIM_HEAD = 64
|
12 |
+
|
13 |
+
Intermediates = namedtuple('Intermediates', [
|
14 |
+
'pre_softmax_attn',
|
15 |
+
'post_softmax_attn'
|
16 |
+
])
|
17 |
+
|
18 |
+
LayerIntermediates = namedtuple('Intermediates', [
|
19 |
+
'hiddens',
|
20 |
+
'attn_intermediates',
|
21 |
+
'past_key_values',
|
22 |
+
])
|
23 |
+
|
24 |
+
|
25 |
+
# helpers
|
26 |
+
|
27 |
+
def exists(val):
|
28 |
+
return val is not None
|
29 |
+
|
30 |
+
|
31 |
+
def default(val, d):
|
32 |
+
if exists(val):
|
33 |
+
return val
|
34 |
+
return d() if isfunction(d) else d
|
35 |
+
|
36 |
+
|
37 |
+
def cast_tuple(val, depth):
|
38 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
39 |
+
|
40 |
+
|
41 |
+
class always():
|
42 |
+
def __init__(self, val):
|
43 |
+
self.val = val
|
44 |
+
|
45 |
+
def __call__(self, *args, **kwargs):
|
46 |
+
return self.val
|
47 |
+
|
48 |
+
|
49 |
+
class not_equals():
|
50 |
+
def __init__(self, val):
|
51 |
+
self.val = val
|
52 |
+
|
53 |
+
def __call__(self, x, *args, **kwargs):
|
54 |
+
return x != self.val
|
55 |
+
|
56 |
+
|
57 |
+
class equals():
|
58 |
+
def __init__(self, val):
|
59 |
+
self.val = val
|
60 |
+
|
61 |
+
def __call__(self, x, *args, **kwargs):
|
62 |
+
return x == self.val
|
63 |
+
|
64 |
+
|
65 |
+
def max_neg_value(tensor):
|
66 |
+
return -torch.finfo(tensor.dtype).max
|
67 |
+
|
68 |
+
|
69 |
+
def l2norm(t):
|
70 |
+
return F.normalize(t, p=2, dim=-1)
|
71 |
+
|
72 |
+
|
73 |
+
# init helpers
|
74 |
+
|
75 |
+
def init_zero_(layer):
|
76 |
+
nn.init.constant_(layer.weight, 0.)
|
77 |
+
if exists(layer.bias):
|
78 |
+
nn.init.constant_(layer.bias, 0.)
|
79 |
+
|
80 |
+
|
81 |
+
# keyword argument helpers
|
82 |
+
|
83 |
+
def pick_and_pop(keys, d):
|
84 |
+
values = list(map(lambda key: d.pop(key), keys))
|
85 |
+
return dict(zip(keys, values))
|
86 |
+
|
87 |
+
|
88 |
+
def group_dict_by_key(cond, d):
|
89 |
+
return_val = [dict(), dict()]
|
90 |
+
for key in d.keys():
|
91 |
+
match = bool(cond(key))
|
92 |
+
ind = int(not match)
|
93 |
+
return_val[ind][key] = d[key]
|
94 |
+
return (*return_val,)
|
95 |
+
|
96 |
+
|
97 |
+
def string_begins_with(prefix, str):
|
98 |
+
return str.startswith(prefix)
|
99 |
+
|
100 |
+
|
101 |
+
def group_by_key_prefix(prefix, d):
|
102 |
+
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
103 |
+
|
104 |
+
|
105 |
+
def groupby_prefix_and_trim(prefix, d):
|
106 |
+
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
107 |
+
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
108 |
+
return kwargs_without_prefix, kwargs
|
109 |
+
|
110 |
+
|
111 |
+
# activations
|
112 |
+
|
113 |
+
class ReluSquared(nn.Module):
|
114 |
+
def forward(self, x):
|
115 |
+
return F.relu(x) ** 2
|
116 |
+
|
117 |
+
|
118 |
+
# positional embeddings
|
119 |
+
|
120 |
+
class AbsolutePositionalEmbedding(nn.Module):
|
121 |
+
def __init__(self, dim, max_seq_len):
|
122 |
+
super().__init__()
|
123 |
+
self.scale = dim ** -0.5
|
124 |
+
self.emb = nn.Embedding(max_seq_len, dim)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
n = torch.arange(x.shape[1], device=x.device)
|
128 |
+
pos_emb = self.emb(n)
|
129 |
+
pos_emb = rearrange(pos_emb, 'n d -> () n d')
|
130 |
+
return pos_emb * self.scale
|
131 |
+
|
132 |
+
|
133 |
+
class FixedPositionalEmbedding(nn.Module):
|
134 |
+
def __init__(self, dim):
|
135 |
+
super().__init__()
|
136 |
+
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
137 |
+
self.register_buffer('inv_freq', inv_freq)
|
138 |
+
|
139 |
+
def forward(self, x, seq_dim=1, offset=0):
|
140 |
+
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
141 |
+
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
142 |
+
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
143 |
+
return rearrange(emb, 'n d -> () n d')
|
144 |
+
|
145 |
+
|
146 |
+
class RelativePositionBias(nn.Module):
|
147 |
+
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
|
148 |
+
super().__init__()
|
149 |
+
self.scale = scale
|
150 |
+
self.causal = causal
|
151 |
+
self.num_buckets = num_buckets
|
152 |
+
self.max_distance = max_distance
|
153 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
157 |
+
ret = 0
|
158 |
+
n = -relative_position
|
159 |
+
if not causal:
|
160 |
+
num_buckets //= 2
|
161 |
+
ret += (n < 0).long() * num_buckets
|
162 |
+
n = torch.abs(n)
|
163 |
+
else:
|
164 |
+
n = torch.max(n, torch.zeros_like(n))
|
165 |
+
|
166 |
+
max_exact = num_buckets // 2
|
167 |
+
is_small = n < max_exact
|
168 |
+
|
169 |
+
val_if_large = max_exact + (
|
170 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
171 |
+
).long()
|
172 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
173 |
+
|
174 |
+
ret += torch.where(is_small, n, val_if_large)
|
175 |
+
return ret
|
176 |
+
|
177 |
+
def forward(self, qk_dots):
|
178 |
+
i, j, device = *qk_dots.shape[-2:], qk_dots.device
|
179 |
+
q_pos = torch.arange(i, dtype=torch.long, device=device)
|
180 |
+
k_pos = torch.arange(j, dtype=torch.long, device=device)
|
181 |
+
rel_pos = k_pos[None, :] - q_pos[:, None]
|
182 |
+
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
|
183 |
+
max_distance=self.max_distance)
|
184 |
+
values = self.relative_attention_bias(rp_bucket)
|
185 |
+
bias = rearrange(values, 'i j h -> () h i j')
|
186 |
+
return qk_dots + (bias * self.scale)
|
187 |
+
|
188 |
+
|
189 |
+
class AlibiPositionalBias(nn.Module):
|
190 |
+
def __init__(self, heads, **kwargs):
|
191 |
+
super().__init__()
|
192 |
+
self.heads = heads
|
193 |
+
slopes = torch.Tensor(self._get_slopes(heads))
|
194 |
+
slopes = rearrange(slopes, 'h -> () h () ()')
|
195 |
+
self.register_buffer('slopes', slopes, persistent=False)
|
196 |
+
self.register_buffer('bias', None, persistent=False)
|
197 |
+
|
198 |
+
@staticmethod
|
199 |
+
def _get_slopes(heads):
|
200 |
+
def get_slopes_power_of_2(n):
|
201 |
+
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
202 |
+
ratio = start
|
203 |
+
return [start * ratio ** i for i in range(n)]
|
204 |
+
|
205 |
+
if math.log2(heads).is_integer():
|
206 |
+
return get_slopes_power_of_2(heads)
|
207 |
+
|
208 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
|
209 |
+
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
|
210 |
+
:heads - closest_power_of_2]
|
211 |
+
|
212 |
+
def forward(self, qk_dots):
|
213 |
+
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
|
214 |
+
|
215 |
+
if exists(self.bias) and self.bias.shape[-1] >= j:
|
216 |
+
return qk_dots + self.bias[..., :j]
|
217 |
+
|
218 |
+
bias = torch.arange(j, device=device)
|
219 |
+
bias = rearrange(bias, 'j -> () () () j')
|
220 |
+
bias = bias * self.slopes
|
221 |
+
|
222 |
+
num_heads_unalibied = h - bias.shape[1]
|
223 |
+
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
|
224 |
+
|
225 |
+
self.register_buffer('bias', bias, persistent=False)
|
226 |
+
return qk_dots + self.bias
|
227 |
+
|
228 |
+
|
229 |
+
class LearnedAlibiPositionalBias(AlibiPositionalBias):
|
230 |
+
def __init__(self, heads, bidirectional=False):
|
231 |
+
super().__init__(heads)
|
232 |
+
los_slopes = torch.log(self.slopes)
|
233 |
+
self.learned_logslopes = nn.Parameter(los_slopes)
|
234 |
+
|
235 |
+
self.bidirectional = bidirectional
|
236 |
+
if self.bidirectional:
|
237 |
+
self.learned_logslopes_future = nn.Parameter(los_slopes)
|
238 |
+
|
239 |
+
def forward(self, qk_dots):
|
240 |
+
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
|
241 |
+
|
242 |
+
def get_slopes(param):
|
243 |
+
return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
|
244 |
+
|
245 |
+
if exists(self.bias) and self.bias.shape[-1] >= j:
|
246 |
+
bias = self.bias[..., :i, :j]
|
247 |
+
else:
|
248 |
+
i_arange = torch.arange(i, device=device)
|
249 |
+
j_arange = torch.arange(j, device=device)
|
250 |
+
bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
|
251 |
+
self.register_buffer('bias', bias, persistent=False)
|
252 |
+
|
253 |
+
if self.bidirectional:
|
254 |
+
past_slopes = get_slopes(self.learned_logslopes)
|
255 |
+
future_slopes = get_slopes(self.learned_logslopes_future)
|
256 |
+
bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
|
257 |
+
else:
|
258 |
+
slopes = get_slopes(self.learned_logslopes)
|
259 |
+
bias = bias * slopes
|
260 |
+
|
261 |
+
return qk_dots + bias
|
262 |
+
|
263 |
+
|
264 |
+
class RotaryEmbedding(nn.Module):
|
265 |
+
def __init__(self, dim):
|
266 |
+
super().__init__()
|
267 |
+
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
268 |
+
self.register_buffer('inv_freq', inv_freq)
|
269 |
+
|
270 |
+
def forward(self, max_seq_len, device):
|
271 |
+
t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
|
272 |
+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
273 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
274 |
+
return rearrange(emb, 'n d -> () () n d')
|
275 |
+
|
276 |
+
|
277 |
+
def rotate_half(x):
|
278 |
+
x = rearrange(x, '... (j d) -> ... j d', j=2)
|
279 |
+
x1, x2 = x.unbind(dim=-2)
|
280 |
+
return torch.cat((-x2, x1), dim=-1)
|
281 |
+
|
282 |
+
|
283 |
+
def apply_rotary_pos_emb(t, freqs):
|
284 |
+
seq_len = t.shape[-2]
|
285 |
+
freqs = freqs[:, :, -seq_len:]
|
286 |
+
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
|
287 |
+
|
288 |
+
|
289 |
+
# norms
|
290 |
+
|
291 |
+
class Scale(nn.Module):
|
292 |
+
def __init__(self, value, fn):
|
293 |
+
super().__init__()
|
294 |
+
self.value = value
|
295 |
+
self.fn = fn
|
296 |
+
|
297 |
+
def forward(self, x, **kwargs):
|
298 |
+
out = self.fn(x, **kwargs)
|
299 |
+
scale_fn = lambda t: t * self.value
|
300 |
+
|
301 |
+
if not isinstance(out, tuple):
|
302 |
+
return scale_fn(out)
|
303 |
+
|
304 |
+
return (scale_fn(out[0]), *out[1:])
|
305 |
+
|
306 |
+
|
307 |
+
class Rezero(nn.Module):
|
308 |
+
def __init__(self, fn):
|
309 |
+
super().__init__()
|
310 |
+
self.fn = fn
|
311 |
+
self.g = nn.Parameter(torch.zeros(1))
|
312 |
+
|
313 |
+
def forward(self, x, **kwargs):
|
314 |
+
out = self.fn(x, **kwargs)
|
315 |
+
rezero_fn = lambda t: t * self.g
|
316 |
+
|
317 |
+
if not isinstance(out, tuple):
|
318 |
+
return rezero_fn(out)
|
319 |
+
|
320 |
+
return (rezero_fn(out[0]), *out[1:])
|
321 |
+
|
322 |
+
|
323 |
+
class ScaleNorm(nn.Module):
|
324 |
+
def __init__(self, dim, eps=1e-5):
|
325 |
+
super().__init__()
|
326 |
+
self.scale = dim ** -0.5
|
327 |
+
self.eps = eps
|
328 |
+
self.g = nn.Parameter(torch.ones(1))
|
329 |
+
|
330 |
+
def forward(self, x):
|
331 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
332 |
+
return x / norm.clamp(min=self.eps) * self.g
|
333 |
+
|
334 |
+
|
335 |
+
class RMSNorm(nn.Module):
|
336 |
+
def __init__(self, dim, eps=1e-8):
|
337 |
+
super().__init__()
|
338 |
+
self.scale = dim ** -0.5
|
339 |
+
self.eps = eps
|
340 |
+
self.g = nn.Parameter(torch.ones(dim))
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
344 |
+
return x / norm.clamp(min=self.eps) * self.g
|
345 |
+
|
346 |
+
|
347 |
+
class RMSScaleShiftNorm(nn.Module):
|
348 |
+
def __init__(self, dim, eps=1e-8):
|
349 |
+
super().__init__()
|
350 |
+
self.scale = dim ** -0.5
|
351 |
+
self.eps = eps
|
352 |
+
self.g = nn.Parameter(torch.ones(dim))
|
353 |
+
self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
|
354 |
+
|
355 |
+
def forward(self, x, norm_scale_shift_inp):
|
356 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
357 |
+
norm = x / norm.clamp(min=self.eps) * self.g
|
358 |
+
|
359 |
+
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
|
360 |
+
scale, shift = torch.chunk(ss_emb, 2, dim=1)
|
361 |
+
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
362 |
+
return h
|
363 |
+
|
364 |
+
|
365 |
+
# residual and residual gates
|
366 |
+
|
367 |
+
class Residual(nn.Module):
|
368 |
+
def __init__(self, dim, scale_residual=False):
|
369 |
+
super().__init__()
|
370 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
371 |
+
|
372 |
+
def forward(self, x, residual):
|
373 |
+
if exists(self.residual_scale):
|
374 |
+
residual = residual * self.residual_scale
|
375 |
+
|
376 |
+
return x + residual
|
377 |
+
|
378 |
+
|
379 |
+
class GRUGating(nn.Module):
|
380 |
+
def __init__(self, dim, scale_residual=False):
|
381 |
+
super().__init__()
|
382 |
+
self.gru = nn.GRUCell(dim, dim)
|
383 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
384 |
+
|
385 |
+
def forward(self, x, residual):
|
386 |
+
if exists(self.residual_scale):
|
387 |
+
residual = residual * self.residual_scale
|
388 |
+
|
389 |
+
gated_output = self.gru(
|
390 |
+
rearrange(x, 'b n d -> (b n) d'),
|
391 |
+
rearrange(residual, 'b n d -> (b n) d')
|
392 |
+
)
|
393 |
+
|
394 |
+
return gated_output.reshape_as(x)
|
395 |
+
|
396 |
+
|
397 |
+
# token shifting
|
398 |
+
|
399 |
+
def shift(t, amount, mask=None):
|
400 |
+
if amount == 0:
|
401 |
+
return t
|
402 |
+
|
403 |
+
if exists(mask):
|
404 |
+
t = t.masked_fill(~mask[..., None], 0.)
|
405 |
+
|
406 |
+
return F.pad(t, (0, 0, amount, -amount), value=0.)
|
407 |
+
|
408 |
+
|
409 |
+
class ShiftTokens(nn.Module):
|
410 |
+
def __init__(self, shifts, fn):
|
411 |
+
super().__init__()
|
412 |
+
self.fn = fn
|
413 |
+
self.shifts = tuple(shifts)
|
414 |
+
|
415 |
+
def forward(self, x, **kwargs):
|
416 |
+
mask = kwargs.get('mask', None)
|
417 |
+
shifts = self.shifts
|
418 |
+
segments = len(shifts)
|
419 |
+
feats_per_shift = x.shape[-1] // segments
|
420 |
+
splitted = x.split(feats_per_shift, dim=-1)
|
421 |
+
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
422 |
+
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
|
423 |
+
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
424 |
+
return self.fn(x, **kwargs)
|
425 |
+
|
426 |
+
|
427 |
+
# feedforward
|
428 |
+
|
429 |
+
class GLU(nn.Module):
|
430 |
+
def __init__(self, dim_in, dim_out, activation):
|
431 |
+
super().__init__()
|
432 |
+
self.act = activation
|
433 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
434 |
+
|
435 |
+
def forward(self, x):
|
436 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
437 |
+
return x * self.act(gate)
|
438 |
+
|
439 |
+
|
440 |
+
class FeedForward(nn.Module):
|
441 |
+
def __init__(
|
442 |
+
self,
|
443 |
+
dim,
|
444 |
+
dim_out=None,
|
445 |
+
mult=4,
|
446 |
+
glu=False,
|
447 |
+
relu_squared=False,
|
448 |
+
post_act_ln=False,
|
449 |
+
dropout=0.,
|
450 |
+
zero_init_output=False
|
451 |
+
):
|
452 |
+
super().__init__()
|
453 |
+
inner_dim = int(dim * mult)
|
454 |
+
dim_out = default(dim_out, dim)
|
455 |
+
activation = ReluSquared() if relu_squared else nn.GELU()
|
456 |
+
|
457 |
+
project_in = nn.Sequential(
|
458 |
+
nn.Linear(dim, inner_dim),
|
459 |
+
activation
|
460 |
+
) if not glu else GLU(dim, inner_dim, activation)
|
461 |
+
|
462 |
+
self.net = nn.Sequential(
|
463 |
+
project_in,
|
464 |
+
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
465 |
+
nn.Dropout(dropout),
|
466 |
+
nn.Linear(inner_dim, dim_out)
|
467 |
+
)
|
468 |
+
|
469 |
+
# init last linear layer to 0
|
470 |
+
if zero_init_output:
|
471 |
+
init_zero_(self.net[-1])
|
472 |
+
|
473 |
+
def forward(self, x):
|
474 |
+
return self.net(x)
|
475 |
+
|
476 |
+
|
477 |
+
# attention.
|
478 |
+
|
479 |
+
class Attention(nn.Module):
|
480 |
+
def __init__(
|
481 |
+
self,
|
482 |
+
dim,
|
483 |
+
dim_head=DEFAULT_DIM_HEAD,
|
484 |
+
heads=8,
|
485 |
+
causal=False,
|
486 |
+
talking_heads=False,
|
487 |
+
head_scale=False,
|
488 |
+
collab_heads=False,
|
489 |
+
collab_compression=.3,
|
490 |
+
sparse_topk=None,
|
491 |
+
use_entmax15=False,
|
492 |
+
num_mem_kv=0,
|
493 |
+
dropout=0.,
|
494 |
+
on_attn=False,
|
495 |
+
gate_values=False,
|
496 |
+
zero_init_output=False,
|
497 |
+
max_attend_past=None,
|
498 |
+
qk_norm=False,
|
499 |
+
scale_init_value=None,
|
500 |
+
rel_pos_bias=False,
|
501 |
+
rel_pos_num_buckets=32,
|
502 |
+
rel_pos_max_distance=128,
|
503 |
+
):
|
504 |
+
super().__init__()
|
505 |
+
self.scale = dim_head ** -0.5
|
506 |
+
|
507 |
+
self.heads = heads
|
508 |
+
self.causal = causal
|
509 |
+
self.max_attend_past = max_attend_past
|
510 |
+
|
511 |
+
qk_dim = v_dim = dim_head * heads
|
512 |
+
|
513 |
+
# collaborative heads
|
514 |
+
self.collab_heads = collab_heads
|
515 |
+
if self.collab_heads:
|
516 |
+
qk_dim = int(collab_compression * qk_dim)
|
517 |
+
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
|
518 |
+
|
519 |
+
self.to_q = nn.Linear(dim, qk_dim, bias=False)
|
520 |
+
self.to_k = nn.Linear(dim, qk_dim, bias=False)
|
521 |
+
self.to_v = nn.Linear(dim, v_dim, bias=False)
|
522 |
+
|
523 |
+
self.dropout = nn.Dropout(dropout)
|
524 |
+
|
525 |
+
# add GLU gating for aggregated values, from alphafold2
|
526 |
+
self.to_v_gate = None
|
527 |
+
if gate_values:
|
528 |
+
self.to_v_gate = nn.Linear(dim, v_dim)
|
529 |
+
nn.init.constant_(self.to_v_gate.weight, 0)
|
530 |
+
nn.init.constant_(self.to_v_gate.bias, 1)
|
531 |
+
|
532 |
+
# cosine sim attention
|
533 |
+
self.qk_norm = qk_norm
|
534 |
+
if qk_norm:
|
535 |
+
scale_init_value = default(scale_init_value,
|
536 |
+
-3) # if not provided, initialize as though it were sequence length of 1024
|
537 |
+
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
|
538 |
+
|
539 |
+
# talking heads
|
540 |
+
self.talking_heads = talking_heads
|
541 |
+
if talking_heads:
|
542 |
+
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
543 |
+
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
544 |
+
|
545 |
+
# head scaling
|
546 |
+
self.head_scale = head_scale
|
547 |
+
if head_scale:
|
548 |
+
self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
|
549 |
+
|
550 |
+
# explicit topk sparse attention
|
551 |
+
self.sparse_topk = sparse_topk
|
552 |
+
|
553 |
+
# entmax
|
554 |
+
self.attn_fn = F.softmax
|
555 |
+
|
556 |
+
# add memory key / values
|
557 |
+
self.num_mem_kv = num_mem_kv
|
558 |
+
if num_mem_kv > 0:
|
559 |
+
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
560 |
+
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
561 |
+
|
562 |
+
# attention on attention
|
563 |
+
self.attn_on_attn = on_attn
|
564 |
+
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
|
565 |
+
|
566 |
+
self.rel_pos_bias = rel_pos_bias
|
567 |
+
if rel_pos_bias:
|
568 |
+
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
569 |
+
self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
|
570 |
+
num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
|
571 |
+
|
572 |
+
# init output projection 0
|
573 |
+
if zero_init_output:
|
574 |
+
init_zero_(self.to_out)
|
575 |
+
|
576 |
+
def forward(
|
577 |
+
self,
|
578 |
+
x,
|
579 |
+
context=None,
|
580 |
+
mask=None,
|
581 |
+
context_mask=None,
|
582 |
+
attn_mask=None,
|
583 |
+
sinusoidal_emb=None,
|
584 |
+
rotary_pos_emb=None,
|
585 |
+
prev_attn=None,
|
586 |
+
mem=None,
|
587 |
+
layer_past=None,
|
588 |
+
):
|
589 |
+
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
|
590 |
+
context)
|
591 |
+
kv_input = default(context, x)
|
592 |
+
|
593 |
+
q_input = x
|
594 |
+
k_input = kv_input
|
595 |
+
v_input = kv_input
|
596 |
+
|
597 |
+
if exists(mem):
|
598 |
+
k_input = torch.cat((mem, k_input), dim=-2)
|
599 |
+
v_input = torch.cat((mem, v_input), dim=-2)
|
600 |
+
|
601 |
+
if exists(sinusoidal_emb):
|
602 |
+
# in shortformer, the query would start at a position offset depending on the past cached memory
|
603 |
+
offset = k_input.shape[-2] - q_input.shape[-2]
|
604 |
+
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
605 |
+
k_input = k_input + sinusoidal_emb(k_input)
|
606 |
+
|
607 |
+
q = self.to_q(q_input)
|
608 |
+
k = self.to_k(k_input)
|
609 |
+
v = self.to_v(v_input)
|
610 |
+
|
611 |
+
if not collab_heads:
|
612 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
613 |
+
else:
|
614 |
+
q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
|
615 |
+
k = rearrange(k, 'b n d -> b () n d')
|
616 |
+
v = rearrange(v, 'b n (h d) -> b h n d', h=h)
|
617 |
+
|
618 |
+
if layer_past is not None:
|
619 |
+
past_key, past_value = layer_past
|
620 |
+
k = torch.cat([past_key, k], dim=-2)
|
621 |
+
v = torch.cat([past_value, v], dim=-2)
|
622 |
+
k_cache = k
|
623 |
+
v_cache = v
|
624 |
+
|
625 |
+
if exists(rotary_pos_emb) and not has_context:
|
626 |
+
l = rotary_pos_emb.shape[-1]
|
627 |
+
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
628 |
+
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
|
629 |
+
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
|
630 |
+
|
631 |
+
input_mask = None
|
632 |
+
if any(map(exists, (mask, context_mask))):
|
633 |
+
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
634 |
+
k_mask = q_mask if not exists(context) else context_mask
|
635 |
+
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
636 |
+
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
637 |
+
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
638 |
+
input_mask = q_mask * k_mask
|
639 |
+
|
640 |
+
if self.num_mem_kv > 0:
|
641 |
+
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
642 |
+
k = torch.cat((mem_k, k), dim=-2)
|
643 |
+
v = torch.cat((mem_v, v), dim=-2)
|
644 |
+
if exists(input_mask):
|
645 |
+
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
646 |
+
|
647 |
+
if collab_heads:
|
648 |
+
k = k.expand(-1, h, -1, -1)
|
649 |
+
|
650 |
+
if self.qk_norm:
|
651 |
+
q, k = map(l2norm, (q, k))
|
652 |
+
scale = 1 / (self.scale.exp().clamp(min=1e-2))
|
653 |
+
|
654 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
|
655 |
+
mask_value = max_neg_value(dots)
|
656 |
+
|
657 |
+
if exists(prev_attn):
|
658 |
+
dots = dots + prev_attn
|
659 |
+
|
660 |
+
pre_softmax_attn = dots.clone()
|
661 |
+
|
662 |
+
if talking_heads:
|
663 |
+
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
664 |
+
|
665 |
+
if self.rel_pos_bias:
|
666 |
+
dots = self.rel_pos(dots)
|
667 |
+
|
668 |
+
if exists(input_mask):
|
669 |
+
dots.masked_fill_(~input_mask, mask_value)
|
670 |
+
del input_mask
|
671 |
+
|
672 |
+
if exists(attn_mask):
|
673 |
+
assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
|
674 |
+
if attn_mask.ndim == 2:
|
675 |
+
attn_mask = rearrange(attn_mask, 'i j -> () () i j')
|
676 |
+
elif attn_mask.ndim == 3:
|
677 |
+
attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
|
678 |
+
dots.masked_fill_(~attn_mask, mask_value)
|
679 |
+
|
680 |
+
if exists(self.max_attend_past):
|
681 |
+
i, j = dots.shape[-2:]
|
682 |
+
range_q = torch.arange(j - i, j, device=device)
|
683 |
+
range_k = torch.arange(j, device=device)
|
684 |
+
dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
|
685 |
+
mask = dist > self.max_attend_past
|
686 |
+
dots.masked_fill_(mask, mask_value)
|
687 |
+
del mask
|
688 |
+
|
689 |
+
if self.causal:
|
690 |
+
i, j = dots.shape[-2:]
|
691 |
+
r = torch.arange(i, device=device)
|
692 |
+
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
693 |
+
mask = F.pad(mask, (j - i, 0), value=False)
|
694 |
+
dots.masked_fill_(mask, mask_value)
|
695 |
+
del mask
|
696 |
+
|
697 |
+
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
698 |
+
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
699 |
+
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
700 |
+
mask = dots < vk
|
701 |
+
dots.masked_fill_(mask, mask_value)
|
702 |
+
del mask
|
703 |
+
|
704 |
+
attn = self.attn_fn(dots, dim=-1)
|
705 |
+
post_softmax_attn = attn.clone()
|
706 |
+
|
707 |
+
attn = self.dropout(attn)
|
708 |
+
|
709 |
+
if talking_heads:
|
710 |
+
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
711 |
+
|
712 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
713 |
+
|
714 |
+
if head_scale:
|
715 |
+
out = out * self.head_scale_params
|
716 |
+
|
717 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
718 |
+
|
719 |
+
if exists(self.to_v_gate):
|
720 |
+
gates = self.to_v_gate(x)
|
721 |
+
out = out * gates.sigmoid()
|
722 |
+
|
723 |
+
intermediates = Intermediates(
|
724 |
+
pre_softmax_attn=pre_softmax_attn,
|
725 |
+
post_softmax_attn=post_softmax_attn
|
726 |
+
)
|
727 |
+
|
728 |
+
return self.to_out(out), intermediates, k_cache, v_cache
|
729 |
+
|
730 |
+
|
731 |
+
class AttentionLayers(nn.Module):
|
732 |
+
def __init__(
|
733 |
+
self,
|
734 |
+
dim,
|
735 |
+
depth,
|
736 |
+
heads=8,
|
737 |
+
causal=False,
|
738 |
+
cross_attend=False,
|
739 |
+
only_cross=False,
|
740 |
+
use_scalenorm=False,
|
741 |
+
use_rms_scaleshift_norm=False,
|
742 |
+
use_rmsnorm=False,
|
743 |
+
use_rezero=False,
|
744 |
+
alibi_pos_bias=False,
|
745 |
+
alibi_num_heads=None,
|
746 |
+
alibi_learned=False,
|
747 |
+
position_infused_attn=False,
|
748 |
+
rotary_pos_emb=False,
|
749 |
+
rotary_emb_dim=None,
|
750 |
+
custom_layers=None,
|
751 |
+
sandwich_coef=None,
|
752 |
+
par_ratio=None,
|
753 |
+
residual_attn=False,
|
754 |
+
cross_residual_attn=False,
|
755 |
+
macaron=False,
|
756 |
+
pre_norm=True,
|
757 |
+
gate_residual=False,
|
758 |
+
scale_residual=False,
|
759 |
+
shift_tokens=0,
|
760 |
+
sandwich_norm=False,
|
761 |
+
use_qk_norm_attn=False,
|
762 |
+
qk_norm_attn_seq_len=None,
|
763 |
+
zero_init_branch_output=False,
|
764 |
+
**kwargs
|
765 |
+
):
|
766 |
+
super().__init__()
|
767 |
+
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
768 |
+
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
769 |
+
|
770 |
+
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
771 |
+
|
772 |
+
self.dim = dim
|
773 |
+
self.depth = depth
|
774 |
+
self.layers = nn.ModuleList([])
|
775 |
+
self.causal = causal
|
776 |
+
|
777 |
+
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
778 |
+
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
779 |
+
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
780 |
+
|
781 |
+
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
|
782 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
|
783 |
+
|
784 |
+
assert not (
|
785 |
+
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
786 |
+
|
787 |
+
if alibi_pos_bias:
|
788 |
+
alibi_num_heads = default(alibi_num_heads, heads)
|
789 |
+
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
790 |
+
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
|
791 |
+
self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
|
792 |
+
else:
|
793 |
+
self.rel_pos = None
|
794 |
+
|
795 |
+
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
796 |
+
self.pre_norm = pre_norm
|
797 |
+
self.sandwich_norm = sandwich_norm
|
798 |
+
|
799 |
+
self.residual_attn = residual_attn
|
800 |
+
self.cross_residual_attn = cross_residual_attn
|
801 |
+
self.cross_attend = cross_attend
|
802 |
+
|
803 |
+
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
804 |
+
norm_class = RMSNorm if use_rmsnorm else norm_class
|
805 |
+
norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
|
806 |
+
norm_fn = partial(norm_class, dim)
|
807 |
+
|
808 |
+
norm_fn = nn.Identity if use_rezero else norm_fn
|
809 |
+
branch_fn = Rezero if use_rezero else None
|
810 |
+
|
811 |
+
if cross_attend and not only_cross:
|
812 |
+
default_block = ('a', 'c', 'f')
|
813 |
+
elif cross_attend and only_cross:
|
814 |
+
default_block = ('c', 'f')
|
815 |
+
else:
|
816 |
+
default_block = ('a', 'f')
|
817 |
+
|
818 |
+
if macaron:
|
819 |
+
default_block = ('f',) + default_block
|
820 |
+
|
821 |
+
# qk normalization
|
822 |
+
|
823 |
+
if use_qk_norm_attn:
|
824 |
+
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
|
825 |
+
qk_norm_attn_seq_len) else None
|
826 |
+
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
|
827 |
+
|
828 |
+
# zero init
|
829 |
+
|
830 |
+
if zero_init_branch_output:
|
831 |
+
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
832 |
+
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
833 |
+
|
834 |
+
# calculate layer block order
|
835 |
+
|
836 |
+
if exists(custom_layers):
|
837 |
+
layer_types = custom_layers
|
838 |
+
elif exists(par_ratio):
|
839 |
+
par_depth = depth * len(default_block)
|
840 |
+
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
841 |
+
default_block = tuple(filter(not_equals('f'), default_block))
|
842 |
+
par_attn = par_depth // par_ratio
|
843 |
+
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
844 |
+
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
845 |
+
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
846 |
+
par_block = default_block + ('f',) * (par_width - len(default_block))
|
847 |
+
par_head = par_block * par_attn
|
848 |
+
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
849 |
+
elif exists(sandwich_coef):
|
850 |
+
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
851 |
+
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
852 |
+
else:
|
853 |
+
layer_types = default_block * depth
|
854 |
+
|
855 |
+
self.layer_types = layer_types
|
856 |
+
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
857 |
+
|
858 |
+
# calculate token shifting
|
859 |
+
|
860 |
+
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
861 |
+
|
862 |
+
# iterate and construct layers
|
863 |
+
|
864 |
+
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
865 |
+
is_last_layer = ind == (len(self.layer_types) - 1)
|
866 |
+
|
867 |
+
if layer_type == 'a':
|
868 |
+
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
869 |
+
elif layer_type == 'c':
|
870 |
+
layer = Attention(dim, heads=heads, **attn_kwargs)
|
871 |
+
elif layer_type == 'f':
|
872 |
+
layer = FeedForward(dim, **ff_kwargs)
|
873 |
+
layer = layer if not macaron else Scale(0.5, layer)
|
874 |
+
else:
|
875 |
+
raise Exception(f'invalid layer type {layer_type}')
|
876 |
+
|
877 |
+
if layer_shift_tokens > 0:
|
878 |
+
shift_range_upper = layer_shift_tokens + 1
|
879 |
+
shift_range_lower = -layer_shift_tokens if not causal else 0
|
880 |
+
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
|
881 |
+
|
882 |
+
if exists(branch_fn):
|
883 |
+
layer = branch_fn(layer)
|
884 |
+
|
885 |
+
residual_fn = GRUGating if gate_residual else Residual
|
886 |
+
residual = residual_fn(dim, scale_residual=scale_residual)
|
887 |
+
|
888 |
+
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
|
889 |
+
|
890 |
+
pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
|
891 |
+
post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
|
892 |
+
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
|
893 |
+
|
894 |
+
norms = nn.ModuleList([
|
895 |
+
pre_branch_norm,
|
896 |
+
post_branch_norm,
|
897 |
+
post_main_norm
|
898 |
+
])
|
899 |
+
|
900 |
+
self.layers.append(nn.ModuleList([
|
901 |
+
norms,
|
902 |
+
layer,
|
903 |
+
residual
|
904 |
+
]))
|
905 |
+
|
906 |
+
def forward(
|
907 |
+
self,
|
908 |
+
x,
|
909 |
+
context=None,
|
910 |
+
full_context=None, # for passing a list of hidden states from an encoder
|
911 |
+
mask=None,
|
912 |
+
context_mask=None,
|
913 |
+
attn_mask=None,
|
914 |
+
mems=None,
|
915 |
+
return_hiddens=False,
|
916 |
+
norm_scale_shift_inp=None,
|
917 |
+
past_key_values=None,
|
918 |
+
expected_seq_len=None,
|
919 |
+
):
|
920 |
+
|
921 |
+
assert not (self.cross_attend ^ (exists(context) or exists(
|
922 |
+
full_context))), 'context must be passed in if cross_attend is set to True'
|
923 |
+
assert context is None or full_context is None, 'only one of full_context or context can be provided'
|
924 |
+
|
925 |
+
hiddens = []
|
926 |
+
intermediates = []
|
927 |
+
prev_attn = None
|
928 |
+
prev_cross_attn = None
|
929 |
+
|
930 |
+
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
931 |
+
norm_args = {}
|
932 |
+
if exists(norm_scale_shift_inp):
|
933 |
+
norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
|
934 |
+
|
935 |
+
rotary_pos_emb = None
|
936 |
+
if exists(self.rotary_pos_emb):
|
937 |
+
if not self.training and self.causal:
|
938 |
+
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
|
939 |
+
elif expected_seq_len is None:
|
940 |
+
expected_seq_len = 0
|
941 |
+
seq_len = x.shape[1]
|
942 |
+
if past_key_values is not None:
|
943 |
+
seq_len += past_key_values[0][0].shape[-2]
|
944 |
+
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
|
945 |
+
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
946 |
+
|
947 |
+
present_key_values = []
|
948 |
+
cross_attn_count = 0
|
949 |
+
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
950 |
+
if layer_type == 'a':
|
951 |
+
layer_mem = mems.pop(0) if mems else None
|
952 |
+
|
953 |
+
residual = x
|
954 |
+
|
955 |
+
pre_branch_norm, post_branch_norm, post_main_norm = norm
|
956 |
+
|
957 |
+
if exists(pre_branch_norm):
|
958 |
+
x = pre_branch_norm(x, **norm_args)
|
959 |
+
|
960 |
+
if layer_type == 'a' or layer_type == 'c':
|
961 |
+
if past_key_values is not None:
|
962 |
+
layer_kv = past_key_values.pop(0)
|
963 |
+
layer_past = tuple(s.to(x.device) for s in layer_kv)
|
964 |
+
else:
|
965 |
+
layer_past = None
|
966 |
+
|
967 |
+
if layer_type == 'a':
|
968 |
+
out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
969 |
+
prev_attn, layer_mem, layer_past)
|
970 |
+
elif layer_type == 'c':
|
971 |
+
if exists(full_context):
|
972 |
+
out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
|
973 |
+
None, prev_attn, None, layer_past)
|
974 |
+
else:
|
975 |
+
out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
976 |
+
elif layer_type == 'f':
|
977 |
+
out = block(x)
|
978 |
+
|
979 |
+
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
980 |
+
present_key_values.append((k.detach(), v.detach()))
|
981 |
+
|
982 |
+
if exists(post_branch_norm):
|
983 |
+
out = post_branch_norm(out, **norm_args)
|
984 |
+
|
985 |
+
x = residual_fn(out, residual)
|
986 |
+
|
987 |
+
if layer_type in ('a', 'c'):
|
988 |
+
intermediates.append(inter)
|
989 |
+
|
990 |
+
if layer_type == 'a' and self.residual_attn:
|
991 |
+
prev_attn = inter.pre_softmax_attn
|
992 |
+
elif layer_type == 'c' and self.cross_residual_attn:
|
993 |
+
prev_cross_attn = inter.pre_softmax_attn
|
994 |
+
|
995 |
+
if exists(post_main_norm):
|
996 |
+
x = post_main_norm(x, **norm_args)
|
997 |
+
|
998 |
+
if layer_type == 'c':
|
999 |
+
cross_attn_count += 1
|
1000 |
+
|
1001 |
+
if layer_type == 'f':
|
1002 |
+
hiddens.append(x)
|
1003 |
+
|
1004 |
+
if return_hiddens:
|
1005 |
+
intermediates = LayerIntermediates(
|
1006 |
+
hiddens=hiddens,
|
1007 |
+
attn_intermediates=intermediates,
|
1008 |
+
past_key_values=present_key_values
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
return x, intermediates
|
1012 |
+
|
1013 |
+
return x
|
1014 |
+
|
1015 |
+
|
1016 |
+
class Encoder(AttentionLayers):
|
1017 |
+
def __init__(self, **kwargs):
|
1018 |
+
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
1019 |
+
super().__init__(causal=False, **kwargs)
|
1020 |
+
|
1021 |
+
|
1022 |
+
class Decoder(AttentionLayers):
|
1023 |
+
def __init__(self, **kwargs):
|
1024 |
+
assert 'causal' not in kwargs, 'cannot set causality on decoder'
|
1025 |
+
super().__init__(causal=True, **kwargs)
|
1026 |
+
|
1027 |
+
|
1028 |
+
class CrossAttender(AttentionLayers):
|
1029 |
+
def __init__(self, **kwargs):
|
1030 |
+
super().__init__(cross_attend=True, only_cross=True, **kwargs)
|
1031 |
+
|
1032 |
+
|
1033 |
+
class ViTransformerWrapper(nn.Module):
|
1034 |
+
def __init__(
|
1035 |
+
self,
|
1036 |
+
*,
|
1037 |
+
image_size,
|
1038 |
+
patch_size,
|
1039 |
+
attn_layers,
|
1040 |
+
num_classes=None,
|
1041 |
+
dropout=0.,
|
1042 |
+
emb_dropout=0.
|
1043 |
+
):
|
1044 |
+
super().__init__()
|
1045 |
+
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
|
1046 |
+
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
1047 |
+
dim = attn_layers.dim
|
1048 |
+
num_patches = (image_size // patch_size) ** 2
|
1049 |
+
patch_dim = 3 * patch_size ** 2
|
1050 |
+
|
1051 |
+
self.patch_size = patch_size
|
1052 |
+
|
1053 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
1054 |
+
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
1055 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
1056 |
+
self.dropout = nn.Dropout(emb_dropout)
|
1057 |
+
|
1058 |
+
self.attn_layers = attn_layers
|
1059 |
+
self.norm = nn.LayerNorm(dim)
|
1060 |
+
self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
|
1061 |
+
|
1062 |
+
def forward(
|
1063 |
+
self,
|
1064 |
+
img,
|
1065 |
+
return_embeddings=False
|
1066 |
+
):
|
1067 |
+
p = self.patch_size
|
1068 |
+
|
1069 |
+
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
1070 |
+
x = self.patch_to_embedding(x)
|
1071 |
+
b, n, _ = x.shape
|
1072 |
+
|
1073 |
+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
1074 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
1075 |
+
x = x + self.pos_embedding[:, :(n + 1)]
|
1076 |
+
x = self.dropout(x)
|
1077 |
+
|
1078 |
+
x = self.attn_layers(x)
|
1079 |
+
x = self.norm(x)
|
1080 |
+
|
1081 |
+
if not exists(self.mlp_head) or return_embeddings:
|
1082 |
+
return x
|
1083 |
+
|
1084 |
+
return self.mlp_head(x[:, 0])
|
1085 |
+
|
1086 |
+
|
1087 |
+
class TransformerWrapper(nn.Module):
|
1088 |
+
def __init__(
|
1089 |
+
self,
|
1090 |
+
*,
|
1091 |
+
num_tokens,
|
1092 |
+
max_seq_len,
|
1093 |
+
attn_layers,
|
1094 |
+
emb_dim=None,
|
1095 |
+
max_mem_len=0.,
|
1096 |
+
shift_mem_down=0,
|
1097 |
+
emb_dropout=0.,
|
1098 |
+
num_memory_tokens=None,
|
1099 |
+
tie_embedding=False,
|
1100 |
+
use_pos_emb=True
|
1101 |
+
):
|
1102 |
+
super().__init__()
|
1103 |
+
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
1104 |
+
|
1105 |
+
dim = attn_layers.dim
|
1106 |
+
emb_dim = default(emb_dim, dim)
|
1107 |
+
|
1108 |
+
self.max_seq_len = max_seq_len
|
1109 |
+
self.max_mem_len = max_mem_len
|
1110 |
+
self.shift_mem_down = shift_mem_down
|
1111 |
+
|
1112 |
+
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
1113 |
+
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
1114 |
+
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
1115 |
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
1116 |
+
|
1117 |
+
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
1118 |
+
self.attn_layers = attn_layers
|
1119 |
+
self.norm = nn.LayerNorm(dim)
|
1120 |
+
|
1121 |
+
self.init_()
|
1122 |
+
|
1123 |
+
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
1124 |
+
|
1125 |
+
# memory tokens (like [cls]) from Memory Transformers paper
|
1126 |
+
num_memory_tokens = default(num_memory_tokens, 0)
|
1127 |
+
self.num_memory_tokens = num_memory_tokens
|
1128 |
+
if num_memory_tokens > 0:
|
1129 |
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
1130 |
+
|
1131 |
+
def init_(self):
|
1132 |
+
nn.init.kaiming_normal_(self.token_emb.weight)
|
1133 |
+
|
1134 |
+
def forward(
|
1135 |
+
self,
|
1136 |
+
x,
|
1137 |
+
return_embeddings=False,
|
1138 |
+
mask=None,
|
1139 |
+
return_hiddens=False,
|
1140 |
+
return_attn=False,
|
1141 |
+
mems=None,
|
1142 |
+
use_cache=False,
|
1143 |
+
**kwargs
|
1144 |
+
):
|
1145 |
+
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
1146 |
+
x = self.token_emb(x)
|
1147 |
+
x = x + self.pos_emb(x)
|
1148 |
+
x = self.emb_dropout(x)
|
1149 |
+
|
1150 |
+
x = self.project_emb(x)
|
1151 |
+
|
1152 |
+
if num_mem > 0:
|
1153 |
+
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
|
1154 |
+
x = torch.cat((mem, x), dim=1)
|
1155 |
+
|
1156 |
+
# auto-handle masking after appending memory tokens
|
1157 |
+
if exists(mask):
|
1158 |
+
mask = F.pad(mask, (num_mem, 0), value=True)
|
1159 |
+
|
1160 |
+
if self.shift_mem_down and exists(mems):
|
1161 |
+
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
1162 |
+
mems = [*mems_r, *mems_l]
|
1163 |
+
|
1164 |
+
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
1165 |
+
x = self.norm(x)
|
1166 |
+
|
1167 |
+
mem, x = x[:, :num_mem], x[:, num_mem:]
|
1168 |
+
|
1169 |
+
out = self.to_logits(x) if not return_embeddings else x
|
1170 |
+
|
1171 |
+
if return_hiddens:
|
1172 |
+
hiddens = intermediates.hiddens
|
1173 |
+
return out, hiddens
|
1174 |
+
|
1175 |
+
res = [out]
|
1176 |
+
if return_attn:
|
1177 |
+
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1178 |
+
res.append(attn_maps)
|
1179 |
+
if use_cache:
|
1180 |
+
res.append(intermediates.past_key_values)
|
1181 |
+
|
1182 |
+
if len(res) > 1:
|
1183 |
+
return tuple(res)
|
1184 |
+
return res[0]
|
1185 |
+
|
1186 |
+
|
1187 |
+
class ContinuousTransformerWrapper(nn.Module):
|
1188 |
+
def __init__(
|
1189 |
+
self,
|
1190 |
+
*,
|
1191 |
+
max_seq_len,
|
1192 |
+
attn_layers,
|
1193 |
+
dim_in=None,
|
1194 |
+
dim_out=None,
|
1195 |
+
emb_dim=None,
|
1196 |
+
emb_dropout=0.,
|
1197 |
+
use_pos_emb=True
|
1198 |
+
):
|
1199 |
+
super().__init__()
|
1200 |
+
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
1201 |
+
|
1202 |
+
dim = attn_layers.dim
|
1203 |
+
|
1204 |
+
self.max_seq_len = max_seq_len
|
1205 |
+
|
1206 |
+
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
|
1207 |
+
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
1208 |
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
1209 |
+
|
1210 |
+
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
1211 |
+
|
1212 |
+
self.attn_layers = attn_layers
|
1213 |
+
self.norm = nn.LayerNorm(dim)
|
1214 |
+
|
1215 |
+
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
1216 |
+
|
1217 |
+
def forward(
|
1218 |
+
self,
|
1219 |
+
x,
|
1220 |
+
return_embeddings=False,
|
1221 |
+
mask=None,
|
1222 |
+
return_attn=False,
|
1223 |
+
mems=None,
|
1224 |
+
use_cache=False,
|
1225 |
+
**kwargs
|
1226 |
+
):
|
1227 |
+
b, n, _, device = *x.shape, x.device
|
1228 |
+
|
1229 |
+
x = self.project_in(x)
|
1230 |
+
x = x + self.pos_emb(x)
|
1231 |
+
x = self.emb_dropout(x)
|
1232 |
+
|
1233 |
+
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
1234 |
+
x = self.norm(x)
|
1235 |
+
|
1236 |
+
out = self.project_out(x) if not return_embeddings else x
|
1237 |
+
|
1238 |
+
res = [out]
|
1239 |
+
if return_attn:
|
1240 |
+
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1241 |
+
res.append(attn_maps)
|
1242 |
+
if use_cache:
|
1243 |
+
res.append(intermediates.past_key_values)
|
1244 |
+
|
1245 |
+
if len(res) > 1:
|
1246 |
+
return tuple(res)
|
1247 |
+
return res[0]
|
1248 |
+
|
ruth_tts_transformer/utils/__init__.py
ADDED
File without changes
|