| from typing import Any, Optional | |
| import torch | |
| from torch import nn | |
| from torch.ao.quantization import QConfig | |
| __all__ = ["QuantStub", "DeQuantStub", "QuantWrapper"] | |
| class QuantStub(nn.Module): | |
| r"""Quantize stub module, before calibration, this is same as an observer, | |
| it will be swapped as `nnq.Quantize` in `convert`. | |
| Args: | |
| qconfig: quantization configuration for the tensor, | |
| if qconfig is not provided, we will get qconfig from parent modules | |
| """ | |
| def __init__(self, qconfig: Optional[QConfig] = None): | |
| super().__init__() | |
| if qconfig: | |
| self.qconfig = qconfig | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x | |
| class DeQuantStub(nn.Module): | |
| r"""Dequantize stub module, before calibration, this is same as identity, | |
| this will be swapped as `nnq.DeQuantize` in `convert`. | |
| Args: | |
| qconfig: quantization configuration for the tensor, | |
| if qconfig is not provided, we will get qconfig from parent modules | |
| """ | |
| def __init__(self, qconfig: Optional[Any] = None): | |
| super().__init__() | |
| if qconfig: | |
| self.qconfig = qconfig | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x | |
| class QuantWrapper(nn.Module): | |
| r"""A wrapper class that wraps the input module, adds QuantStub and | |
| DeQuantStub and surround the call to module with call to quant and dequant | |
| modules. | |
| This is used by the `quantization` utility functions to add the quant and | |
| dequant modules, before `convert` function `QuantStub` will just be observer, | |
| it observes the input tensor, after `convert`, `QuantStub` | |
| will be swapped to `nnq.Quantize` which does actual quantization. Similarly | |
| for `DeQuantStub`. | |
| """ | |
| quant: QuantStub | |
| dequant: DeQuantStub | |
| module: nn.Module | |
| def __init__(self, module: nn.Module): | |
| super().__init__() | |
| qconfig = getattr(module, "qconfig", None) | |
| self.add_module("quant", QuantStub(qconfig)) | |
| self.add_module("dequant", DeQuantStub(qconfig)) | |
| self.add_module("module", module) | |
| self.train(module.training) | |
| def forward(self, X: torch.Tensor) -> torch.Tensor: | |
| X = self.quant(X) | |
| X = self.module(X) | |
| return self.dequant(X) | |