| import torch |
| import numpy as np |
| from typing import Literal |
| from sentence_transformers.models import Module |
|
|
|
|
| class Quantizer(torch.nn.Module): |
| def __init__(self, hard: bool = True): |
| """ |
| Args: |
| hard: Whether to use hard or soft quantization. Defaults to True. |
| """ |
| super().__init__() |
| self._hard = hard |
|
|
| def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| def forward(self, x, *args, **kwargs) -> torch.Tensor: |
| soft = self._soft_quantize(x, *args, **kwargs) |
|
|
| if not self._hard: |
| result = soft |
| else: |
| result = ( |
| self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach() |
| ) |
|
|
| return result |
|
|
|
|
| class Int8TanhQuantizer(Quantizer): |
| def __init__( |
| self, |
| hard: bool = True, |
| ): |
| super().__init__(hard=hard) |
| self.qmin = -128 |
| self.qmax = 127 |
|
|
| def _soft_quantize(self, x, *args, **kwargs): |
| return torch.tanh(x) |
|
|
| def _hard_quantize(self, x, *args, **kwargs): |
| soft = self._soft_quantize(x) |
| int_x = torch.round(soft * self.qmax) |
| int_x = torch.clamp(int_x, self.qmin, self.qmax) |
| return int_x |
|
|
|
|
| class BinaryTanhQuantizer(Quantizer): |
| def __init__( |
| self, |
| hard: bool = True, |
| scale: float = 1.0, |
| ): |
| super().__init__(hard) |
| self._scale = scale |
|
|
| def _soft_quantize(self, x, *args, **kwargs): |
| return torch.tanh(self._scale * x) |
|
|
| def _hard_quantize(self, x, *args, **kwargs): |
| return torch.where(x >= 0, 1.0, -1.0) |
| |
|
|
| class PackedBinaryQuantizer: |
| def __call__(self, x: torch.Tensor) -> torch.Tensor: |
| bits = np.where(x.cpu().numpy() >= 0, True, False) |
| packed = np.packbits(bits, axis=-1) |
| return torch.from_numpy(packed).to(x.device) |
|
|
|
|
| class FlexibleQuantizer(Module): |
| def __init__(self): |
| super().__init__() |
| self._int8_quantizer = Int8TanhQuantizer() |
| self._binary_quantizer = BinaryTanhQuantizer() |
| self._packed_binary_quantizer = PackedBinaryQuantizer() |
|
|
| def forward( |
| self, |
| features: dict[str, torch.Tensor], |
| quantization: Literal["int8", "binary", "ubinary"] = "int8", |
| **kwargs |
| ) -> dict[str, torch.Tensor]: |
| if quantization == "int8": |
| features["sentence_embedding"] = self._int8_quantizer( |
| features["sentence_embedding"] |
| ) |
| elif quantization == "binary": |
| features["sentence_embedding"] = self._binary_quantizer( |
| features["sentence_embedding"] |
| ) |
| elif quantization == "ubinary": |
| features["sentence_embedding"] = self._packed_binary_quantizer( |
| features["sentence_embedding"] |
| ) |
| else: |
| raise ValueError( |
| f"Invalid quantization type: {quantization}. Must be 'binary', 'ubinary', or 'int8'." |
| ) |
| return features |
|
|
| @classmethod |
| def load( |
| cls, |
| model_name_or_path: str, |
| subfolder: str = "", |
| token: bool | str | None = None, |
| cache_folder: str | None = None, |
| revision: str | None = None, |
| local_files_only: bool = False, |
| **kwargs, |
| ): |
| return cls() |
| |
| def save(self, output_path: str, *args, **kwargs) -> None: |
| return |
|
|