File size: 2,663 Bytes
69524d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
from torch import nn
import constants as cst
class BiN(nn.Module):
def __init__(self, d1, t1):
super().__init__()
self.t1 = t1
self.d1 = d1
bias1 = torch.Tensor(t1, 1)
self.B1 = nn.Parameter(bias1)
nn.init.constant_(self.B1, 0)
l1 = torch.Tensor(t1, 1)
self.l1 = nn.Parameter(l1)
nn.init.xavier_normal_(self.l1)
bias2 = torch.Tensor(d1, 1)
self.B2 = nn.Parameter(bias2)
nn.init.constant_(self.B2, 0)
l2 = torch.Tensor(d1, 1)
self.l2 = nn.Parameter(l2)
nn.init.xavier_normal_(self.l2)
y1 = torch.Tensor(1, )
self.y1 = nn.Parameter(y1)
nn.init.constant_(self.y1, 0.5)
y2 = torch.Tensor(1, )
self.y2 = nn.Parameter(y2)
nn.init.constant_(self.y2, 0.5)
def forward(self, x):
# if the two scalars are negative then we setting them to 0
if (self.y1[0] < 0):
y1 = torch.cuda.FloatTensor(1, )
self.y1 = nn.Parameter(y1)
nn.init.constant_(self.y1, 0.01)
if (self.y2[0] < 0):
y2 = torch.cuda.FloatTensor(1, )
self.y2 = nn.Parameter(y2)
nn.init.constant_(self.y2, 0.01)
# normalization along the temporal dimensione
T2 = torch.ones([self.t1, 1], device=cst.DEVICE)
x2 = torch.mean(x, dim=2)
x2 = torch.reshape(x2, (x2.shape[0], x2.shape[1], 1))
std = torch.std(x, dim=2)
std = torch.reshape(std, (std.shape[0], std.shape[1], 1))
# it can be possible that the std of some temporal slices is 0, and this produces inf values, so we have to set them to one
std[std < 1e-4] = 1
diff = x - (x2 @ (T2.T))
Z2 = diff / (std @ (T2.T))
X2 = self.l2 @ T2.T
X2 = X2 * Z2
X2 = X2 + (self.B2 @ T2.T)
# normalization along the feature dimension
T1 = torch.ones([self.d1, 1], device=cst.DEVICE)
x1 = torch.mean(x, dim=1)
x1 = torch.reshape(x1, (x1.shape[0], x1.shape[1], 1))
std = torch.std(x, dim=1)
std = torch.reshape(std, (std.shape[0], std.shape[1], 1))
op1 = x1 @ T1.T
op1 = torch.permute(op1, (0, 2, 1))
op2 = std @ T1.T
op2 = torch.permute(op2, (0, 2, 1))
z1 = (x - op1) / (op2)
X1 = (T1 @ self.l1.T)
X1 = X1 * z1
X1 = X1 + (T1 @ self.B1.T)
# weighing the imporance of temporal and feature normalization
x = self.y1 * X1 + self.y2 * X2
return x |