File size: 7,497 Bytes
2f6628d
 
fd48f4d
2f6628d
fd48f4d
2f6628d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
###########################################################################
# NLP demo software by HyperbeeAI.                                        #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai #
###########################################################################
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai"
print("imported layers.py")
print(license_statement)
print("")

import torch, sys
import torch.nn as nn
import numpy as np
from torch.autograd import Function
from functions import quantization, clamping_hw, linear_functional

class ai85_base(nn.Module):
    def __init__(
            self,
            operation_module  = None,
            operation_fcnl    = None,
            activation_module = None,
            output_width_30b  = False
    ):
        super().__init__()
        self.op               = operation_module
        self.op_fcn           = operation_fcnl
        self.act              = activation_module
        self.wide             = output_width_30b
        self.quantize_Q_d_8b    = None
        self.quantize_Q_u_wb    = None
        self.quantize_Q_d_wide  = None
        self.clamp_C_hw_8b      = None
        self.clamp_C_hw_wide    = None
        self.output_shift        = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) 
        self.weight_bits         = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False) 
        self.bias_bits           = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False) 
        self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) 
        self.adjust_output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) 
        self.shift_quantile      = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False)
        weight_bits      = self.weight_bits
        bias_bits        = self.bias_bits
        shift_quantile   = self.shift_quantile
        self.configure_layer_base( weight_bits, bias_bits, shift_quantile )

    def configure_layer_base(self, weight_bits, bias_bits, shift_quantile):
        self.quantize_Q_d_8b    = quantization(xb = 8,           mode ='down'   , wide=False) # 8 here is activation bits
        self.quantize_Q_u_wb    = quantization(xb = weight_bits, mode ='up'     , wide=False)
        self.quantize_Q_d_wide  = quantization(xb = 8,           mode ='down'   , wide=True)  # 8 here is activation bits, but its wide, so check inside
        self.clamp_C_hw_8b    = clamping_hw(xb = 8,           wide=False) # 8 here is activation bits
        self.clamp_C_hw_wide  = clamping_hw(xb = None,        wide=True)  # None to avoid misleading info on the # of bits, check inside
        self.weight_bits     = nn.Parameter(torch.Tensor([ weight_bits    ]), requires_grad=False)
        self.bias_bits       = nn.Parameter(torch.Tensor([ bias_bits      ]), requires_grad=False)
        self.shift_quantile  = nn.Parameter(torch.Tensor([ shift_quantile ]), requires_grad=False)

    def forward(self, x):
        w = self.op.weight
        b = self.op.bias
        los  = self.output_shift
        s_o  = 2**(los)
        w_q = self.quantize_Q_u_wb(w);
        b_q = self.quantize_Q_u_wb(b);

        x = self.op_fcn(x, w_q, b_q, self.op.stride, self.op.padding) # convolution / linear
        x = x*s_o
        if(self.act is not None):
            x = self.act(x)
        if((self.wide) and (self.act is None)):
            x = self.quantize_Q_d_wide(x)
            x = self.clamp_C_hw_wide(x)
            ### The +5 here is the 5 fractional bits the chip adds to the number in wide mode
            ### we divide the number back here to get it back into range. ai8x-training does not do this for some reason 
            ### until the synthesis/deployment phase, and they do a +1 bit, why?
            x = x / (2**(5)); # this is simulation of chip behavior
            x = x / 128.0     # this is ours, for convenience + this part is done outside the chip since it's the step before table lookup
            x = x / 2.0;      # this is ours, for convenience + this part is done outside the chip since it's the step before table lookup
        else:
            x = self.quantize_Q_d_8b(x)
            x = self.clamp_C_hw_8b(x)

        return x

class ai85_conv1d(ai85_base):
    def __init__(
            self,
            C_in_channels      = None,
            D_out_channels     = None,
            K_kernel_dimension = None,
            padding            = 0,
            activation         = None,
            output_width_30b   = False,
    ):

        if(activation is None):
            activation_fcn = None;
        elif(activation == 'relu'):
            activation_fcn = nn.ReLU(inplace=True);
        else:
            print('wrong activation type in model. only {relu} is acceptable. exiting')
            sys.exit()

        operation_mdl  = nn.Conv1d(C_in_channels, D_out_channels, kernel_size=K_kernel_dimension, stride=1, padding=padding, bias=True);
        operation_fcn  = nn.functional.conv1d

        super().__init__(
            activation_module  = activation_fcn,
            operation_module   = operation_mdl,
            operation_fcnl     = operation_fcn,
            output_width_30b   = output_width_30b,
        )

class ai85_add(nn.Module):
    def __init__(self ):
        super().__init__()
        self.clamp_C_hw_8b    = clamping_hw( xb = 8, wide=False) # 8 here is activation bits

    def forward(self, x, res):
        x = self.clamp_C_hw_8b(x+res)
        return x

class ai85_fullyconnected(ai85_base):
    def __init__(
            self,
            in_features        = None,
            out_features       = None,
            activation         = None,    
            output_width_30b   = False):
 
        if(activation is None):
            activation_fcn = None;
        elif(activation == 'relu'):
            activation_fcn = nn.ReLU(inplace=True);
        else:
            print('wrong activation type in model. only {relu} is acceptable. exiting')
            sys.exit()

        operation_mdl  = nn.Linear(in_features, out_features, bias=True);
        operation_fcn  = linear_functional

        super().__init__(
            activation_module  = activation_fcn,
            operation_module   = operation_mdl,
            operation_fcnl     = operation_fcn,
            output_width_30b   = output_width_30b
        )
        # Define dummy arguments to make Linear and conv compatible in ai85_base.
        # the name "op" here refers to op in super, i.e., in base_layer
        self.op.stride = None
        self.op.padding = None

class lpre(nn.Module):
    def __init__(self):
        super().__init__()
        self.ee1 = nn.Embedding(16384, 64)
        self.ee2 = nn.Embedding(48, 64)
        self.quantize  = quantization(xb = 8, mode ='updown', wide=False)

    def forward(self, x, sp1, sp2, sb):
        pp= torch.arange(sp1, sp2).unsqueeze(0).repeat(sb, 1).to(x.device)
        ee2_d = self.ee2(pp)
        ee1_d = self.ee1(x)
        ed = ee1_d + ee2_d
        min_w = self.ee2.weight.data.min() + self.ee1.weight.data.min()
        max_w = self.ee2.weight.data.max() + self.ee1.weight.data.max()
        t = (ed - min_w) / (max_w - min_w)
        t = t.add(-0.5).mul(2.0)
        t = self.quantize(t)
        t = t.clamp(min= -1.0, max=1.0-(1.0/128.0))
        t = t.mul(2**(8-1)).add(0.5).floor().clamp(min=-128, max=127)
        return t.permute(0, 2, 1)