File size: 2,111 Bytes
1d5604f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python3
# coding=utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.module.biaffine import Biaffine


class EdgeClassifier(nn.Module):
    def __init__(self, dataset, args, initialize: bool, presence: bool, label: bool):
        super(EdgeClassifier, self).__init__()

        self.presence = presence
        if self.presence:
            if initialize:
                presence_init = torch.tensor([dataset.edge_presence_freq])
                presence_init = (presence_init / (1.0 - presence_init)).log()
            else:
                presence_init = None

            self.edge_presence = EdgeBiaffine(
                args.hidden_size, args.hidden_size_edge_presence, 1, args.dropout_edge_presence, bias_init=presence_init
            )

        self.label = label
        if self.label:
            label_init = (dataset.edge_label_freqs / (1.0 - dataset.edge_label_freqs)).log() if initialize else None
            n_labels = len(dataset.edge_label_field.vocab)
            self.edge_label = EdgeBiaffine(
                args.hidden_size, args.hidden_size_edge_label, n_labels, args.dropout_edge_label, bias_init=label_init
            )

    def forward(self, x):
        presence, label = None, None

        if self.presence:
            presence = self.edge_presence(x).squeeze(-1)  # shape: (B, T, T)
        if self.label:
            label = self.edge_label(x)  # shape: (B, T, T, O_1)

        return presence, label


class EdgeBiaffine(nn.Module):
    def __init__(self, hidden_dim, bottleneck_dim, output_dim, dropout, bias_init=None):
        super(EdgeBiaffine, self).__init__()
        self.hidden = nn.Linear(hidden_dim, 2 * bottleneck_dim)
        self.output = Biaffine(bottleneck_dim, output_dim, bias_init=bias_init)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout(F.elu(self.hidden(x)))  # shape: (B, T, 2H)
        predecessors, current = x.chunk(2, dim=-1)  # shape: (B, T, H), (B, T, H)
        edge = self.output(current, predecessors)  # shape: (B, T, T, O)
        return edge