|
import torch |
|
import torch.nn as nn |
|
from torch.nn.utils.rnn import pack_padded_sequence |
|
|
|
|
|
class TextEncoderBiGRUCo(nn.Module): |
|
def __init__(self, word_size: int, pos_size: int, hidden_size: int, output_size: int) -> None: |
|
super(TextEncoderBiGRUCo, self).__init__() |
|
|
|
self.pos_emb = nn.Linear(pos_size, word_size) |
|
self.input_emb = nn.Linear(word_size, hidden_size) |
|
self.gru = nn.GRU( |
|
hidden_size, hidden_size, batch_first=True, bidirectional=True |
|
) |
|
self.output_net = nn.Sequential( |
|
nn.Linear(hidden_size * 2, hidden_size), |
|
nn.LayerNorm(hidden_size), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Linear(hidden_size, output_size), |
|
) |
|
|
|
self.hidden_size = hidden_size |
|
self.hidden = nn.Parameter( |
|
torch.randn((2, 1, self.hidden_size), requires_grad=True) |
|
) |
|
|
|
def forward(self, word_embs: torch.Tensor, pos_onehot: torch.Tensor, |
|
cap_lens: torch.Tensor) -> torch.Tensor: |
|
num_samples = word_embs.shape[0] |
|
|
|
pos_embs = self.pos_emb(pos_onehot) |
|
inputs = word_embs + pos_embs |
|
input_embs = self.input_emb(inputs) |
|
hidden = self.hidden.repeat(1, num_samples, 1) |
|
|
|
cap_lens = cap_lens.data.tolist() |
|
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) |
|
|
|
gru_seq, gru_last = self.gru(emb, hidden) |
|
|
|
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) |
|
|
|
return self.output_net(gru_last) |
|
|