| | |
| |
|
| | import math |
| | from functools import partial |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .utils import generate_length_mask, init, PositionalEncoding |
| |
|
| |
|
| | class BaseDecoder(nn.Module): |
| | """ |
| | Take word/audio embeddings and output the next word probs |
| | Base decoder, cannot be called directly |
| | All decoders should inherit from this class |
| | """ |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, |
| | attn_emb_dim, dropout=0.2): |
| | super().__init__() |
| | self.emb_dim = emb_dim |
| | self.vocab_size = vocab_size |
| | self.fc_emb_dim = fc_emb_dim |
| | self.attn_emb_dim = attn_emb_dim |
| | self.word_embedding = nn.Embedding(vocab_size, emb_dim) |
| | self.in_dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x): |
| | raise NotImplementedError |
| |
|
| | def load_word_embedding(self, weight, freeze=True): |
| | embedding = np.load(weight) |
| | assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch" |
| | assert embedding.shape[1] == self.emb_dim, "embed size mismatch" |
| | |
| | |
| | |
| | |
| | |
| | self.word_embedding = nn.Embedding.from_pretrained(embedding, |
| | freeze=freeze) |
| |
|
| |
|
| | class RnnDecoder(BaseDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs): |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout,) |
| | self.d_model = d_model |
| | self.num_layers = kwargs.get('num_layers', 1) |
| | self.bidirectional = kwargs.get('bidirectional', False) |
| | self.rnn_type = kwargs.get('rnn_type', "GRU") |
| | self.classifier = nn.Linear( |
| | self.d_model * (self.bidirectional + 1), vocab_size) |
| |
|
| | def forward(self, x): |
| | raise NotImplementedError |
| |
|
| | def init_hidden(self, bs, device): |
| | num_dire = self.bidirectional + 1 |
| | n_layer = self.num_layers |
| | hid_dim = self.d_model |
| | if self.rnn_type == "LSTM": |
| | return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device), |
| | torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)) |
| | else: |
| | return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device) |
| |
|
| |
|
| | class RnnFcDecoder(RnnDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs): |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs) |
| | self.model = getattr(nn, self.rnn_type)( |
| | input_size=self.emb_dim * 2, |
| | hidden_size=self.d_model, |
| | batch_first=True, |
| | num_layers=self.num_layers, |
| | bidirectional=self.bidirectional) |
| | self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim) |
| | self.apply(init) |
| | |
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | state = input_dict.get("state", None) |
| | fc_emb = input_dict["fc_emb"] |
| |
|
| | word = word.to(fc_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) |
| | |
| | p_fc_emb = self.fc_proj(fc_emb) |
| | |
| | embed = torch.cat((embed, p_fc_emb), dim=-1) |
| |
|
| | out, state = self.model(embed, state) |
| | |
| | logits = self.classifier(out) |
| | output = { |
| | "state": state, |
| | "embeds": out, |
| | "logits": logits |
| | } |
| |
|
| | return output |
| |
|
| |
|
| | class Seq2SeqAttention(nn.Module): |
| |
|
| | def __init__(self, hs_enc, hs_dec, attn_size): |
| | """ |
| | Args: |
| | hs_enc: encoder hidden size |
| | hs_dec: decoder hidden size |
| | attn_size: attention vector size |
| | """ |
| | super(Seq2SeqAttention, self).__init__() |
| | self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size) |
| | self.v = nn.Parameter(torch.randn(attn_size)) |
| | self.apply(init) |
| |
|
| | def forward(self, h_dec, h_enc, src_lens): |
| | """ |
| | Args: |
| | h_dec: decoder hidden (query), [N, hs_dec] |
| | h_enc: encoder memory (key/value), [N, src_max_len, hs_enc] |
| | src_lens: source (encoder memory) lengths, [N, ] |
| | """ |
| | N = h_enc.size(0) |
| | src_max_len = h_enc.size(1) |
| | h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) |
| |
|
| | attn_input = torch.cat((h_dec, h_enc), dim=-1) |
| | attn_out = torch.tanh(self.h2attn(attn_input)) |
| |
|
| | v = self.v.repeat(N, 1).unsqueeze(1) |
| | score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) |
| |
|
| | idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len) |
| | mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device) |
| |
|
| | score = score.masked_fill(mask == 0, -1e10) |
| | weights = torch.softmax(score, dim=-1) |
| | ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) |
| |
|
| | return ctx, weights |
| |
|
| |
|
| | class AttentionProj(nn.Module): |
| |
|
| | def __init__(self, hs_enc, hs_dec, embed_dim, attn_size): |
| | self.q_proj = nn.Linear(hs_dec, embed_dim) |
| | self.kv_proj = nn.Linear(hs_enc, embed_dim) |
| | self.h2attn = nn.Linear(embed_dim * 2, attn_size) |
| | self.v = nn.Parameter(torch.randn(attn_size)) |
| | self.apply(init) |
| |
|
| | def init(self, m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.kaiming_uniform_(m.weight) |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | def forward(self, h_dec, h_enc, src_lens): |
| | """ |
| | Args: |
| | h_dec: decoder hidden (query), [N, hs_dec] |
| | h_enc: encoder memory (key/value), [N, src_max_len, hs_enc] |
| | src_lens: source (encoder memory) lengths, [N, ] |
| | """ |
| | h_enc = self.kv_proj(h_enc) |
| | h_dec = self.q_proj(h_dec) |
| | N = h_enc.size(0) |
| | src_max_len = h_enc.size(1) |
| | h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) |
| |
|
| | attn_input = torch.cat((h_dec, h_enc), dim=-1) |
| | attn_out = torch.tanh(self.h2attn(attn_input)) |
| |
|
| | v = self.v.repeat(N, 1).unsqueeze(1) |
| | score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) |
| |
|
| | idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len) |
| | mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device) |
| |
|
| | score = score.masked_fill(mask == 0, -1e10) |
| | weights = torch.softmax(score, dim=-1) |
| | ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) |
| |
|
| | return ctx, weights |
| |
|
| |
|
| | class BahAttnDecoder(RnnDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs): |
| | """ |
| | concatenate fc, attn, word to feed to the rnn |
| | """ |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs) |
| | attn_size = kwargs.get("attn_size", self.d_model) |
| | self.model = getattr(nn, self.rnn_type)( |
| | input_size=self.emb_dim * 3, |
| | hidden_size=self.d_model, |
| | batch_first=True, |
| | num_layers=self.num_layers, |
| | bidirectional=self.bidirectional) |
| | self.attn = Seq2SeqAttention(self.attn_emb_dim, |
| | self.d_model * (self.bidirectional + 1) * \ |
| | self.num_layers, |
| | attn_size) |
| | self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim) |
| | self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) |
| | self.apply(init) |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | state = input_dict.get("state", None) |
| | fc_emb = input_dict["fc_emb"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| |
|
| | word = word.to(fc_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) |
| |
|
| | |
| | if state is None: |
| | state = self.init_hidden(word.size(0), fc_emb.device) |
| | if self.rnn_type == "LSTM": |
| | query = state[0].transpose(0, 1).flatten(1) |
| | else: |
| | query = state.transpose(0, 1).flatten(1) |
| | c, attn_weight = self.attn(query, attn_emb, attn_emb_len) |
| |
|
| | p_fc_emb = self.fc_proj(fc_emb) |
| | p_ctx = self.ctx_proj(c) |
| | rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)), |
| | dim=-1) |
| |
|
| | out, state = self.model(rnn_input, state) |
| |
|
| | output = { |
| | "state": state, |
| | "embed": out, |
| | "logit": self.classifier(out), |
| | "attn_weight": attn_weight |
| | } |
| | return output |
| |
|
| |
|
| | class BahAttnDecoder2(RnnDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs): |
| | """ |
| | add fc, attn, word together to feed to the rnn |
| | """ |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs) |
| | attn_size = kwargs.get("attn_size", self.d_model) |
| | self.model = getattr(nn, self.rnn_type)( |
| | input_size=self.emb_dim, |
| | hidden_size=self.d_model, |
| | batch_first=True, |
| | num_layers=self.num_layers, |
| | bidirectional=self.bidirectional) |
| | self.attn = Seq2SeqAttention(self.emb_dim, |
| | self.d_model * (self.bidirectional + 1) * \ |
| | self.num_layers, |
| | attn_size) |
| | self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim) |
| | self.attn_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) |
| | self.apply(partial(init, method="xavier")) |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | state = input_dict.get("state", None) |
| | fc_emb = input_dict["fc_emb"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| |
|
| | word = word.to(fc_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) |
| | p_attn_emb = self.attn_proj(attn_emb) |
| |
|
| | |
| | if state is None: |
| | state = self.init_hidden(word.size(0), fc_emb.device) |
| | if self.rnn_type == "LSTM": |
| | query = state[0].transpose(0, 1).flatten(1) |
| | else: |
| | query = state.transpose(0, 1).flatten(1) |
| | c, attn_weight = self.attn(query, p_attn_emb, attn_emb_len) |
| |
|
| | p_fc_emb = self.fc_proj(fc_emb) |
| | rnn_input = embed + c.unsqueeze(1) + p_fc_emb.unsqueeze(1) |
| |
|
| | out, state = self.model(rnn_input, state) |
| |
|
| | output = { |
| | "state": state, |
| | "embed": out, |
| | "logit": self.classifier(out), |
| | "attn_weight": attn_weight |
| | } |
| | return output |
| |
|
| |
|
| | class ConditionalBahAttnDecoder(RnnDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs): |
| | """ |
| | concatenate fc, attn, word to feed to the rnn |
| | """ |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs) |
| | attn_size = kwargs.get("attn_size", self.d_model) |
| | self.model = getattr(nn, self.rnn_type)( |
| | input_size=self.emb_dim * 3, |
| | hidden_size=self.d_model, |
| | batch_first=True, |
| | num_layers=self.num_layers, |
| | bidirectional=self.bidirectional) |
| | self.attn = Seq2SeqAttention(self.attn_emb_dim, |
| | self.d_model * (self.bidirectional + 1) * \ |
| | self.num_layers, |
| | attn_size) |
| | self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) |
| | self.condition_embedding = nn.Embedding(2, emb_dim) |
| | self.apply(init) |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | state = input_dict.get("state", None) |
| | fc_emb = input_dict["fc_emb"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| | condition = input_dict["condition"] |
| |
|
| | word = word.to(fc_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) |
| |
|
| | condition = torch.as_tensor([[1 - c, c] for c in condition]).to(fc_emb.device) |
| | condition_emb = torch.matmul(condition, self.condition_embedding.weight) |
| | |
| |
|
| | |
| | if state is None: |
| | state = self.init_hidden(word.size(0), fc_emb.device) |
| | if self.rnn_type == "LSTM": |
| | query = state[0].transpose(0, 1).flatten(1) |
| | else: |
| | query = state.transpose(0, 1).flatten(1) |
| | c, attn_weight = self.attn(query, attn_emb, attn_emb_len) |
| |
|
| | p_ctx = self.ctx_proj(c) |
| | rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), condition_emb.unsqueeze(1)), |
| | dim=-1) |
| |
|
| | out, state = self.model(rnn_input, state) |
| |
|
| | output = { |
| | "state": state, |
| | "embed": out, |
| | "logit": self.classifier(out), |
| | "attn_weight": attn_weight |
| | } |
| | return output |
| |
|
| |
|
| | class StructBahAttnDecoder(RnnDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, struct_vocab_size, |
| | attn_emb_dim, dropout, d_model, **kwargs): |
| | """ |
| | concatenate fc, attn, word to feed to the rnn |
| | """ |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs) |
| | attn_size = kwargs.get("attn_size", self.d_model) |
| | self.model = getattr(nn, self.rnn_type)( |
| | input_size=self.emb_dim * 3, |
| | hidden_size=self.d_model, |
| | batch_first=True, |
| | num_layers=self.num_layers, |
| | bidirectional=self.bidirectional) |
| | self.attn = Seq2SeqAttention(self.attn_emb_dim, |
| | self.d_model * (self.bidirectional + 1) * \ |
| | self.num_layers, |
| | attn_size) |
| | self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) |
| | self.struct_embedding = nn.Embedding(struct_vocab_size, emb_dim) |
| | self.apply(init) |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | state = input_dict.get("state", None) |
| | fc_emb = input_dict["fc_emb"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| | structure = input_dict["structure"] |
| |
|
| | word = word.to(fc_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) |
| |
|
| | struct_emb = self.struct_embedding(structure) |
| | |
| |
|
| | |
| | if state is None: |
| | state = self.init_hidden(word.size(0), fc_emb.device) |
| | if self.rnn_type == "LSTM": |
| | query = state[0].transpose(0, 1).flatten(1) |
| | else: |
| | query = state.transpose(0, 1).flatten(1) |
| | c, attn_weight = self.attn(query, attn_emb, attn_emb_len) |
| |
|
| | p_ctx = self.ctx_proj(c) |
| | rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), struct_emb.unsqueeze(1)), dim=-1) |
| |
|
| | out, state = self.model(rnn_input, state) |
| |
|
| | output = { |
| | "state": state, |
| | "embed": out, |
| | "logit": self.classifier(out), |
| | "attn_weight": attn_weight |
| | } |
| | return output |
| |
|
| |
|
| | class StyleBahAttnDecoder(RnnDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs): |
| | """ |
| | concatenate fc, attn, word to feed to the rnn |
| | """ |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs) |
| | attn_size = kwargs.get("attn_size", self.d_model) |
| | self.model = getattr(nn, self.rnn_type)( |
| | input_size=self.emb_dim * 3, |
| | hidden_size=self.d_model, |
| | batch_first=True, |
| | num_layers=self.num_layers, |
| | bidirectional=self.bidirectional) |
| | self.attn = Seq2SeqAttention(self.attn_emb_dim, |
| | self.d_model * (self.bidirectional + 1) * \ |
| | self.num_layers, |
| | attn_size) |
| | self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) |
| | self.apply(init) |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | state = input_dict.get("state", None) |
| | fc_emb = input_dict["fc_emb"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| | style = input_dict["style"] |
| |
|
| | word = word.to(fc_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) |
| |
|
| | |
| | if state is None: |
| | state = self.init_hidden(word.size(0), fc_emb.device) |
| | if self.rnn_type == "LSTM": |
| | query = state[0].transpose(0, 1).flatten(1) |
| | else: |
| | query = state.transpose(0, 1).flatten(1) |
| | c, attn_weight = self.attn(query, attn_emb, attn_emb_len) |
| |
|
| | p_ctx = self.ctx_proj(c) |
| | rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), style.unsqueeze(1)), |
| | dim=-1) |
| |
|
| | out, state = self.model(rnn_input, state) |
| |
|
| | output = { |
| | "state": state, |
| | "embed": out, |
| | "logit": self.classifier(out), |
| | "attn_weight": attn_weight |
| | } |
| | return output |
| |
|
| |
|
| | class BahAttnDecoder3(RnnDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs): |
| | """ |
| | concatenate fc, attn, word to feed to the rnn |
| | """ |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs) |
| | attn_size = kwargs.get("attn_size", self.d_model) |
| | self.model = getattr(nn, self.rnn_type)( |
| | input_size=self.emb_dim + attn_emb_dim, |
| | hidden_size=self.d_model, |
| | batch_first=True, |
| | num_layers=self.num_layers, |
| | bidirectional=self.bidirectional) |
| | self.attn = Seq2SeqAttention(self.attn_emb_dim, |
| | self.d_model * (self.bidirectional + 1) * \ |
| | self.num_layers, |
| | attn_size) |
| | self.ctx_proj = lambda x: x |
| | self.apply(init) |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | state = input_dict.get("state", None) |
| | fc_emb = input_dict["fc_emb"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| |
|
| | if word.size(-1) == self.fc_emb_dim: |
| | embed = word.unsqueeze(1) |
| | elif word.size(-1) == 1: |
| | word = word.to(fc_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) |
| | else: |
| | raise Exception(f"problem with word input size {word.size()}") |
| |
|
| | |
| | if state is None: |
| | state = self.init_hidden(word.size(0), fc_emb.device) |
| | if self.rnn_type == "LSTM": |
| | query = state[0].transpose(0, 1).flatten(1) |
| | else: |
| | query = state.transpose(0, 1).flatten(1) |
| | c, attn_weight = self.attn(query, attn_emb, attn_emb_len) |
| |
|
| | p_ctx = self.ctx_proj(c) |
| | rnn_input = torch.cat((embed, p_ctx.unsqueeze(1)), dim=-1) |
| |
|
| | out, state = self.model(rnn_input, state) |
| |
|
| | output = { |
| | "state": state, |
| | "embed": out, |
| | "logit": self.classifier(out), |
| | "attn_weight": attn_weight |
| | } |
| | return output |
| |
|
| |
|
| | class SpecificityBahAttnDecoder(RnnDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs): |
| | """ |
| | concatenate fc, attn, word to feed to the rnn |
| | """ |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, d_model, **kwargs) |
| | attn_size = kwargs.get("attn_size", self.d_model) |
| | self.model = getattr(nn, self.rnn_type)( |
| | input_size=self.emb_dim + attn_emb_dim + 1, |
| | hidden_size=self.d_model, |
| | batch_first=True, |
| | num_layers=self.num_layers, |
| | bidirectional=self.bidirectional) |
| | self.attn = Seq2SeqAttention(self.attn_emb_dim, |
| | self.d_model * (self.bidirectional + 1) * \ |
| | self.num_layers, |
| | attn_size) |
| | self.ctx_proj = lambda x: x |
| | self.apply(init) |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | state = input_dict.get("state", None) |
| | fc_emb = input_dict["fc_emb"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| | condition = input_dict["condition"] |
| |
|
| | word = word.to(fc_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) |
| |
|
| | |
| | if state is None: |
| | state = self.init_hidden(word.size(0), fc_emb.device) |
| | if self.rnn_type == "LSTM": |
| | query = state[0].transpose(0, 1).flatten(1) |
| | else: |
| | query = state.transpose(0, 1).flatten(1) |
| | c, attn_weight = self.attn(query, attn_emb, attn_emb_len) |
| |
|
| | p_ctx = self.ctx_proj(c) |
| | rnn_input = torch.cat( |
| | (embed, p_ctx.unsqueeze(1), condition.reshape(-1, 1, 1)), |
| | dim=-1) |
| |
|
| | out, state = self.model(rnn_input, state) |
| |
|
| | output = { |
| | "state": state, |
| | "embed": out, |
| | "logit": self.classifier(out), |
| | "attn_weight": attn_weight |
| | } |
| | return output |
| |
|
| |
|
| | class TransformerDecoder(BaseDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, **kwargs): |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout=dropout,) |
| | self.d_model = emb_dim |
| | self.nhead = kwargs.get("nhead", self.d_model // 64) |
| | self.nlayers = kwargs.get("nlayers", 2) |
| | self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) |
| |
|
| | self.pos_encoder = PositionalEncoding(self.d_model, dropout) |
| | layer = nn.TransformerDecoderLayer(d_model=self.d_model, |
| | nhead=self.nhead, |
| | dim_feedforward=self.dim_feedforward, |
| | dropout=dropout) |
| | self.model = nn.TransformerDecoder(layer, self.nlayers) |
| | self.classifier = nn.Linear(self.d_model, vocab_size) |
| | self.attn_proj = nn.Sequential( |
| | nn.Linear(self.attn_emb_dim, self.d_model), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.LayerNorm(self.d_model) |
| | ) |
| | |
| | self.init_params() |
| |
|
| | def init_params(self): |
| | for p in self.parameters(): |
| | if p.dim() > 1: |
| | nn.init.xavier_uniform_(p) |
| |
|
| | def generate_square_subsequent_mask(self, max_length): |
| | mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1) |
| | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) |
| | return mask |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| | cap_padding_mask = input_dict["cap_padding_mask"] |
| |
|
| | p_attn_emb = self.attn_proj(attn_emb) |
| | p_attn_emb = p_attn_emb.transpose(0, 1) |
| | word = word.to(attn_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) |
| | embed = embed.transpose(0, 1) |
| | embed = self.pos_encoder(embed) |
| |
|
| | tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) |
| | memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) |
| | output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, |
| | tgt_key_padding_mask=cap_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask) |
| | output = output.transpose(0, 1) |
| | output = { |
| | "embed": output, |
| | "logit": self.classifier(output), |
| | } |
| | return output |
| |
|
| |
|
| |
|
| |
|
| | class EventTransformerDecoder(TransformerDecoder): |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| | cap_padding_mask = input_dict["cap_padding_mask"] |
| | event_emb = input_dict["event"] |
| |
|
| | p_attn_emb = self.attn_proj(attn_emb) |
| | p_attn_emb = p_attn_emb.transpose(0, 1) |
| | word = word.to(attn_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) |
| |
|
| | embed = embed.transpose(0, 1) |
| | embed += event_emb |
| | embed = self.pos_encoder(embed) |
| |
|
| | tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) |
| | memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) |
| | output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, |
| | tgt_key_padding_mask=cap_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask) |
| | output = output.transpose(0, 1) |
| | output = { |
| | "embed": output, |
| | "logit": self.classifier(output), |
| | } |
| | return output |
| |
|
| |
|
| | class KeywordProbTransformerDecoder(TransformerDecoder): |
| |
|
| | def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, keyword_classes_num, **kwargs): |
| | super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
| | dropout, **kwargs) |
| | self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model) |
| | self.word_keyword_norm = nn.LayerNorm(self.d_model) |
| |
|
| | def forward(self, input_dict): |
| | word = input_dict["word"] |
| | attn_emb = input_dict["attn_emb"] |
| | attn_emb_len = input_dict["attn_emb_len"] |
| | cap_padding_mask = input_dict["cap_padding_mask"] |
| | keyword = input_dict["keyword"] |
| |
|
| | p_attn_emb = self.attn_proj(attn_emb) |
| | p_attn_emb = p_attn_emb.transpose(0, 1) |
| | word = word.to(attn_emb.device) |
| | embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) |
| |
|
| | embed = embed.transpose(0, 1) |
| | embed += self.keyword_proj(keyword) |
| | embed = self.word_keyword_norm(embed) |
| |
|
| | embed = self.pos_encoder(embed) |
| |
|
| | tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) |
| | memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) |
| | output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, |
| | tgt_key_padding_mask=cap_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask) |
| | output = output.transpose(0, 1) |
| | output = { |
| | "embed": output, |
| | "logit": self.classifier(output), |
| | } |
| | return output |
| |
|