unpairedelectron07 commited on
Commit
8d183a7
1 Parent(s): 3d857a9

Upload 3 files

Browse files
audiocraft/quantization/base.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Base class for all quantizers.
9
+ """
10
+
11
+ from dataclasses import dataclass, field
12
+ import typing as tp
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+
18
+ @dataclass
19
+ class QuantizedResult:
20
+ x: torch.Tensor
21
+ codes: torch.Tensor
22
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
23
+ penalty: tp.Optional[torch.Tensor] = None
24
+ metrics: dict = field(default_factory=dict)
25
+
26
+
27
+ class BaseQuantizer(nn.Module):
28
+ """Base class for quantizers.
29
+ """
30
+
31
+ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
32
+ """
33
+ Given input tensor x, returns first the quantized (or approximately quantized)
34
+ representation along with quantized codes, bandwidth, and any penalty term for the loss.
35
+ Finally, this returns a dict of metrics to update logging etc.
36
+ Frame rate must be passed so that the bandwidth is properly computed.
37
+ """
38
+ raise NotImplementedError()
39
+
40
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
41
+ """Encode a given input tensor with the specified sample rate at the given bandwidth."""
42
+ raise NotImplementedError()
43
+
44
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
45
+ """Decode the given codes to the quantized representation."""
46
+ raise NotImplementedError()
47
+
48
+ @property
49
+ def total_codebooks(self):
50
+ """Total number of codebooks."""
51
+ raise NotImplementedError()
52
+
53
+ @property
54
+ def num_codebooks(self):
55
+ """Number of active codebooks."""
56
+ raise NotImplementedError()
57
+
58
+ def set_num_codebooks(self, n: int):
59
+ """Set the number of active codebooks."""
60
+ raise NotImplementedError()
61
+
62
+
63
+ class DummyQuantizer(BaseQuantizer):
64
+ """Fake quantizer that actually does not perform any quantization.
65
+ """
66
+ def __init__(self):
67
+ super().__init__()
68
+
69
+ def forward(self, x: torch.Tensor, frame_rate: int):
70
+ q = x.unsqueeze(1)
71
+ return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
72
+
73
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
74
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
75
+ In the case of the DummyQuantizer, the codes are actually identical
76
+ to the input and resulting quantized representation as no quantization is done.
77
+ """
78
+ return x.unsqueeze(1)
79
+
80
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
81
+ """Decode the given codes to the quantized representation.
82
+ In the case of the DummyQuantizer, the codes are actually identical
83
+ to the input and resulting quantized representation as no quantization is done.
84
+ """
85
+ return codes.squeeze(1)
86
+
87
+ @property
88
+ def total_codebooks(self):
89
+ """Total number of codebooks."""
90
+ return 1
91
+
92
+ @property
93
+ def num_codebooks(self):
94
+ """Total number of codebooks."""
95
+ return self.total_codebooks
96
+
97
+ def set_num_codebooks(self, n: int):
98
+ """Set the number of active codebooks."""
99
+ raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
audiocraft/quantization/core_vq.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ from einops import rearrange, repeat
10
+ import flashy
11
+ import torch
12
+ from torch import nn, einsum
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def exists(val: tp.Optional[tp.Any]) -> bool:
17
+ return val is not None
18
+
19
+
20
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
21
+ return val if exists(val) else d
22
+
23
+
24
+ def l2norm(t):
25
+ return F.normalize(t, p=2, dim=-1)
26
+
27
+
28
+ def ema_inplace(moving_avg, new, decay: float):
29
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
30
+
31
+
32
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
33
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
34
+
35
+
36
+ def uniform_init(*shape: int):
37
+ t = torch.empty(shape)
38
+ nn.init.kaiming_uniform_(t)
39
+ return t
40
+
41
+
42
+ def sample_vectors(samples, num: int):
43
+ num_samples, device = samples.shape[0], samples.device
44
+
45
+ if num_samples >= num:
46
+ indices = torch.randperm(num_samples, device=device)[:num]
47
+ else:
48
+ indices = torch.randint(0, num_samples, (num,), device=device)
49
+
50
+ return samples[indices]
51
+
52
+
53
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
54
+ dim, dtype = samples.shape[-1], samples.dtype
55
+
56
+ means = sample_vectors(samples, num_clusters)
57
+
58
+ for _ in range(num_iters):
59
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
60
+ means, "c d -> () c d"
61
+ )
62
+ dists = -(diffs ** 2).sum(dim=-1)
63
+
64
+ buckets = dists.max(dim=-1).indices
65
+ bins = torch.bincount(buckets, minlength=num_clusters)
66
+ zero_mask = bins == 0
67
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
68
+
69
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
70
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
71
+ new_means = new_means / bins_min_clamped[..., None]
72
+
73
+ means = torch.where(zero_mask[..., None], means, new_means)
74
+
75
+ return means, bins
76
+
77
+
78
+ def orthogonal_loss_fn(t):
79
+ # eq (2) from https://arxiv.org/abs/2112.00384
80
+ n = t.shape[0]
81
+ normed_codes = l2norm(t)
82
+ identity = torch.eye(n, device=t.device)
83
+ cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
84
+ return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
85
+
86
+
87
+ class EuclideanCodebook(nn.Module):
88
+ """Codebook with Euclidean distance.
89
+
90
+ Args:
91
+ dim (int): Dimension.
92
+ codebook_size (int): Codebook size.
93
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
94
+ If set to true, run the k-means algorithm on the first training batch and use
95
+ the learned centroids as initialization.
96
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
97
+ decay (float): Decay for exponential moving average over the codebooks.
98
+ epsilon (float): Epsilon value for numerical stability.
99
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
100
+ that have an exponential moving average cluster size less than the specified threshold with
101
+ randomly selected vector from the current batch.
102
+ """
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ codebook_size: int,
107
+ kmeans_init: int = False,
108
+ kmeans_iters: int = 10,
109
+ decay: float = 0.8,
110
+ epsilon: float = 1e-5,
111
+ threshold_ema_dead_code: int = 2,
112
+ ):
113
+ super().__init__()
114
+ self.decay = decay
115
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
116
+ embed = init_fn(codebook_size, dim)
117
+
118
+ self.codebook_size = codebook_size
119
+
120
+ self.kmeans_iters = kmeans_iters
121
+ self.epsilon = epsilon
122
+ self.threshold_ema_dead_code = threshold_ema_dead_code
123
+
124
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
125
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
126
+ self.register_buffer("embed", embed)
127
+ self.register_buffer("embed_avg", embed.clone())
128
+
129
+ @torch.jit.ignore
130
+ def init_embed_(self, data):
131
+ if self.inited:
132
+ return
133
+
134
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
135
+ self.embed.data.copy_(embed)
136
+ self.embed_avg.data.copy_(embed.clone())
137
+ self.cluster_size.data.copy_(cluster_size)
138
+ self.inited.data.copy_(torch.Tensor([True]))
139
+ # Make sure all buffers across workers are in sync after initialization
140
+ flashy.distrib.broadcast_tensors(self.buffers())
141
+
142
+ def replace_(self, samples, mask):
143
+ modified_codebook = torch.where(
144
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
145
+ )
146
+ self.embed.data.copy_(modified_codebook)
147
+
148
+ def expire_codes_(self, batch_samples):
149
+ if self.threshold_ema_dead_code == 0:
150
+ return
151
+
152
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
153
+ if not torch.any(expired_codes):
154
+ return
155
+
156
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
157
+ self.replace_(batch_samples, mask=expired_codes)
158
+ flashy.distrib.broadcast_tensors(self.buffers())
159
+
160
+ def preprocess(self, x):
161
+ x = rearrange(x, "... d -> (...) d")
162
+ return x
163
+
164
+ def quantize(self, x):
165
+ embed = self.embed.t()
166
+ dist = -(
167
+ x.pow(2).sum(1, keepdim=True)
168
+ - 2 * x @ embed
169
+ + embed.pow(2).sum(0, keepdim=True)
170
+ )
171
+ embed_ind = dist.max(dim=-1).indices
172
+ return embed_ind
173
+
174
+ def postprocess_emb(self, embed_ind, shape):
175
+ return embed_ind.view(*shape[:-1])
176
+
177
+ def dequantize(self, embed_ind):
178
+ quantize = F.embedding(embed_ind, self.embed)
179
+ return quantize
180
+
181
+ def encode(self, x):
182
+ shape = x.shape
183
+ # pre-process
184
+ x = self.preprocess(x)
185
+ # quantize
186
+ embed_ind = self.quantize(x)
187
+ # post-process
188
+ embed_ind = self.postprocess_emb(embed_ind, shape)
189
+ return embed_ind
190
+
191
+ def decode(self, embed_ind):
192
+ quantize = self.dequantize(embed_ind)
193
+ return quantize
194
+
195
+ def forward(self, x):
196
+ shape, dtype = x.shape, x.dtype
197
+ x = self.preprocess(x)
198
+ self.init_embed_(x)
199
+
200
+ embed_ind = self.quantize(x)
201
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
202
+ embed_ind = self.postprocess_emb(embed_ind, shape)
203
+ quantize = self.dequantize(embed_ind)
204
+
205
+ if self.training:
206
+ # We do the expiry of code at that point as buffers are in sync
207
+ # and all the workers will take the same decision.
208
+ self.expire_codes_(x)
209
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
210
+ embed_sum = x.t() @ embed_onehot
211
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
212
+ cluster_size = (
213
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
214
+ * self.cluster_size.sum()
215
+ )
216
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
217
+ self.embed.data.copy_(embed_normalized)
218
+
219
+ return quantize, embed_ind
220
+
221
+
222
+ class VectorQuantization(nn.Module):
223
+ """Vector quantization implementation.
224
+ Currently supports only euclidean distance.
225
+
226
+ Args:
227
+ dim (int): Dimension
228
+ codebook_size (int): Codebook size
229
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
230
+ decay (float): Decay for exponential moving average over the codebooks.
231
+ epsilon (float): Epsilon value for numerical stability.
232
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
233
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
234
+ threshold_ema_dead_code (int):
235
+ channels_last (bool): Channels are the last dimension in the input tensors.
236
+ commitment_weight (float): Weight for commitment loss.
237
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
238
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
239
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
240
+ for orthogonal regularization.
241
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
242
+ that have an exponential moving average cluster size less than the specified threshold with
243
+ randomly selected vector from the current batch.
244
+ """
245
+ def __init__(
246
+ self,
247
+ dim: int,
248
+ codebook_size: int,
249
+ codebook_dim: tp.Optional[int] = None,
250
+ decay: float = 0.8,
251
+ epsilon: float = 1e-5,
252
+ kmeans_init: bool = False,
253
+ kmeans_iters: int = 10,
254
+ threshold_ema_dead_code: int = 2,
255
+ channels_last: bool = False,
256
+ commitment_weight: float = 1.,
257
+ orthogonal_reg_weight: float = 0.0,
258
+ orthogonal_reg_active_codes_only: bool = False,
259
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
260
+ ):
261
+ super().__init__()
262
+ _codebook_dim: int = default(codebook_dim, dim)
263
+
264
+ requires_projection = _codebook_dim != dim
265
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
266
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
267
+
268
+ self.epsilon = epsilon
269
+ self.commitment_weight = commitment_weight
270
+
271
+ self.orthogonal_reg_weight = orthogonal_reg_weight
272
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
273
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
274
+
275
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
276
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
277
+ decay=decay, epsilon=epsilon,
278
+ threshold_ema_dead_code=threshold_ema_dead_code)
279
+ self.codebook_size = codebook_size
280
+
281
+ self.channels_last = channels_last
282
+
283
+ @property
284
+ def codebook(self):
285
+ return self._codebook.embed
286
+
287
+ @property
288
+ def inited(self):
289
+ return self._codebook.inited
290
+
291
+ def _preprocess(self, x):
292
+ if not self.channels_last:
293
+ x = rearrange(x, "b d n -> b n d")
294
+ return x
295
+
296
+ def _postprocess(self, quantize):
297
+ if not self.channels_last:
298
+ quantize = rearrange(quantize, "b n d -> b d n")
299
+ return quantize
300
+
301
+ def encode(self, x):
302
+ x = self._preprocess(x)
303
+ x = self.project_in(x)
304
+ embed_in = self._codebook.encode(x)
305
+ return embed_in
306
+
307
+ def decode(self, embed_ind):
308
+ quantize = self._codebook.decode(embed_ind)
309
+ quantize = self.project_out(quantize)
310
+ quantize = self._postprocess(quantize)
311
+ return quantize
312
+
313
+ def forward(self, x):
314
+ device = x.device
315
+ x = self._preprocess(x)
316
+
317
+ x = self.project_in(x)
318
+ quantize, embed_ind = self._codebook(x)
319
+
320
+ if self.training:
321
+ quantize = x + (quantize - x).detach()
322
+
323
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
324
+
325
+ if self.training:
326
+ if self.commitment_weight > 0:
327
+ commit_loss = F.mse_loss(quantize.detach(), x)
328
+ loss = loss + commit_loss * self.commitment_weight
329
+
330
+ if self.orthogonal_reg_weight > 0:
331
+ codebook = self.codebook
332
+
333
+ if self.orthogonal_reg_active_codes_only:
334
+ # only calculate orthogonal loss for the activated codes for this batch
335
+ unique_code_ids = torch.unique(embed_ind)
336
+ codebook = codebook[unique_code_ids]
337
+
338
+ num_codes = codebook.shape[0]
339
+ if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
340
+ rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
341
+ codebook = codebook[rand_ids]
342
+
343
+ orthogonal_reg_loss = orthogonal_loss_fn(codebook)
344
+ loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
345
+
346
+ quantize = self.project_out(quantize)
347
+ quantize = self._postprocess(quantize)
348
+
349
+ return quantize, embed_ind, loss
350
+
351
+
352
+ class ResidualVectorQuantization(nn.Module):
353
+ """Residual vector quantization implementation.
354
+
355
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
356
+ """
357
+ def __init__(self, *, num_quantizers, **kwargs):
358
+ super().__init__()
359
+ self.layers = nn.ModuleList(
360
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
361
+ )
362
+
363
+ def forward(self, x, n_q: tp.Optional[int] = None):
364
+ quantized_out = 0.0
365
+ residual = x
366
+
367
+ all_losses = []
368
+ all_indices = []
369
+
370
+ n_q = n_q or len(self.layers)
371
+
372
+ for i, layer in enumerate(self.layers[:n_q]):
373
+ quantized, indices, loss = layer(residual)
374
+ quantized = quantized.detach()
375
+ residual = residual - quantized
376
+ quantized_out = quantized_out + quantized
377
+ all_indices.append(indices)
378
+ all_losses.append(loss)
379
+
380
+ if self.training:
381
+ # Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
382
+ quantized_out = x + (quantized_out - x).detach()
383
+
384
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
385
+ return quantized_out, out_indices, out_losses
386
+
387
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
388
+ residual = x
389
+ all_indices = []
390
+ n_q = n_q or len(self.layers)
391
+ for layer in self.layers[:n_q]:
392
+ indices = layer.encode(residual)
393
+ quantized = layer.decode(indices)
394
+ residual = residual - quantized
395
+ all_indices.append(indices)
396
+ out_indices = torch.stack(all_indices)
397
+ return out_indices
398
+
399
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
400
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
401
+ for i, indices in enumerate(q_indices):
402
+ layer = self.layers[i]
403
+ quantized = layer.decode(indices)
404
+ quantized_out = quantized_out + quantized
405
+ return quantized_out
audiocraft/quantization/vq.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import torch
11
+
12
+ from .base import BaseQuantizer, QuantizedResult
13
+ from .core_vq import ResidualVectorQuantization
14
+
15
+
16
+ class ResidualVectorQuantizer(BaseQuantizer):
17
+ """Residual Vector Quantizer.
18
+
19
+ Args:
20
+ dimension (int): Dimension of the codebooks.
21
+ n_q (int): Number of residual vector quantizers used.
22
+ q_dropout (bool): Random quantizer drop out at train time.
23
+ bins (int): Codebook size.
24
+ decay (float): Decay for exponential moving average over the codebooks.
25
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
26
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
27
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
28
+ that have an exponential moving average cluster size less than the specified threshold with
29
+ randomly selected vector from the current batch.
30
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
31
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
32
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
33
+ for orthogonal regularization.
34
+ """
35
+ def __init__(
36
+ self,
37
+ dimension: int = 256,
38
+ n_q: int = 8,
39
+ q_dropout: bool = False,
40
+ bins: int = 1024,
41
+ decay: float = 0.99,
42
+ kmeans_init: bool = True,
43
+ kmeans_iters: int = 10,
44
+ threshold_ema_dead_code: int = 2,
45
+ orthogonal_reg_weight: float = 0.0,
46
+ orthogonal_reg_active_codes_only: bool = False,
47
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
48
+ ):
49
+ super().__init__()
50
+ self.max_n_q = n_q
51
+ self.n_q = n_q
52
+ self.q_dropout = q_dropout
53
+ self.dimension = dimension
54
+ self.bins = bins
55
+ self.decay = decay
56
+ self.kmeans_init = kmeans_init
57
+ self.kmeans_iters = kmeans_iters
58
+ self.threshold_ema_dead_code = threshold_ema_dead_code
59
+ self.orthogonal_reg_weight = orthogonal_reg_weight
60
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
61
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
62
+ self.vq = ResidualVectorQuantization(
63
+ dim=self.dimension,
64
+ codebook_size=self.bins,
65
+ num_quantizers=self.n_q,
66
+ decay=self.decay,
67
+ kmeans_init=self.kmeans_init,
68
+ kmeans_iters=self.kmeans_iters,
69
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
70
+ orthogonal_reg_weight=self.orthogonal_reg_weight,
71
+ orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
72
+ orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
73
+ channels_last=False
74
+ )
75
+
76
+ def forward(self, x: torch.Tensor, frame_rate: int):
77
+ n_q = self.n_q
78
+ if self.training and self.q_dropout:
79
+ n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
80
+ bw_per_q = math.log2(self.bins) * frame_rate / 1000
81
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
82
+ codes = codes.transpose(0, 1)
83
+ # codes is [B, K, T], with T frames, K nb of codebooks.
84
+ bw = torch.tensor(n_q * bw_per_q).to(x)
85
+ return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
86
+
87
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
88
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
89
+ The RVQ encode method sets the appropriate number of quantizer to use
90
+ and returns indices for each quantizer.
91
+ """
92
+ n_q = self.n_q
93
+ codes = self.vq.encode(x, n_q=n_q)
94
+ codes = codes.transpose(0, 1)
95
+ # codes is [B, K, T], with T frames, K nb of codebooks.
96
+ return codes
97
+
98
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
99
+ """Decode the given codes to the quantized representation."""
100
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
101
+ codes = codes.transpose(0, 1)
102
+ quantized = self.vq.decode(codes)
103
+ return quantized
104
+
105
+ @property
106
+ def total_codebooks(self):
107
+ return self.max_n_q
108
+
109
+ @property
110
+ def num_codebooks(self):
111
+ return self.n_q
112
+
113
+ def set_num_codebooks(self, n: int):
114
+ assert n > 0 and n <= self.max_n_q
115
+ self.n_q = n