LCM / utils /nn /modules /linear.py
shivrajanand's picture
Upload folder using huggingface_hub
e8f4897 verified
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class BiLinear(nn.Module):
'''
Bi-linear layer
'''
def __init__(self, left_features, right_features, out_features, bias=True):
'''
Args:
left_features: size of left input
right_features: size of right input
out_features: size of output
bias: If set to False, the layer will not learn an additive bias.
Default: True
'''
super(BiLinear, self).__init__()
self.left_features = left_features
self.right_features = right_features
self.out_features = out_features
self.U = Parameter(torch.Tensor(self.out_features, self.left_features, self.right_features))
self.W_l = Parameter(torch.Tensor(self.out_features, self.left_features))
self.W_r = Parameter(torch.Tensor(self.out_features, self.left_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.W_l)
nn.init.xavier_uniform_(self.W_r)
nn.init.constant_(self.bias, 0.)
nn.init.xavier_uniform_(self.U)
def forward(self, input_left, input_right):
'''
Args:
input_left: Tensor
the left input tensor with shape = [batch1, batch2, ..., left_features]
input_right: Tensor
the right input tensor with shape = [batch1, batch2, ..., right_features]
Returns:
'''
left_size = input_left.size()
right_size = input_right.size()
assert left_size[:-1] == right_size[:-1], \
"batch size of left and right inputs mis-match: (%s, %s)" % (left_size[:-1], right_size[:-1])
batch_size = int(np.prod(left_size[:-1]))
# convert left and right input to matrices [batch_size, left_features], [batch_size, right_features]
input_left = input_left.view(batch_size, self.left_features)
input_right = input_right.view(batch_size, self.right_features)
# output [batch_size, out_features]
output = F.bilinear(input_left, input_right, self.U, self.bias)
output = output + F.linear(input_left, self.W_l, None) + F.linear(input_right, self.W_r, None)
# convert back to [batch1, batch2, ..., out_features]
# output = F.linear(input_left, self.W_l, None) + F.linear(input_right, self.W_r, None)
return output.view(left_size[:-1] + (self.out_features, ))
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ 'in1_features=' + str(self.left_features) \
+ ', in2_features=' + str(self.right_features) \
+ ', out_features=' + str(self.out_features) + ')'