File size: 7,497 Bytes
2f6628d fd48f4d 2f6628d fd48f4d 2f6628d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
###########################################################################
# NLP demo software by HyperbeeAI. #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai #
###########################################################################
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai"
print("imported layers.py")
print(license_statement)
print("")
import torch, sys
import torch.nn as nn
import numpy as np
from torch.autograd import Function
from functions import quantization, clamping_hw, linear_functional
class ai85_base(nn.Module):
def __init__(
self,
operation_module = None,
operation_fcnl = None,
activation_module = None,
output_width_30b = False
):
super().__init__()
self.op = operation_module
self.op_fcn = operation_fcnl
self.act = activation_module
self.wide = output_width_30b
self.quantize_Q_d_8b = None
self.quantize_Q_u_wb = None
self.quantize_Q_d_wide = None
self.clamp_C_hw_8b = None
self.clamp_C_hw_wide = None
self.output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False)
self.weight_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False)
self.bias_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False)
self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False)
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False)
self.shift_quantile = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False)
weight_bits = self.weight_bits
bias_bits = self.bias_bits
shift_quantile = self.shift_quantile
self.configure_layer_base( weight_bits, bias_bits, shift_quantile )
def configure_layer_base(self, weight_bits, bias_bits, shift_quantile):
self.quantize_Q_d_8b = quantization(xb = 8, mode ='down' , wide=False) # 8 here is activation bits
self.quantize_Q_u_wb = quantization(xb = weight_bits, mode ='up' , wide=False)
self.quantize_Q_d_wide = quantization(xb = 8, mode ='down' , wide=True) # 8 here is activation bits, but its wide, so check inside
self.clamp_C_hw_8b = clamping_hw(xb = 8, wide=False) # 8 here is activation bits
self.clamp_C_hw_wide = clamping_hw(xb = None, wide=True) # None to avoid misleading info on the # of bits, check inside
self.weight_bits = nn.Parameter(torch.Tensor([ weight_bits ]), requires_grad=False)
self.bias_bits = nn.Parameter(torch.Tensor([ bias_bits ]), requires_grad=False)
self.shift_quantile = nn.Parameter(torch.Tensor([ shift_quantile ]), requires_grad=False)
def forward(self, x):
w = self.op.weight
b = self.op.bias
los = self.output_shift
s_o = 2**(los)
w_q = self.quantize_Q_u_wb(w);
b_q = self.quantize_Q_u_wb(b);
x = self.op_fcn(x, w_q, b_q, self.op.stride, self.op.padding) # convolution / linear
x = x*s_o
if(self.act is not None):
x = self.act(x)
if((self.wide) and (self.act is None)):
x = self.quantize_Q_d_wide(x)
x = self.clamp_C_hw_wide(x)
### The +5 here is the 5 fractional bits the chip adds to the number in wide mode
### we divide the number back here to get it back into range. ai8x-training does not do this for some reason
### until the synthesis/deployment phase, and they do a +1 bit, why?
x = x / (2**(5)); # this is simulation of chip behavior
x = x / 128.0 # this is ours, for convenience + this part is done outside the chip since it's the step before table lookup
x = x / 2.0; # this is ours, for convenience + this part is done outside the chip since it's the step before table lookup
else:
x = self.quantize_Q_d_8b(x)
x = self.clamp_C_hw_8b(x)
return x
class ai85_conv1d(ai85_base):
def __init__(
self,
C_in_channels = None,
D_out_channels = None,
K_kernel_dimension = None,
padding = 0,
activation = None,
output_width_30b = False,
):
if(activation is None):
activation_fcn = None;
elif(activation == 'relu'):
activation_fcn = nn.ReLU(inplace=True);
else:
print('wrong activation type in model. only {relu} is acceptable. exiting')
sys.exit()
operation_mdl = nn.Conv1d(C_in_channels, D_out_channels, kernel_size=K_kernel_dimension, stride=1, padding=padding, bias=True);
operation_fcn = nn.functional.conv1d
super().__init__(
activation_module = activation_fcn,
operation_module = operation_mdl,
operation_fcnl = operation_fcn,
output_width_30b = output_width_30b,
)
class ai85_add(nn.Module):
def __init__(self ):
super().__init__()
self.clamp_C_hw_8b = clamping_hw( xb = 8, wide=False) # 8 here is activation bits
def forward(self, x, res):
x = self.clamp_C_hw_8b(x+res)
return x
class ai85_fullyconnected(ai85_base):
def __init__(
self,
in_features = None,
out_features = None,
activation = None,
output_width_30b = False):
if(activation is None):
activation_fcn = None;
elif(activation == 'relu'):
activation_fcn = nn.ReLU(inplace=True);
else:
print('wrong activation type in model. only {relu} is acceptable. exiting')
sys.exit()
operation_mdl = nn.Linear(in_features, out_features, bias=True);
operation_fcn = linear_functional
super().__init__(
activation_module = activation_fcn,
operation_module = operation_mdl,
operation_fcnl = operation_fcn,
output_width_30b = output_width_30b
)
# Define dummy arguments to make Linear and conv compatible in ai85_base.
# the name "op" here refers to op in super, i.e., in base_layer
self.op.stride = None
self.op.padding = None
class lpre(nn.Module):
def __init__(self):
super().__init__()
self.ee1 = nn.Embedding(16384, 64)
self.ee2 = nn.Embedding(48, 64)
self.quantize = quantization(xb = 8, mode ='updown', wide=False)
def forward(self, x, sp1, sp2, sb):
pp= torch.arange(sp1, sp2).unsqueeze(0).repeat(sb, 1).to(x.device)
ee2_d = self.ee2(pp)
ee1_d = self.ee1(x)
ed = ee1_d + ee2_d
min_w = self.ee2.weight.data.min() + self.ee1.weight.data.min()
max_w = self.ee2.weight.data.max() + self.ee1.weight.data.max()
t = (ed - min_w) / (max_w - min_w)
t = t.add(-0.5).mul(2.0)
t = self.quantize(t)
t = t.clamp(min= -1.0, max=1.0-(1.0/128.0))
t = t.mul(2**(8-1)).add(0.5).floor().clamp(min=-128, max=127)
return t.permute(0, 2, 1)
|