ssa-perin / model /module /biaffine.py
larkkin's picture
Add code and readme
c45d283
raw
history blame
No virus
677 Bytes
#!/usr/bin/env python3
# coding=utf-8
import torch.nn as nn
from model.module.bilinear import Bilinear
class Biaffine(nn.Module):
def __init__(self, input_dim, output_dim, bias=True, bias_init=None):
super(Biaffine, self).__init__()
self.linear_1 = nn.Linear(input_dim, output_dim, bias=False)
self.linear_2 = nn.Linear(input_dim, output_dim, bias=False)
self.bilinear = Bilinear(input_dim, input_dim, output_dim, bias=bias)
if bias_init is not None:
self.bilinear.bias.data = bias_init
def forward(self, x, y):
return self.bilinear(x, y) + self.linear_1(x).unsqueeze(2) + self.linear_2(y).unsqueeze(1)