|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from model.module.biaffine import Biaffine |
|
|
|
|
|
class AnchorClassifier(nn.Module): |
|
def __init__(self, dataset, args, initialize: bool, bias=True, mode="anchor"): |
|
super(AnchorClassifier, self).__init__() |
|
|
|
self.token_f = nn.Linear(args.hidden_size, args.hidden_size_anchor) |
|
self.label_f = nn.Linear(args.hidden_size, args.hidden_size_anchor) |
|
self.dropout = nn.Dropout(args.dropout_anchor) |
|
|
|
if bias and initialize: |
|
bias_init = torch.tensor([getattr(dataset, f"{mode}_freq")]) |
|
bias_init = (bias_init / (1.0 - bias_init)).log() |
|
else: |
|
bias_init = None |
|
|
|
self.output = Biaffine(args.hidden_size_anchor, 1, bias=bias, bias_init=bias_init) |
|
|
|
def forward(self, label, tokens, encoder_mask): |
|
tokens = self.dropout(F.elu(self.token_f(tokens))) |
|
label = self.dropout(F.elu(self.label_f(label))) |
|
anchor = self.output(label, tokens).squeeze(-1) |
|
|
|
anchor = anchor.masked_fill(encoder_mask.unsqueeze(1), float("-inf")) |
|
return anchor |
|
|