|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from abc import ABC, abstractmethod | 
					
						
						|  | import typing as tp | 
					
						
						|  |  | 
					
						
						|  | from einops import rearrange | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  |  | 
					
						
						|  | from .. import quantization as qt | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CompressionModel(ABC, nn.Module): | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | 
					
						
						|  | """See `EncodecModel.encode`""" | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): | 
					
						
						|  | """See `EncodecModel.decode`""" | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def channels(self) -> int: | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def frame_rate(self) -> int: | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def sample_rate(self) -> int: | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def cardinality(self) -> int: | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def num_codebooks(self) -> int: | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def total_codebooks(self) -> int: | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def set_num_codebooks(self, n: int): | 
					
						
						|  | """Set the active number of codebooks used by the quantizer. | 
					
						
						|  | """ | 
					
						
						|  | ... | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class EncodecModel(CompressionModel): | 
					
						
						|  | """Encodec model operating on the raw waveform. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | encoder (nn.Module): Encoder network. | 
					
						
						|  | decoder (nn.Module): Decoder network. | 
					
						
						|  | quantizer (qt.BaseQuantizer): Quantizer network. | 
					
						
						|  | frame_rate (int): Frame rate for the latent representation. | 
					
						
						|  | sample_rate (int): Audio sample rate. | 
					
						
						|  | channels (int): Number of audio channels. | 
					
						
						|  | causal (bool): Whether to use a causal version of the model. | 
					
						
						|  | renormalize (bool): Whether to renormalize the audio before running the model. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | frame_rate: int = 0 | 
					
						
						|  | sample_rate: int = 0 | 
					
						
						|  | channels: int = 0 | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, | 
					
						
						|  | encoder: nn.Module, | 
					
						
						|  | decoder: nn.Module, | 
					
						
						|  | quantizer: qt.BaseQuantizer, | 
					
						
						|  | frame_rate: int, | 
					
						
						|  | sample_rate: int, | 
					
						
						|  | channels: int, | 
					
						
						|  | causal: bool = False, | 
					
						
						|  | renormalize: bool = False): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.encoder = encoder | 
					
						
						|  | self.decoder = decoder | 
					
						
						|  | self.quantizer = quantizer | 
					
						
						|  | self.frame_rate = frame_rate | 
					
						
						|  | self.sample_rate = sample_rate | 
					
						
						|  | self.channels = channels | 
					
						
						|  | self.renormalize = renormalize | 
					
						
						|  | self.causal = causal | 
					
						
						|  | if self.causal: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | assert not self.renormalize, 'Causal model does not support renormalize' | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def total_codebooks(self): | 
					
						
						|  | """Total number of quantizer codebooks available. | 
					
						
						|  | """ | 
					
						
						|  | return self.quantizer.total_codebooks | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def num_codebooks(self): | 
					
						
						|  | """Active number of codebooks used by the quantizer. | 
					
						
						|  | """ | 
					
						
						|  | return self.quantizer.num_codebooks | 
					
						
						|  |  | 
					
						
						|  | def set_num_codebooks(self, n: int): | 
					
						
						|  | """Set the active number of codebooks used by the quantizer. | 
					
						
						|  | """ | 
					
						
						|  | self.quantizer.set_num_codebooks(n) | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def cardinality(self): | 
					
						
						|  | """Cardinality of each codebook. | 
					
						
						|  | """ | 
					
						
						|  | return self.quantizer.bins | 
					
						
						|  |  | 
					
						
						|  | def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | 
					
						
						|  | scale: tp.Optional[torch.Tensor] | 
					
						
						|  | if self.renormalize: | 
					
						
						|  | mono = x.mean(dim=1, keepdim=True) | 
					
						
						|  | volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() | 
					
						
						|  | scale = 1e-8 + volume | 
					
						
						|  | x = x / scale | 
					
						
						|  | scale = scale.view(-1, 1) | 
					
						
						|  | else: | 
					
						
						|  | scale = None | 
					
						
						|  | return x, scale | 
					
						
						|  |  | 
					
						
						|  | def postprocess(self, | 
					
						
						|  | x: torch.Tensor, | 
					
						
						|  | scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: | 
					
						
						|  | if scale is not None: | 
					
						
						|  | assert self.renormalize | 
					
						
						|  | x = x * scale.view(-1, 1, 1) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: | 
					
						
						|  | assert x.dim() == 3 | 
					
						
						|  | length = x.shape[-1] | 
					
						
						|  | x, scale = self.preprocess(x) | 
					
						
						|  |  | 
					
						
						|  | emb = self.encoder(x) | 
					
						
						|  | q_res = self.quantizer(emb, self.frame_rate) | 
					
						
						|  | out = self.decoder(q_res.x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | assert out.shape[-1] >= length, (out.shape[-1], length) | 
					
						
						|  | out = out[..., :length] | 
					
						
						|  |  | 
					
						
						|  | q_res.x = self.postprocess(out, scale) | 
					
						
						|  |  | 
					
						
						|  | return q_res | 
					
						
						|  |  | 
					
						
						|  | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | 
					
						
						|  | """Encode the given input tensor to quantized representation along with scale parameter. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x (torch.Tensor): Float tensor of shape [B, C, T] | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of: | 
					
						
						|  | codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. | 
					
						
						|  | scale a float tensor containing the scale for audio renormalizealization. | 
					
						
						|  | """ | 
					
						
						|  | assert x.dim() == 3 | 
					
						
						|  | x, scale = self.preprocess(x) | 
					
						
						|  | emb = self.encoder(x) | 
					
						
						|  | codes = self.quantizer.encode(emb) | 
					
						
						|  | return codes, scale | 
					
						
						|  |  | 
					
						
						|  | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): | 
					
						
						|  | """Decode the given codes to a reconstructed representation, using the scale to perform | 
					
						
						|  | audio denormalization if needed. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | codes (torch.Tensor): Int tensor of shape [B, K, T] | 
					
						
						|  | scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. | 
					
						
						|  | """ | 
					
						
						|  | emb = self.quantizer.decode(codes) | 
					
						
						|  | out = self.decoder(emb) | 
					
						
						|  | out = self.postprocess(out, scale) | 
					
						
						|  |  | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlattenedCompressionModel(CompressionModel): | 
					
						
						|  | """Wraps a CompressionModel and flatten its codebooks, e.g. | 
					
						
						|  | instead of returning [B, K, T], return [B, S, T * (K // S)] with | 
					
						
						|  | S the number of codebooks per step, and `K // S` the number of 'virtual steps' | 
					
						
						|  | for each real time step. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | model (CompressionModel): compression model to wrap. | 
					
						
						|  | codebooks_per_step (int): number of codebooks to keep per step, | 
					
						
						|  | this must divide the number of codebooks provided by the wrapped model. | 
					
						
						|  | extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1, | 
					
						
						|  | if each codebook has a cardinality N, then the first codebook will | 
					
						
						|  | use the range [0, N - 1], and the second [N, 2 N - 1] etc. | 
					
						
						|  | On decoding, this can lead to potentially invalid sequences. | 
					
						
						|  | Any invalid entry will be silently remapped to the proper range | 
					
						
						|  | with a modulo. | 
					
						
						|  | """ | 
					
						
						|  | def __init__(self, model: CompressionModel, codebooks_per_step: int = 1, | 
					
						
						|  | extend_cardinality: bool = True): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.model = model | 
					
						
						|  | self.codebooks_per_step = codebooks_per_step | 
					
						
						|  | self.extend_cardinality = extend_cardinality | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def total_codebooks(self): | 
					
						
						|  | return self.model.total_codebooks | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def num_codebooks(self): | 
					
						
						|  | """Active number of codebooks used by the quantizer. | 
					
						
						|  |  | 
					
						
						|  | ..Warning:: this reports the number of codebooks after the flattening | 
					
						
						|  | of the codebooks! | 
					
						
						|  | """ | 
					
						
						|  | assert self.model.num_codebooks % self.codebooks_per_step == 0 | 
					
						
						|  | return self.codebooks_per_step | 
					
						
						|  |  | 
					
						
						|  | def set_num_codebooks(self, n: int): | 
					
						
						|  | """Set the active number of codebooks used by the quantizer. | 
					
						
						|  |  | 
					
						
						|  | ..Warning:: this sets the number of codebooks **before** the flattening | 
					
						
						|  | of the codebooks. | 
					
						
						|  | """ | 
					
						
						|  | assert n % self.codebooks_per_step == 0 | 
					
						
						|  | self.model.set_num_codebooks(n) | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def num_virtual_steps(self) -> int: | 
					
						
						|  | """Return the number of virtual steps, e.g. one real step | 
					
						
						|  | will be split into that many steps. | 
					
						
						|  | """ | 
					
						
						|  | return self.model.num_codebooks // self.codebooks_per_step | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def frame_rate(self) -> int: | 
					
						
						|  | return self.model.frame_rate * self.num_virtual_steps | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def sample_rate(self) -> int: | 
					
						
						|  | return self.model.sample_rate | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def channels(self) -> int: | 
					
						
						|  | return self.model.channels | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def cardinality(self): | 
					
						
						|  | """Cardinality of each codebook. | 
					
						
						|  | """ | 
					
						
						|  | if self.extend_cardinality: | 
					
						
						|  | return self.model.cardinality * self.num_virtual_steps | 
					
						
						|  | else: | 
					
						
						|  | return self.model.cardinality | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: | 
					
						
						|  | raise NotImplementedError("Not supported, use encode and decode.") | 
					
						
						|  |  | 
					
						
						|  | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | 
					
						
						|  | indices, scales = self.model.encode(x) | 
					
						
						|  | B, K, T = indices.shape | 
					
						
						|  | indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step) | 
					
						
						|  | if self.extend_cardinality: | 
					
						
						|  | for virtual_step in range(1, self.num_virtual_steps): | 
					
						
						|  | indices[..., virtual_step] += self.model.cardinality * virtual_step | 
					
						
						|  | indices = rearrange(indices, 'b k t v -> b k (t v)') | 
					
						
						|  | return (indices, scales) | 
					
						
						|  |  | 
					
						
						|  | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): | 
					
						
						|  | B, K, T = codes.shape | 
					
						
						|  | assert T % self.num_virtual_steps == 0 | 
					
						
						|  | codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | codes = codes % self.model.cardinality | 
					
						
						|  | return self.model.decode(codes, scale) | 
					
						
						|  |  |