|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Base class for all quantizers. |
|
""" |
|
|
|
from dataclasses import dataclass, field |
|
import typing as tp |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
@dataclass |
|
class QuantizedResult: |
|
x: torch.Tensor |
|
codes: torch.Tensor |
|
bandwidth: torch.Tensor |
|
penalty: tp.Optional[torch.Tensor] = None |
|
metrics: dict = field(default_factory=dict) |
|
|
|
|
|
class BaseQuantizer(nn.Module): |
|
"""Base class for quantizers. |
|
""" |
|
|
|
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: |
|
""" |
|
Given input tensor x, returns first the quantized (or approximately quantized) |
|
representation along with quantized codes, bandwidth, and any penalty term for the loss. |
|
Finally, this returns a dict of metrics to update logging etc. |
|
Frame rate must be passed so that the bandwidth is properly computed. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def encode(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Encode a given input tensor with the specified sample rate at the given bandwidth. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def decode(self, codes: torch.Tensor) -> torch.Tensor: |
|
"""Decode the given codes to the quantized representation. |
|
""" |
|
raise NotImplementedError() |
|
|
|
@property |
|
def total_codebooks(self): |
|
"""Total number of codebooks. |
|
""" |
|
raise NotImplementedError() |
|
|
|
@property |
|
def num_codebooks(self): |
|
"""Number of active codebooks. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def set_num_codebooks(self, n: int): |
|
"""Set the number of active codebooks. |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
class DummyQuantizer(BaseQuantizer): |
|
"""Fake quantizer that actually does not perform any quantization. |
|
""" |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x: torch.Tensor, frame_rate: int): |
|
q = x.unsqueeze(1) |
|
return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) |
|
|
|
def encode(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Encode a given input tensor with the specified sample rate at the given bandwidth. |
|
In the case of the DummyQuantizer, the codes are actually identical |
|
to the input and resulting quantized representation as no quantization is done. |
|
""" |
|
return x.unsqueeze(1) |
|
|
|
def decode(self, codes: torch.Tensor) -> torch.Tensor: |
|
"""Decode the given codes to the quantized representation. |
|
In the case of the DummyQuantizer, the codes are actually identical |
|
to the input and resulting quantized representation as no quantization is done. |
|
""" |
|
return codes.squeeze(1) |
|
|
|
@property |
|
def total_codebooks(self): |
|
"""Total number of codebooks. |
|
""" |
|
return 1 |
|
|
|
@property |
|
def num_codebooks(self): |
|
"""Total number of codebooks. |
|
""" |
|
return self.total_codebooks |
|
|
|
def set_num_codebooks(self, n: int): |
|
"""Set the number of active codebooks. |
|
""" |
|
raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") |
|
|