|
|
|
|
|
|
|
import torch.nn as nn |
|
from model.module.bilinear import Bilinear |
|
|
|
|
|
class Biaffine(nn.Module): |
|
def __init__(self, input_dim, output_dim, bias=True, bias_init=None): |
|
super(Biaffine, self).__init__() |
|
|
|
self.linear_1 = nn.Linear(input_dim, output_dim, bias=False) |
|
self.linear_2 = nn.Linear(input_dim, output_dim, bias=False) |
|
|
|
self.bilinear = Bilinear(input_dim, input_dim, output_dim, bias=bias) |
|
if bias_init is not None: |
|
self.bilinear.bias.data = bias_init |
|
|
|
def forward(self, x, y): |
|
return self.bilinear(x, y) + self.linear_1(x).unsqueeze(2) + self.linear_2(y).unsqueeze(1) |
|
|