Create musiclm_pytorch.py
Browse files- musiclm_pytorch.py +555 -0
musiclm_pytorch.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn, einsum
|
4 |
+
|
5 |
+
from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
|
6 |
+
|
7 |
+
from audiolm_pytorch import AudioLM
|
8 |
+
|
9 |
+
from x_clip.tokenizer import tokenizer
|
10 |
+
from vector_quantize_pytorch import ResidualVQ
|
11 |
+
|
12 |
+
from einops import rearrange, repeat, reduce, pack, unpack
|
13 |
+
|
14 |
+
from beartype.typing import List, Optional, Tuple
|
15 |
+
from beartype import beartype
|
16 |
+
|
17 |
+
# functions
|
18 |
+
|
19 |
+
def exists(val):
|
20 |
+
return val is not None
|
21 |
+
|
22 |
+
def default(val, d):
|
23 |
+
return val if exists(val) else d
|
24 |
+
|
25 |
+
def round_down_nearest_multiple(n, divisor):
|
26 |
+
return n // divisor * divisor
|
27 |
+
|
28 |
+
# tensor functions
|
29 |
+
|
30 |
+
def log(t, eps = 1e-20):
|
31 |
+
return torch.log(t.clamp(min = eps))
|
32 |
+
|
33 |
+
def l2norm(t):
|
34 |
+
return F.normalize(t, p = 2, dim = -1)
|
35 |
+
|
36 |
+
# 2d sinusoidal positional embedding
|
37 |
+
# simple vit paper shows it is good enough compared to learned
|
38 |
+
|
39 |
+
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
40 |
+
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
41 |
+
|
42 |
+
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
43 |
+
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
44 |
+
|
45 |
+
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
|
46 |
+
omega = 1. / (temperature ** omega)
|
47 |
+
|
48 |
+
y = y.flatten()[:, None] * omega[None, :]
|
49 |
+
x = x.flatten()[:, None] * omega[None, :]
|
50 |
+
|
51 |
+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
52 |
+
pe = pe.type(dtype)
|
53 |
+
|
54 |
+
return rearrange(pe, '(h w) d -> h w d', h = h, w = w)
|
55 |
+
|
56 |
+
# biasless layernorm
|
57 |
+
|
58 |
+
class LayerNorm(nn.Module):
|
59 |
+
def __init__(self, dim):
|
60 |
+
super().__init__()
|
61 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
62 |
+
self.register_buffer('beta', torch.zeros(dim))
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
66 |
+
|
67 |
+
# feedforward
|
68 |
+
|
69 |
+
class GEGLU(nn.Module):
|
70 |
+
def forward(self, x):
|
71 |
+
x, gate = x.chunk(2, dim = -1)
|
72 |
+
return F.gelu(gate) * x
|
73 |
+
|
74 |
+
def FeedForward(dim, mult = 4, dropout = 0.):
|
75 |
+
dim_hidden = int(dim * mult * 2 / 3)
|
76 |
+
|
77 |
+
return nn.Sequential(
|
78 |
+
LayerNorm(dim),
|
79 |
+
nn.Linear(dim, dim_hidden * 2, bias = False),
|
80 |
+
GEGLU(),
|
81 |
+
nn.Dropout(dropout),
|
82 |
+
nn.Linear(dim_hidden, dim, bias = False)
|
83 |
+
)
|
84 |
+
|
85 |
+
# attention
|
86 |
+
|
87 |
+
class Attention(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
dim,
|
91 |
+
causal = False,
|
92 |
+
dim_head = 64,
|
93 |
+
heads = 8,
|
94 |
+
dropout = 0.
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.heads = heads
|
98 |
+
self.scale = dim_head ** -0.5
|
99 |
+
self.causal = causal
|
100 |
+
inner_dim = dim_head * heads
|
101 |
+
|
102 |
+
self.norm = LayerNorm(dim)
|
103 |
+
|
104 |
+
self.attn_dropout = nn.Dropout(dropout)
|
105 |
+
|
106 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
107 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
108 |
+
|
109 |
+
self.to_out = nn.Sequential(
|
110 |
+
nn.Linear(inner_dim, dim, bias = False),
|
111 |
+
nn.Dropout(dropout)
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward(
|
115 |
+
self,
|
116 |
+
x,
|
117 |
+
mask = None
|
118 |
+
):
|
119 |
+
b, n, _, device = *x.shape, x.device
|
120 |
+
|
121 |
+
# prenorm
|
122 |
+
|
123 |
+
x = self.norm(x)
|
124 |
+
|
125 |
+
# project for queries, keys, values
|
126 |
+
|
127 |
+
q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
|
128 |
+
|
129 |
+
# split for multi-headed attention
|
130 |
+
|
131 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
|
132 |
+
|
133 |
+
q = q * self.scale
|
134 |
+
|
135 |
+
# similarities
|
136 |
+
|
137 |
+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
138 |
+
|
139 |
+
if exists(mask):
|
140 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
141 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
142 |
+
|
143 |
+
if self.causal:
|
144 |
+
i, j = sim.shape[-2:]
|
145 |
+
causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
|
146 |
+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
147 |
+
|
148 |
+
# attention
|
149 |
+
|
150 |
+
attn = sim.softmax(dim = -1)
|
151 |
+
attn = self.attn_dropout(attn)
|
152 |
+
|
153 |
+
# aggregate
|
154 |
+
|
155 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
156 |
+
|
157 |
+
# merge heads
|
158 |
+
|
159 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
160 |
+
return self.to_out(out)
|
161 |
+
|
162 |
+
# transformer
|
163 |
+
|
164 |
+
class Transformer(nn.Module):
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
dim,
|
168 |
+
depth,
|
169 |
+
dim_head = 64,
|
170 |
+
heads = 8,
|
171 |
+
attn_dropout = 0.,
|
172 |
+
ff_mult = 4,
|
173 |
+
ff_dropout = 0.
|
174 |
+
):
|
175 |
+
super().__init__()
|
176 |
+
self.layers = nn.ModuleList([])
|
177 |
+
for _ in range(depth):
|
178 |
+
self.layers.append(nn.ModuleList([
|
179 |
+
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
180 |
+
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
|
181 |
+
]))
|
182 |
+
|
183 |
+
def forward(self, x, mask = None):
|
184 |
+
|
185 |
+
for attn, ff in self.layers:
|
186 |
+
x = attn(x, mask = mask) + x
|
187 |
+
x = ff(x) + x
|
188 |
+
|
189 |
+
return x
|
190 |
+
|
191 |
+
# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778
|
192 |
+
|
193 |
+
def pair(t):
|
194 |
+
return (t, t) if not isinstance(t, tuple) else t
|
195 |
+
|
196 |
+
class AudioSpectrogramTransformer(nn.Module):
|
197 |
+
def __init__(
|
198 |
+
self,
|
199 |
+
dim,
|
200 |
+
depth,
|
201 |
+
patch_size = 16,
|
202 |
+
dim_head = 64,
|
203 |
+
heads = 8,
|
204 |
+
attn_dropout = 0.,
|
205 |
+
ff_mult = 4,
|
206 |
+
ff_dropout = 0.,
|
207 |
+
spec_n_fft = 128,
|
208 |
+
spec_power = 2,
|
209 |
+
spec_win_length = 24,
|
210 |
+
spec_hop_length = None,
|
211 |
+
spec_pad = 0,
|
212 |
+
spec_center = True,
|
213 |
+
spec_pad_mode = 'reflect',
|
214 |
+
spec_aug_stretch_factor = 0.8,
|
215 |
+
spec_aug_freq_mask = 80,
|
216 |
+
spec_aug_time_mask = 80
|
217 |
+
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
self.dim = dim
|
221 |
+
|
222 |
+
self.patch_size = pair(patch_size)
|
223 |
+
self.to_patch_tokens = nn.Conv2d(self.patch_size[0] * self.patch_size[1], dim, 1)
|
224 |
+
|
225 |
+
self.spec = Spectrogram(
|
226 |
+
n_fft = spec_n_fft,
|
227 |
+
power = spec_power,
|
228 |
+
win_length = spec_win_length,
|
229 |
+
hop_length = spec_hop_length,
|
230 |
+
pad = spec_pad,
|
231 |
+
center = spec_center,
|
232 |
+
pad_mode = spec_pad_mode
|
233 |
+
)
|
234 |
+
|
235 |
+
# SpecAugment - seems to be widely used in audio field https://arxiv.org/abs/1904.08779
|
236 |
+
|
237 |
+
self.aug = torch.nn.Sequential(
|
238 |
+
TimeStretch(spec_aug_stretch_factor, fixed_rate=True),
|
239 |
+
FrequencyMasking(freq_mask_param = spec_aug_freq_mask),
|
240 |
+
TimeMasking(time_mask_param = spec_aug_time_mask),
|
241 |
+
)
|
242 |
+
|
243 |
+
self.transformer = Transformer(
|
244 |
+
dim = dim,
|
245 |
+
depth = depth,
|
246 |
+
dim_head = dim_head,
|
247 |
+
heads = heads,
|
248 |
+
attn_dropout = attn_dropout,
|
249 |
+
ff_mult = ff_mult,
|
250 |
+
ff_dropout = ff_dropout
|
251 |
+
)
|
252 |
+
|
253 |
+
self.norm = LayerNorm(dim)
|
254 |
+
|
255 |
+
def forward(self, x):
|
256 |
+
x = self.spec(x)
|
257 |
+
|
258 |
+
if self.training:
|
259 |
+
x = self.aug(x)
|
260 |
+
|
261 |
+
# automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes
|
262 |
+
|
263 |
+
height, width = x.shape[-2:]
|
264 |
+
patch_height, patch_width = self.patch_size
|
265 |
+
|
266 |
+
rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))
|
267 |
+
|
268 |
+
if (height, width) != (rounded_height, rounded_width): # just keep printing to be annoying until it is fixed
|
269 |
+
print(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')
|
270 |
+
|
271 |
+
x = x[..., :rounded_height, :rounded_width]
|
272 |
+
|
273 |
+
# to patches
|
274 |
+
|
275 |
+
x = rearrange(x, 'b (h p1) (w p2) -> b (p1 p2) h w', p1 = patch_height, p2 = patch_width)
|
276 |
+
x = self.to_patch_tokens(x)
|
277 |
+
|
278 |
+
# 2d sinusoidal positional embedding
|
279 |
+
|
280 |
+
x = rearrange(x, 'b c h w -> b h w c')
|
281 |
+
x = x + posemb_sincos_2d(x)
|
282 |
+
|
283 |
+
# attention, what else
|
284 |
+
|
285 |
+
x = rearrange(x, 'b ... c -> b (...) c')
|
286 |
+
|
287 |
+
x = self.transformer(x)
|
288 |
+
|
289 |
+
# final global average and norm (most recent papers show this is superior to CLS token)
|
290 |
+
|
291 |
+
x = reduce(x, 'b n d -> b d', 'mean')
|
292 |
+
|
293 |
+
return self.norm(x)
|
294 |
+
|
295 |
+
# text transformer
|
296 |
+
|
297 |
+
@beartype
|
298 |
+
class TextTransformer(nn.Module):
|
299 |
+
def __init__(
|
300 |
+
self,
|
301 |
+
dim,
|
302 |
+
depth,
|
303 |
+
num_tokens = tokenizer.vocab_size,
|
304 |
+
max_seq_len = 256,
|
305 |
+
dim_head = 64,
|
306 |
+
heads = 8,
|
307 |
+
attn_dropout = 0.,
|
308 |
+
ff_dropout = 0.,
|
309 |
+
ff_mult = 4,
|
310 |
+
pad_id = 0
|
311 |
+
):
|
312 |
+
super().__init__()
|
313 |
+
self.dim = dim
|
314 |
+
|
315 |
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
316 |
+
self.pos_emb = nn.Embedding(max_seq_len, dim)
|
317 |
+
|
318 |
+
self.cls_token = nn.Parameter(torch.randn(dim))
|
319 |
+
|
320 |
+
self.transformer = Transformer(
|
321 |
+
dim = dim,
|
322 |
+
depth = depth,
|
323 |
+
dim_head = dim_head,
|
324 |
+
heads = heads,
|
325 |
+
attn_dropout = attn_dropout,
|
326 |
+
ff_dropout = ff_dropout,
|
327 |
+
ff_mult = ff_mult
|
328 |
+
)
|
329 |
+
|
330 |
+
self.pad_id = pad_id
|
331 |
+
self.norm = LayerNorm(dim)
|
332 |
+
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
x = None,
|
336 |
+
raw_texts: Optional[List[str]] = None,
|
337 |
+
mask = None
|
338 |
+
):
|
339 |
+
assert exists(x) ^ exists(raw_texts)
|
340 |
+
|
341 |
+
if exists(raw_texts):
|
342 |
+
x = tokenizer.tokenize(raw_texts)
|
343 |
+
|
344 |
+
if not exists(mask):
|
345 |
+
mask = x != self.pad_id
|
346 |
+
|
347 |
+
b, n, device = *x.shape, x.device
|
348 |
+
|
349 |
+
# token embedding + positional embedding
|
350 |
+
|
351 |
+
x = self.token_emb(x)
|
352 |
+
x = x + self.pos_emb(torch.arange(n, device = device))
|
353 |
+
|
354 |
+
# cls tokens, as in bert
|
355 |
+
|
356 |
+
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
|
357 |
+
x, ps = pack([cls_tokens, x], 'b * d')
|
358 |
+
|
359 |
+
# account for attending to cls token with self attention mask
|
360 |
+
|
361 |
+
mask = F.pad(mask, (1, 0), value = True)
|
362 |
+
|
363 |
+
# attention
|
364 |
+
|
365 |
+
x = self.transformer(x, mask = mask)
|
366 |
+
|
367 |
+
# unpack the cls tokens
|
368 |
+
|
369 |
+
cls_tokens, _ = unpack(x, ps, 'b * d')
|
370 |
+
|
371 |
+
return self.norm(cls_tokens)
|
372 |
+
|
373 |
+
# main classes
|
374 |
+
|
375 |
+
@beartype
|
376 |
+
class MuLaN(nn.Module):
|
377 |
+
def __init__(
|
378 |
+
self,
|
379 |
+
audio_transformer: AudioSpectrogramTransformer,
|
380 |
+
text_transformer: TextTransformer,
|
381 |
+
dim_latent = 128, # they use 128
|
382 |
+
decoupled_contrastive_learning = True, # think this was used, make it optional
|
383 |
+
):
|
384 |
+
super().__init__()
|
385 |
+
self.dim_latent = dim_latent
|
386 |
+
|
387 |
+
self.audio = audio_transformer
|
388 |
+
self.text = text_transformer
|
389 |
+
|
390 |
+
self.temperature = nn.Parameter(torch.tensor(1.))
|
391 |
+
|
392 |
+
self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
|
393 |
+
self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)
|
394 |
+
|
395 |
+
self.decoupled_contrastive_learning = decoupled_contrastive_learning
|
396 |
+
|
397 |
+
def get_audio_latents(
|
398 |
+
self,
|
399 |
+
wavs
|
400 |
+
):
|
401 |
+
audio_embeds = self.audio(wavs)
|
402 |
+
audio_latents = self.audio_to_latents(audio_embeds)
|
403 |
+
return l2norm(audio_latents)
|
404 |
+
|
405 |
+
def get_text_latents(
|
406 |
+
self,
|
407 |
+
texts = None,
|
408 |
+
raw_texts: Optional[List[str]] = None
|
409 |
+
):
|
410 |
+
text_embeds = self.text(texts)
|
411 |
+
text_latents = self.text_to_latents(text_embeds)
|
412 |
+
return l2norm(text_latents)
|
413 |
+
|
414 |
+
def forward(
|
415 |
+
self,
|
416 |
+
wavs,
|
417 |
+
texts = None,
|
418 |
+
raw_texts: Optional[List[str]] = None,
|
419 |
+
return_similarities = False
|
420 |
+
):
|
421 |
+
batch, device = wavs.shape[0], wavs.device
|
422 |
+
|
423 |
+
audio_latents = self.get_audio_latents(wavs)
|
424 |
+
text_latents = self.get_text_latents(texts, raw_texts = raw_texts)
|
425 |
+
|
426 |
+
cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)
|
427 |
+
|
428 |
+
assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal'
|
429 |
+
|
430 |
+
if return_similarities:
|
431 |
+
return cosine_sim
|
432 |
+
|
433 |
+
cosine_sim = cosine_sim * self.temperature.exp()
|
434 |
+
|
435 |
+
cosine_sim_exp = cosine_sim.exp()
|
436 |
+
|
437 |
+
numerator = cosine_sim_exp.diag()
|
438 |
+
|
439 |
+
if self.decoupled_contrastive_learning:
|
440 |
+
eye = torch.eye(batch, device = device)
|
441 |
+
cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)
|
442 |
+
|
443 |
+
denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum')
|
444 |
+
|
445 |
+
contrastive_loss = -log(numerator / denominator)
|
446 |
+
return contrastive_loss.mean()
|
447 |
+
|
448 |
+
# music lm
|
449 |
+
|
450 |
+
@beartype
|
451 |
+
class MuLaNEmbedQuantizer(nn.Module):
|
452 |
+
def __init__(
|
453 |
+
self,
|
454 |
+
mulan: MuLaN,
|
455 |
+
conditioning_dims: Tuple[int, ...],
|
456 |
+
rq_num_quantizers = 8,
|
457 |
+
rq_ema_decay = 0.9,
|
458 |
+
codebook_size = 1024,
|
459 |
+
namespaces: Tuple[str, ...] = ('semantic', 'coarse', 'fine'),
|
460 |
+
|
461 |
+
):
|
462 |
+
super().__init__()
|
463 |
+
self.mulan = mulan
|
464 |
+
|
465 |
+
assert len(namespaces) > 0
|
466 |
+
self.namespaces = namespaces
|
467 |
+
self.conditioning_dims = conditioning_dims
|
468 |
+
|
469 |
+
assert len(conditioning_dims) == len(namespaces), 'number of conditioning dimensions must be equal to number of namespaces'
|
470 |
+
|
471 |
+
dim = mulan.dim_latent
|
472 |
+
|
473 |
+
self.rq = ResidualVQ(
|
474 |
+
dim = dim,
|
475 |
+
num_quantizers = rq_num_quantizers,
|
476 |
+
codebook_size = codebook_size,
|
477 |
+
decay = rq_ema_decay,
|
478 |
+
commitment_weight = 0, # only use EMA to update codebooks
|
479 |
+
kmeans_init = True,
|
480 |
+
threshold_ema_dead_code = 2,
|
481 |
+
quantize_dropout = False # no quantize dropout
|
482 |
+
)
|
483 |
+
|
484 |
+
self.dim = dim
|
485 |
+
self.num_codebooks = rq_num_quantizers
|
486 |
+
|
487 |
+
self.cond_embeddings = nn.ParameterDict({})
|
488 |
+
|
489 |
+
for namespace, conditioning_dim in zip(namespaces, conditioning_dims):
|
490 |
+
cond_embeddings = nn.Parameter(torch.randn(rq_num_quantizers, codebook_size, conditioning_dim))
|
491 |
+
nn.init.normal_(cond_embeddings, std = 0.02)
|
492 |
+
|
493 |
+
self.cond_embeddings[namespace] = cond_embeddings
|
494 |
+
|
495 |
+
self.set_default_namespace(namespaces[0])
|
496 |
+
|
497 |
+
def set_default_namespace(self, namespace):
|
498 |
+
self._default_namespace = namespace
|
499 |
+
|
500 |
+
def forward(
|
501 |
+
self,
|
502 |
+
wavs = None,
|
503 |
+
texts = None,
|
504 |
+
namespace = None
|
505 |
+
):
|
506 |
+
assert exists(wavs) ^ exists(texts)
|
507 |
+
|
508 |
+
namespace = default(namespace, self._default_namespace)
|
509 |
+
assert namespace in self.namespaces, f'namespace {namespace} not found'
|
510 |
+
cond_embeddings = self.cond_embeddings[namespace]
|
511 |
+
|
512 |
+
with torch.no_grad():
|
513 |
+
self.mulan.eval()
|
514 |
+
|
515 |
+
# sound and language live in joint embedding space because of contrastive learning
|
516 |
+
|
517 |
+
if exists(wavs):
|
518 |
+
latents = self.mulan.get_audio_latents(wavs)
|
519 |
+
elif exists(texts):
|
520 |
+
latents = self.mulan.get_text_latents(texts)
|
521 |
+
|
522 |
+
_, indices, _ = self.rq(latents)
|
523 |
+
|
524 |
+
batch, num_codebooks, dim = indices.shape[0], self.num_codebooks, cond_embeddings.shape[-1]
|
525 |
+
|
526 |
+
cond_embeddings = repeat(cond_embeddings, 'q c d -> b q c d', b = batch)
|
527 |
+
indices = repeat(indices, 'b q -> b q 1 d', q = num_codebooks, d = dim)
|
528 |
+
|
529 |
+
cond_embeddings = cond_embeddings.gather(2, indices)
|
530 |
+
return rearrange(cond_embeddings, 'b q 1 d -> b q d')
|
531 |
+
|
532 |
+
@beartype
|
533 |
+
class MusicLM(nn.Module):
|
534 |
+
def __init__(
|
535 |
+
self,
|
536 |
+
audio_lm: AudioLM,
|
537 |
+
mulan_embed_quantizer: MuLaNEmbedQuantizer
|
538 |
+
):
|
539 |
+
super().__init__()
|
540 |
+
self.mulan_embed_quantizer = mulan_embed_quantizer
|
541 |
+
self.audio_lm = audio_lm
|
542 |
+
|
543 |
+
@torch.no_grad()
|
544 |
+
def forward(
|
545 |
+
self,
|
546 |
+
raw_texts: List[str],
|
547 |
+
**audio_lm_kwargs
|
548 |
+
):
|
549 |
+
self.eval()
|
550 |
+
|
551 |
+
texts = tokenizer.tokenize(raw_texts)
|
552 |
+
cond_tokens = self.mulan_embed_quantizer(texts = texts)
|
553 |
+
|
554 |
+
wavs = self.audio_lm.generate(cond_tokens = cond_tokens, **audio_lm_kwargs)
|
555 |
+
return wavs
|