File size: 3,175 Bytes
91b7cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Optional
import torch
import torch.nn.functional as F

from fp12 import to_fp12, fp12_to_fp16, FP12_MAX


def get_param(data: torch.Tensor):
    if FP12_MAX <= data.abs().max():
        print('[WARN] max(abs(data)) >= FP12_MAX')
    
    exp, frac = to_fp12(data)
    
    exp.requires_grad_(False)
    frac.requires_grad_(False)
    
    exp = torch.nn.Parameter(exp, requires_grad=False)
    frac = torch.nn.Parameter(frac, requires_grad=False)
    
    return exp, frac


class Linear(torch.nn.Module):
    def __init__(self, base: torch.nn.Linear) -> None:
        super().__init__()
        self.weight = get_param(base.weight)
        self.weight_shape = base.weight.shape
        if base.bias is not None:
            self.bias = get_param(base.bias)
            self.bias_shape = base.bias.shape
        else:
            self.bias = None
            self.bias_shape = None
        self.to(base.weight.device)
    
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        weight = fp12_to_fp16(*self.weight).reshape(self.weight_shape)
        bias = fp12_to_fp16(*self.bias).reshape(self.bias_shape) if self.bias else None
        return F.linear(x, weight, bias)
    
    def _apply(self, fn, recurse=True):
        super()._apply(fn, recurse)
        self.weight = [fn(p) for p in self.weight]
        if self.bias:
            self.bias = [fn(p) for p in self.bias]
        return self


class Conv2d(torch.nn.Module):
    def __init__(self, base: torch.nn.Conv2d):
        super().__init__()
        self.weight = get_param(base.weight)
        self.weight_shape = base.weight.shape
        if base.bias is not None:
            self.bias = get_param(base.bias)
            self.bias_shape = base.bias.shape
        else:
            self.bias = None
            self.bias_shape = None
        
        self.padding_mode = base.padding_mode
        self._reversed_padding_repeated_twice = base._reversed_padding_repeated_twice
        self.stride = base.stride
        self.dilation = base.dilation
        self.groups = base.groups
        self.padding = base.padding
        
        self.to(base.weight.device)
    
    def _conv_forward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                            weight, bias, self.stride,
                            (0, 0), self.dilation, self.groups)
        return F.conv2d(input, weight, bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        weight = fp12_to_fp16(*self.weight).reshape(self.weight_shape)
        bias = fp12_to_fp16(*self.bias).reshape(self.bias_shape) if self.bias else None
        return self._conv_forward(x, weight, bias)
    
    def _apply(self, fn, recurse=True):
        super()._apply(fn, recurse)
        self.weight = [fn(p) for p in self.weight]
        if self.bias:
            self.bias = [fn(p) for p in self.bias]
        return self