File size: 677 Bytes
c45d283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#!/usr/bin/env python3
# coding=utf-8

import torch.nn as nn
from model.module.bilinear import Bilinear


class Biaffine(nn.Module):
    def __init__(self, input_dim, output_dim, bias=True, bias_init=None):
        super(Biaffine, self).__init__()

        self.linear_1 = nn.Linear(input_dim, output_dim, bias=False)
        self.linear_2 = nn.Linear(input_dim, output_dim, bias=False)

        self.bilinear = Bilinear(input_dim, input_dim, output_dim, bias=bias)
        if bias_init is not None:
            self.bilinear.bias.data = bias_init

    def forward(self, x, y):
        return self.bilinear(x, y) + self.linear_1(x).unsqueeze(2) + self.linear_2(y).unsqueeze(1)