from typing import Optional import torch import torch.nn.functional as F from fp12 import to_fp12, fp12_to_fp16, FP12_MAX def get_param(data: torch.Tensor): if FP12_MAX <= data.abs().max(): print('[WARN] max(abs(data)) >= FP12_MAX') exp, frac = to_fp12(data) exp.requires_grad_(False) frac.requires_grad_(False) exp = torch.nn.Parameter(exp, requires_grad=False) frac = torch.nn.Parameter(frac, requires_grad=False) return exp, frac class Linear(torch.nn.Module): def __init__(self, base: torch.nn.Linear) -> None: super().__init__() self.weight = get_param(base.weight) self.weight_shape = base.weight.shape if base.bias is not None: self.bias = get_param(base.bias) self.bias_shape = base.bias.shape else: self.bias = None self.bias_shape = None self.to(base.weight.device) def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: weight = fp12_to_fp16(*self.weight).reshape(self.weight_shape) bias = fp12_to_fp16(*self.bias).reshape(self.bias_shape) if self.bias else None return F.linear(x, weight, bias) def _apply(self, fn, recurse=True): super()._apply(fn, recurse) self.weight = [fn(p) for p in self.weight] if self.bias: self.bias = [fn(p) for p in self.bias] return self class Conv2d(torch.nn.Module): def __init__(self, base: torch.nn.Conv2d): super().__init__() self.weight = get_param(base.weight) self.weight_shape = base.weight.shape if base.bias is not None: self.bias = get_param(base.bias) self.bias_shape = base.bias.shape else: self.bias = None self.bias_shape = None self.padding_mode = base.padding_mode self._reversed_padding_repeated_twice = base._reversed_padding_repeated_twice self.stride = base.stride self.dilation = base.dilation self.groups = base.groups self.padding = base.padding self.to(base.weight.device) def _conv_forward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]): if self.padding_mode != 'zeros': return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, (0, 0), self.dilation, self.groups) return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: weight = fp12_to_fp16(*self.weight).reshape(self.weight_shape) bias = fp12_to_fp16(*self.bias).reshape(self.bias_shape) if self.bias else None return self._conv_forward(x, weight, bias) def _apply(self, fn, recurse=True): super()._apply(fn, recurse) self.weight = [fn(p) for p in self.weight] if self.bias: self.bias = [fn(p) for p in self.bias] return self