Hugo Flores commited on
Commit
b54865d
1 Parent(s): 275afd0
requirements.txt CHANGED
@@ -26,5 +26,4 @@ jupyter-client==6.1.12
26
  tensorboardX
27
  gradio
28
  einops
29
- flash-attn
30
  frechet_audio_distance
 
26
  tensorboardX
27
  gradio
28
  einops
 
29
  frechet_audio_distance
setup.py CHANGED
@@ -20,7 +20,7 @@ setup(
20
  description="Generative Music Modeling.",
21
  long_description=long_description,
22
  long_description_content_type="text/markdown",
23
- author="Hugo Flores García",
24
  author_email="hfgacrcia@descript.com",
25
  url="https://github.com/descriptinc/lyrebird-vampnet",
26
  license="MIT",
@@ -37,7 +37,6 @@ setup(
37
  "google-cloud-logging==2.2.0",
38
  "torchmetrics>=0.7.3",
39
  "einops",
40
- "flash-attn",
41
  "frechet_audio_distance"
42
  ],
43
  )
 
20
  description="Generative Music Modeling.",
21
  long_description=long_description,
22
  long_description_content_type="text/markdown",
23
+ author="Hugo Flores García, Prem Seetharaman",
24
  author_email="hfgacrcia@descript.com",
25
  url="https://github.com/descriptinc/lyrebird-vampnet",
26
  license="MIT",
 
37
  "google-cloud-logging==2.2.0",
38
  "torchmetrics>=0.7.3",
39
  "einops",
 
40
  "frechet_audio_distance"
41
  ],
42
  )
vampnet/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
  from . import modules
3
  from . import scheduler
4
- from . import enchilada
5
 
6
  __version__ = "0.0.1"
 
1
 
2
  from . import modules
3
  from . import scheduler
4
+ from .interface import Interface
5
 
6
  __version__ = "0.0.1"
vampnet/enchilada.py DELETED
@@ -1,179 +0,0 @@
1
- import os
2
- from pathlib import Path
3
-
4
- import torch
5
- from audiotools import AudioSignal
6
-
7
- from .modules.transformer import VampNet
8
- from lac.model.lac import LAC
9
-
10
-
11
- class TheWholeEnchilada:
12
- def __init__(
13
- self,
14
- coarse_ckpt: str,
15
- coarse2fine_ckpt: str,
16
- codec_ckpt: str,
17
- device: str = "cpu",
18
- ):
19
- self.codec = LAC.load(Path(codec_ckpt))
20
- self.codec.eval()
21
- self.codec.to(device)
22
-
23
- self.coarse = VampNet.load(location=Path(coarse_ckpt), map_location="cpu")
24
- self.coarse.to(device)
25
- self.coarse.eval()
26
-
27
- self.coarse2fine = VampNet.load(
28
- location=Path(coarse2fine_ckpt), map_location="cpu"
29
- )
30
- # FIXME
31
- print(
32
- f"WARNING: PATCHING coarse2fine seq_len to 288, for backwards compatibility with a specific jazzpop model. it used to be {self.coarse2fine.seq_len}"
33
- )
34
- self.coarse2fine.seq_len = 288
35
-
36
- self.coarse2fine.to(device)
37
- self.coarse2fine.eval()
38
-
39
- self.device = device
40
-
41
- def seconds_to_tokens(self, seconds: float):
42
- return int(seconds * self.codec.sample_rate / self.codec.hop_length)
43
-
44
- def to(self, device):
45
- self.device = device
46
- self.coarse.to(device)
47
- self.coarse2fine.to(device)
48
- self.codec.to(device)
49
- return self
50
-
51
- def encode(self, signal: AudioSignal):
52
- with torch.inference_mode():
53
- # coarse z
54
- cz = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
55
-
56
- return cz
57
-
58
- def vamp(
59
- self,
60
- signal,
61
- prefix_dur_s: float = 1.25,
62
- suffix_dur_s: float = 1.25,
63
- downsample_hint: bool = True,
64
- downsample_factor: int = 4,
65
- num_loops: int = 3,
66
- **kwargs,
67
- ):
68
- """
69
- Loop imputation of a signal.
70
- """
71
- signal.to(self.device).resample(self.codec.sample_rate).to_mono()
72
-
73
- z = self.encode(signal)
74
-
75
- cz = z[:, : self.coarse.n_codebooks, :].clone()
76
- original_cz = cz.clone()
77
- seq_len = original_cz.shape[-1]
78
- assert (
79
- seq_len == self.coarse.seq_len
80
- ), f"expected seq_len {self.coarse.seq_len}, got {seq_len} for token sequence length. Is your signal the same duration as the model was trained with? "
81
-
82
- vamp_hop_s = prefix_dur_s
83
- vamp_hop = self.seconds_to_tokens(vamp_hop_s)
84
-
85
- cmask = torch.ones_like(cz)
86
-
87
- if downsample_hint:
88
- # downsample by factor of 4
89
- for i in range(cmask.shape[-1]):
90
- if i % downsample_factor == 0:
91
- cmask[:, :, i] = 0
92
-
93
- if prefix_dur_s > 0:
94
- prefix_len = self.seconds_to_tokens(prefix_dur_s)
95
- cmask[:, :, :prefix_len] = 0
96
- print(f"prefix_len: {prefix_len}")
97
- else:
98
- prefix_len = 0
99
-
100
- if suffix_dur_s > 0:
101
- suffix_len = self.seconds_to_tokens(suffix_dur_s)
102
- cmask[:, :, -suffix_len:] = 0
103
- print(f"suffix_len: {suffix_len}")
104
- else:
105
- suffix_len = 0
106
-
107
- prefix_z = cz[:, :, :prefix_len]
108
-
109
- coarse_vamp = [prefix_z.clone()]
110
- for i in range(num_loops):
111
- sampled_cz = self.coarse.sample(
112
- codec=self.codec,
113
- time_steps=seq_len,
114
- mask=cmask,
115
- start_tokens=cz,
116
- return_signal=False,
117
- **kwargs,
118
- )
119
-
120
- new_prefix = sampled_cz[:, :, prefix_len : prefix_len + vamp_hop]
121
- coarse_vamp.append(new_prefix.clone())
122
-
123
- # replace the prefix in cz with the new prefix
124
- # don't worry about a copy of the prefix still being
125
- # in the mask area, since that will be masked out
126
- cz[:, :, :vamp_hop] = new_prefix.clone()
127
- print("to append and to prefix")
128
-
129
- # we're done, so add the suffix
130
- coarse_vamp.append(sampled_cz[:, :, prefix_len + vamp_hop :])
131
-
132
- # concatenate the vamps
133
- coarse_vamp = torch.cat(coarse_vamp, dim=-1)
134
-
135
- # add a layer of
136
- fine_prefix = z[:, self.coarse.n_codebooks :, :prefix_len]
137
- fine_suffix = z[:, self.coarse.n_codebooks :, -suffix_len:]
138
- fine_vamp = torch.randint(
139
- 0,
140
- self.coarse2fine.vocab_size,
141
- (
142
- coarse_vamp.shape[0],
143
- self.coarse2fine.n_predict_codebooks,
144
- coarse_vamp.shape[-1],
145
- ),
146
- ).to(self.device)
147
- fine_vamp[:, :, :prefix_len] = fine_prefix
148
- fine_vamp[:, :, -suffix_len:] = fine_suffix
149
-
150
- vamp_z = torch.cat([coarse_vamp, fine_vamp], dim=1)
151
-
152
- # now we sample from the coarse2fine model
153
- # to get the fine details
154
- start_pos = 0
155
-
156
- c2f_vamp = []
157
- while start_pos < vamp_z.shape[-1]:
158
- end_pos = min(start_pos + self.coarse2fine.seq_len, vamp_z.shape[-1])
159
-
160
- c2fz = vamp_z[:, :, start_pos:end_pos]
161
- self.coarse2fine: VampNet
162
- sampled_c2fz = self.coarse2fine.sample(
163
- codec=self.codec,
164
- start_tokens=c2fz,
165
- return_signal=False,
166
- mask=None,
167
- )
168
- c2f_vamp.append(sampled_c2fz)
169
- start_pos += self.coarse2fine.seq_len
170
-
171
- c2f_vamp = torch.cat(c2f_vamp, dim=-1)
172
-
173
- # make it a signal
174
- vamp_signal = self.coarse2fine.to_signal(c2f_vamp, self.codec)
175
-
176
- return {
177
- "full": vamp_signal,
178
- "coarse": self.coarse.to_signal(coarse_vamp, self.codec),
179
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vampnet/interface.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import math
4
+
5
+ import torch
6
+ from audiotools import AudioSignal
7
+
8
+ from .modules.transformer import VampNet
9
+ from lac.model.lac import LAC
10
+
11
+
12
+ class Interface:
13
+ def __init__(
14
+ self,
15
+ coarse_ckpt: str,
16
+ coarse2fine_ckpt: str,
17
+ codec_ckpt: str,
18
+ device: str = "cpu",
19
+ coarse_chunk_size_s: int = 5,
20
+ coarse2fine_chunk_size_s: int = 3,
21
+ ):
22
+ self.codec = LAC.load(Path(codec_ckpt))
23
+ self.codec.eval()
24
+ self.codec.to(device)
25
+
26
+ self.coarse = VampNet.load(location=Path(coarse_ckpt), map_location="cpu")
27
+ self.coarse.to(device)
28
+ self.coarse.eval()
29
+ self.coarse.chunk_size_s = coarse_chunk_size_s
30
+
31
+ self.c2f = VampNet.load(
32
+ location=Path(coarse2fine_ckpt), map_location="cpu"
33
+ )
34
+ self.c2f.to(device)
35
+ self.c2f.eval()
36
+ self.c2f.chunk_size_s = coarse2fine_chunk_size_s
37
+
38
+ self.device = device
39
+
40
+ def s2t(self, seconds: float):
41
+ """seconds to tokens"""
42
+ return int(seconds * self.codec.sample_rate / self.codec.hop_length)
43
+
44
+ def to(self, device):
45
+ self.device = device
46
+ self.coarse.to(device)
47
+ self.c2f.to(device)
48
+ self.codec.to(device)
49
+ return self
50
+
51
+ def to_signal(self, z: torch.Tensor):
52
+ return self.coarse.to_signal(z, self.codec)
53
+
54
+ @torch.inference_mode()
55
+ def encode(self, signal: AudioSignal):
56
+ signal = signal.clone().to(self.device).resample(self.codec.sample_rate).to_mono()
57
+ z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
58
+ return z
59
+
60
+ def coarse_to_fine(
61
+ self,
62
+ coarse_z: torch.Tensor,
63
+ **kwargs
64
+ ):
65
+ length = coarse_z.shape[-1]
66
+ chunk_len = self.s2t(self.c2f.chunk_size_s)
67
+ n_chunks = math.ceil(coarse_z.shape[-1] / chunk_len)
68
+
69
+ # zero pad to chunk_len
70
+ if length % chunk_len != 0:
71
+ pad_len = chunk_len - (length % chunk_len)
72
+ coarse_z = torch.nn.functional.pad(coarse_z, (0, pad_len))
73
+
74
+ n_codebooks_to_append = self.c2f.n_codebooks - coarse_z.shape[1]
75
+ if n_codebooks_to_append > 0:
76
+ coarse_z = torch.cat([
77
+ coarse_z,
78
+ torch.zeros(coarse_z.shape[0], n_codebooks_to_append, coarse_z.shape[-1]).long().to(self.device)
79
+ ], dim=1)
80
+
81
+ fine_z = []
82
+ for i in range(n_chunks):
83
+ chunk = coarse_z[:, :, i * chunk_len : (i + 1) * chunk_len]
84
+ chunk = self.c2f.sample(
85
+ codec=self.codec,
86
+ time_steps=chunk_len,
87
+ start_tokens=chunk,
88
+ return_signal=False,
89
+ )
90
+ fine_z.append(chunk)
91
+
92
+ fine_z = torch.cat(fine_z, dim=-1)
93
+ return fine_z[:, :, :length].clone()
94
+
95
+ def coarse_vamp(
96
+ self,
97
+ signal,
98
+ prefix_dur_s: float = 1.25,
99
+ suffix_dur_s: float = 1.25,
100
+ num_loops: int = 3,
101
+ mode="impute",
102
+ downsample_factor: int = None,
103
+ debug=False,
104
+ **kwargs
105
+ ):
106
+ z = self.encode(signal)
107
+
108
+ assert signal.duration == self.coarse.chunk_size_s, "signal duration must match coarse chunk size for now"
109
+
110
+ # coarse z
111
+ cz = z[:, : self.coarse.n_codebooks, :].clone()
112
+ c_seq_len = cz.shape[-1]
113
+ n_prefix = self.s2t(prefix_dur_s)
114
+ n_suffix = self.s2t(suffix_dur_s)
115
+
116
+ # we'll keep the final codes sequence here
117
+ c_vamp = {
118
+ 'prefix': [cz[:, :, :n_prefix].clone()],
119
+ 'suffix': [cz[:, :, c_seq_len-n_suffix:].clone()]
120
+ }
121
+
122
+ _cz = cz.clone()
123
+ for _ in range(num_loops):
124
+ # add noise
125
+ cz_masked, cz_mask = self.coarse.add_noise(
126
+ _cz, r=0.0,
127
+ n_prefix=n_prefix,
128
+ n_suffix=n_suffix,
129
+ downsample_factor=downsample_factor
130
+ )
131
+ if debug:
132
+ print("tokens to infer")
133
+ self.to_signal(cz_masked).cpu().widget()
134
+
135
+ # sample!
136
+ cz_sampled = self.coarse.sample(
137
+ codec=self.codec,
138
+ time_steps=self.s2t(self.coarse.chunk_size_s),
139
+ start_tokens=_cz,
140
+ mask=cz_mask,
141
+ return_signal=False,
142
+ **kwargs
143
+ )
144
+
145
+ if debug:
146
+ print("tokens sampled")
147
+ self.to_signal(cz_sampled).cpu().widget()
148
+
149
+ cz_imputed = cz_sampled[:, :, n_prefix:c_seq_len-n_suffix].clone()
150
+
151
+ if mode == "impute":
152
+ # split the imputed codes into two halves
153
+ cz_imputed_a = cz_imputed[:, :, : cz_imputed.shape[-1] // 2].clone()
154
+ cz_imputed_b = cz_imputed[:, :, cz_imputed.shape[-1] // 2 :].clone()
155
+ elif mode == "continue":
156
+ cz_imputed_a = cz_imputed[:, :, : cz_imputed.shape[-1]].clone()
157
+ cz_imputed_b = _cz[:, :, :0].clone() # empty
158
+ elif mode == "reverse-continue":
159
+ cz_imputed_a = _cz[:, :, :0].clone() # empty
160
+ cz_imputed_b = cz_imputed[:, :, : cz_imputed.shape[-1]].clone()
161
+ else:
162
+ raise ValueError(f"mode {mode} not supported")
163
+
164
+ if debug:
165
+ # add to our c_vamp
166
+ if cz_imputed_a.shape[-1] > 0:
167
+ print("new_prefix added")
168
+ self.to_signal(cz_imputed_a).cpu().widget()
169
+ if cz_imputed_b.shape[-1] > 0:
170
+ print("new_suffix added")
171
+ self.to_signal(cz_imputed_b).cpu().widget()
172
+
173
+ c_vamp['prefix'].append(cz_imputed_a.clone())
174
+ c_vamp['suffix'].insert(0, cz_imputed_b.clone())
175
+
176
+ n_to_insert = c_seq_len - (cz_imputed_a.shape[-1] + cz_imputed_b.shape[-1])
177
+ to_insert = torch.zeros(cz_imputed_a.shape[0], cz_imputed_a.shape[1], n_to_insert).long().to(self.device)
178
+ _cz = torch.cat([cz_imputed_a, to_insert, cz_imputed_b], dim=-1)
179
+
180
+ if debug:
181
+ print("tokens to infer next round (area to insert in the middle)")
182
+ self.to_signal(_cz).cpu().widget()
183
+
184
+
185
+
186
+
187
+ prefix_codes = torch.cat(c_vamp['prefix'], dim=-1)
188
+ suffix_codes = torch.cat(c_vamp['suffix'], dim=-1)
189
+ c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
190
+ return c_vamp
191
+
192
+
193
+ def coarse_vamp_v2(
194
+ self,
195
+ signal,
196
+ prefix_dur_s: float = 1.25,
197
+ suffix_dur_s: float = 1.25,
198
+ num_loops: int = 3,
199
+ downsample_factor: int = None,
200
+ debug=False,
201
+ **kwargs
202
+ ):
203
+ z = self.encode(signal)
204
+
205
+ assert signal.duration == self.coarse.chunk_size_s, "signal duration must match coarse chunk size for now"
206
+
207
+ # coarse z
208
+ cz = z[:, : self.coarse.n_codebooks, :].clone()
209
+ c_seq_len = cz.shape[-1]
210
+ n_prefix = self.s2t(prefix_dur_s)
211
+ n_suffix = self.s2t(suffix_dur_s)
212
+
213
+ assert n_prefix + n_suffix < c_seq_len, "prefix and suffix must be smaller than the chunk size"
214
+
215
+ # we'll keep the final codes sequence here
216
+ c_vamp = {
217
+ 'prefix': [cz[:, :, :n_prefix].clone()],
218
+ 'suffix': [cz[:, :, c_seq_len-n_suffix:].clone()]
219
+ }
220
+
221
+ _cz = cz.clone()
222
+ cz_mask = None
223
+ for _ in range(num_loops):
224
+ # add noise
225
+ cz_masked, cz_mask = self.coarse.add_noise(
226
+ _cz, r=0.0,
227
+ n_prefix=n_prefix,
228
+ n_suffix=n_suffix,
229
+ downsample_factor=downsample_factor,
230
+ mask=cz_mask
231
+ )
232
+ if debug:
233
+ print("tokens to infer")
234
+ self.to_signal(cz_masked).cpu().widget()
235
+
236
+ # sample!
237
+ if debug:
238
+ print(f"mask: {cz_mask[:,0,:]}")
239
+ print(f"z: {_cz[:,0,:]}")
240
+ cz_sampled = self.coarse.sample(
241
+ codec=self.codec,
242
+ time_steps=self.s2t(self.coarse.chunk_size_s),
243
+ start_tokens=_cz,
244
+ mask=cz_mask,
245
+ return_signal=False,
246
+ **kwargs
247
+ )
248
+
249
+ if debug:
250
+ print("tokens sampled")
251
+ self.to_signal(cz_sampled).cpu().widget()
252
+
253
+ # the z that was generated
254
+ cz_generated = cz_sampled[:, :, n_prefix:c_seq_len-n_suffix].clone()
255
+ n_generated = cz_generated.shape[-1]
256
+
257
+ # create the new prefix and suffix
258
+ # we'll make sure that the number of prefix and suffix
259
+ # tokens is the same as the original
260
+ # but we do want to advance the sequence as much as we can
261
+ if n_prefix > 0 and n_suffix > 0:
262
+ # we have both prefix and suffix, so we'll split the generated
263
+ # codes in two halves
264
+ prefix_start_idx = n_generated // 2
265
+ prefix_stop_idx = prefix_start_idx + n_prefix
266
+ assert prefix_start_idx >= 0, "internal error"
267
+
268
+ suffix_start_idx = n_prefix + n_generated // 2
269
+ suffix_stop_idx = suffix_start_idx + n_suffix
270
+ assert suffix_stop_idx <= cz_sampled.shape[-1], "internal error"
271
+
272
+ cz_new_prefix = cz_sampled[:, :, prefix_start_idx:prefix_stop_idx].clone()
273
+ cz_new_suffix = cz_sampled[:, :, suffix_start_idx:suffix_stop_idx].clone()
274
+
275
+ c_vamp['prefix'].append(cz_generated[:,:,:n_generated//2])
276
+ c_vamp['suffix'].insert(0, cz_generated[:,:,n_generated//2:])
277
+
278
+ elif n_prefix > 0:
279
+ # we only have a prefix
280
+ prefix_start_idx = n_generated
281
+ prefix_stop_idx = prefix_start_idx + n_prefix
282
+
283
+ cz_new_prefix = cz_sampled[:, :, prefix_start_idx:prefix_stop_idx].clone()
284
+ cz_new_suffix = _cz[:, :, :0].clone()
285
+
286
+
287
+ c_vamp['prefix'].append(cz_generated)
288
+
289
+ elif n_suffix > 0:
290
+ # we only have a suffix, so everything starting at 0 is generated
291
+ suffix_stop_idx = max(n_generated, n_suffix)
292
+ suffix_start_idx = suffix_stop_idx - n_suffix
293
+
294
+ cz_new_prefix = _cz[:, :, :0].clone()
295
+ cz_new_suffix = cz_sampled[:, :, suffix_start_idx:suffix_stop_idx].clone()
296
+
297
+ c_vamp['suffix'].insert(0, cz_generated)
298
+
299
+
300
+ n_to_insert = c_seq_len - (cz_new_prefix.shape[-1] + cz_new_suffix.shape[-1])
301
+ to_insert = torch.zeros(cz_new_prefix.shape[0], cz_new_prefix.shape[1], n_to_insert).long().to(self.device)
302
+ _cz = torch.cat([cz_new_prefix, to_insert, cz_new_suffix], dim=-1)
303
+
304
+ to_insert_mask = torch.zeros_like(_cz).long().to(self.device)
305
+ to_insert_mask[:, :, cz_new_prefix.shape[-1]:cz_new_prefix.shape[-1]+n_to_insert] = 1
306
+ cz_mask = (cz_mask + to_insert_mask).bool().long()
307
+
308
+
309
+ if debug:
310
+ print("tokens to infer next round (area to insert in the middle)")
311
+ self.to_signal(_cz).cpu().widget()
312
+
313
+
314
+ prefix_codes = torch.cat(c_vamp['prefix'], dim=-1)
315
+ suffix_codes = torch.cat(c_vamp['suffix'], dim=-1)
316
+ c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
317
+ return c_vamp
318
+
319
+
320
+
321
+
322
+
323
+
324
+
325
+
326
+
327
+
328
+
329
+
330
+
331
+
332
+
vampnet/modules/base.py CHANGED
@@ -24,6 +24,9 @@ def gumbel_sample(t, temperature=1.0, dim=-1):
24
  return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
25
 
26
 
 
 
 
27
  class VampBase(at.ml.BaseModel):
28
  def forward(self, x: torch.Tensor, r: torch.Tensor):
29
  raise NotImplementedError
@@ -36,20 +39,40 @@ class VampBase(at.ml.BaseModel):
36
  mask: Optional[torch.Tensor] = None,
37
  n_prefix: Optional[torch.Tensor] = None,
38
  n_suffix: Optional[torch.Tensor] = None,
 
39
  ) -> Tuple[torch.Tensor, torch.Tensor]:
40
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
41
 
42
  if mask is None:
 
 
43
  r = self.gamma(r)[:, None, None]
44
  probs = torch.ones_like(x) * r
45
 
46
  # if we have a prefix or suffix, set their mask prob to 0
47
  if n_prefix is not None:
 
 
48
  for i, n in enumerate(n_prefix):
49
- probs[i, :, :n] = 0.0
 
50
  if n_suffix is not None:
 
 
51
  for i, n in enumerate(n_suffix):
52
- probs[i, :, -n:] = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  mask = torch.bernoulli(probs)
55
  mask = mask.round().long()
@@ -347,7 +370,9 @@ class VampBase(at.ml.BaseModel):
347
  if num_to_keep > 0:
348
  probs = logits.softmax(dim=-1)
349
 
350
- keep_probs = F.one_hot(z, self.vocab_size)[:, :, :]
 
 
351
 
352
  probs = rearrange(
353
  probs, "b (t c) p -> b c t p", c=n_infer_codebooks
 
24
  return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
25
 
26
 
27
+ def scalar_to_batch_tensor(x, batch_size):
28
+ return torch.tensor(x).repeat(batch_size)
29
+
30
  class VampBase(at.ml.BaseModel):
31
  def forward(self, x: torch.Tensor, r: torch.Tensor):
32
  raise NotImplementedError
 
39
  mask: Optional[torch.Tensor] = None,
40
  n_prefix: Optional[torch.Tensor] = None,
41
  n_suffix: Optional[torch.Tensor] = None,
42
+ downsample_factor: Optional[int] = None,
43
  ) -> Tuple[torch.Tensor, torch.Tensor]:
44
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
45
 
46
  if mask is None:
47
+ if not isinstance(r, torch.Tensor):
48
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
49
  r = self.gamma(r)[:, None, None]
50
  probs = torch.ones_like(x) * r
51
 
52
  # if we have a prefix or suffix, set their mask prob to 0
53
  if n_prefix is not None:
54
+ if not isinstance(n_prefix, torch.Tensor):
55
+ n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
56
  for i, n in enumerate(n_prefix):
57
+ if n > 0:
58
+ probs[i, :, :n] = 0.0
59
  if n_suffix is not None:
60
+ if not isinstance(n_suffix, torch.Tensor):
61
+ n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
62
  for i, n in enumerate(n_suffix):
63
+ if n > 0:
64
+ probs[i, :, -n:] = 0.0
65
+
66
+ # if we have a downsample factor, set the mask prob to 0
67
+ if downsample_factor is not None:
68
+ if not isinstance(downsample_factor, torch.Tensor):
69
+ downsample_factor = scalar_to_batch_tensor(downsample_factor, x.shape[0])
70
+ for i, factor in enumerate(downsample_factor):
71
+ if factor == 0:
72
+ continue
73
+ for j in range(probs.shape[-1]):
74
+ if j % factor == 0:
75
+ probs[i, :, j] = 0.0
76
 
77
  mask = torch.bernoulli(probs)
78
  mask = mask.round().long()
 
370
  if num_to_keep > 0:
371
  probs = logits.softmax(dim=-1)
372
 
373
+ # do mod self.vocab_size to make sure we don't sample from the mask token
374
+ # in case the mask token was in the og z
375
+ keep_probs = F.one_hot(z%self.vocab_size, self.vocab_size)[:, :, :]
376
 
377
  probs = rearrange(
378
  probs, "b (t c) p -> b c t p", c=n_infer_codebooks