pgps-demo / model /decoder /rnn_decoder.py
asdfasdfdsafdsa's picture
Fix Lang classes, CUDA compatibility, and config imports
2a2cec1 verified
import torch
import torch.nn as nn
from model.module import *
from utils import *
from torch.nn import functional as F
class DecoderRNN(nn.Module):
def __init__(self, cfg, tgt_lang):
super(DecoderRNN, self).__init__()
# token location
self.var_start = tgt_lang.var_start # spe_num + midvar_num + const_num + op_num
self.sos_id = tgt_lang.word2index["[SOS]"]
self.eos_id = tgt_lang.word2index["[EOS]"]
# Define layers
self.em_dropout = nn.Dropout(cfg.dropout_rate)
self.embedding_tgt = nn.Embedding(self.var_start, cfg.decoder_embedding_size, padding_idx=0)
self.gru = nn.GRU(input_size=cfg.decoder_hidden_size+cfg.decoder_embedding_size, \
hidden_size=cfg.decoder_hidden_size, \
num_layers=cfg.decoder_layers, \
dropout = cfg.dropout_rate, \
batch_first = True)
# Choose attention model
self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size)
self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size)
# predefined constant
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.no_var_id = torch.arange(self.var_start).unsqueeze(0).to(self.device)
self.cfg = cfg
def get_var_encoder_outputs(self, encoder_outputs, var_pos):
"""
Arguments:
encoder_outputs: B x S1 x H
var_pos: B x S3
Returns:
var_embeddings: B x S3 x H
"""
hidden_size = encoder_outputs.size(-1)
expand_var_pos = var_pos.unsqueeze(-1).repeat(1, 1, hidden_size)
var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_pos)
return var_embeddings
def forward(self, encoder_outputs, problem_output, len_src, var_pos, len_var, \
text_tgt=None, is_train=False):
"""
Arguments:
encoder_outputs: B x S1 x H
problem_output: layer_num x B x H
len_src: B
text_tgt: B x S2
var_pos: B x S3
len_var: B
Return:
training: logits, B x S x (no_var_size+var_size)
testing: exp_id, B x candi_size(beam_size) x exp_len
"""
self.embedding_var = self.get_var_encoder_outputs(encoder_outputs, var_pos) # B x S3 x H
self.src_mask = sequence_mask(len_src) # B x S1
self.candi_mask = sequence_mask(self.var_start + len_var) # B x (no_var_size + var_size)
if is_train:
return self._forward_train(encoder_outputs, problem_output, text_tgt)
else:
return self._forward_test(encoder_outputs, problem_output)
def _forward_train(self, encoder_outputs, problem_output, text_tgt):
all_seq_outputs = []
batch_size = encoder_outputs.size(0)
# initial hidden input of RNN
rnn_hidden = problem_output
# input embedding
tgt_novar_id = torch.clamp(text_tgt, max=self.var_start-1) # B x S2
novar_embedding = self.embedding_tgt(tgt_novar_id) # B x S2 x H
tgt_var_id = torch.clamp(text_tgt-self.var_start, min=0) # B x S2
var_embeddings = self.embedding_var.gather(dim=1, index = \
tgt_var_id.unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size)) # B x S2 x H
choose_mask = (text_tgt<self.var_start).unsqueeze(2). \
repeat(1, 1, self.cfg.decoder_embedding_size)
embedding_all = torch.where(choose_mask, novar_embedding, var_embeddings) # B x S2 x H
embedding_all_ = self.em_dropout(embedding_all)
# candi weight embedding
embedding_weight_no_var = self.embedding_tgt(self.no_var_id. \
repeat(batch_size, 1)) # B x no_var_size x H
embedding_weight_all = torch.cat((embedding_weight_no_var, self.embedding_var), dim=1) # B x (no_var_size + var_size) x H
embedding_weight_all_ = self.em_dropout(embedding_weight_all)
for t in range(text_tgt.size(1)-1):
# Calculate attention from current RNN state and all encoder outputs;
# apply to encoder outputs to get weighted average
current_hiddens = self.em_dropout(rnn_hidden[-1].unsqueeze(1)) # B x 1 x H
attn_weights = self.attn(current_hiddens, encoder_outputs, self.src_mask)
context = attn_weights.unsqueeze(1).bmm(encoder_outputs) # B x 1 x H
# Get current hidden state from input word and last hidden state
rnn_output, rnn_hidden = self.gru(torch.cat((embedding_all_[:, t:t+1, :], context), 2), rnn_hidden)
# rnn_output: B x 1 x H
# rnn_hidden: num_layers x B x H
current_fusion_emb = torch.cat((rnn_output, context), 2)
current_fusion_emb_ = self.em_dropout(current_fusion_emb)
candi_score = self.score(current_fusion_emb_, embedding_weight_all_, \
self.candi_mask) # B x (no_var_size + var_size)
all_seq_outputs.append(candi_score)
all_seq_outputs = torch.stack(all_seq_outputs, dim=1)
return all_seq_outputs
def _forward_test(self, encoder_outputs, problem_output):
"""
Decode with beam search algorithm
"""
exp_outputs = []
batch_size = encoder_outputs.size(0)
for sample_id in range(batch_size):
# predefine
rem_size = self.cfg.beam_size
encoder_output = encoder_outputs[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S1 x H
src_mask = self.src_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
embedding_var = self.embedding_var[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S3 x H
embedding_weight_no_var = self.embedding_tgt(self.no_var_id.repeat(rem_size, 1)) # beam_size x no_var_size x H
embedding_weight_all = torch.cat((embedding_weight_no_var, embedding_var), dim=1) # beam_size x (no_var_size + var_size) x H
embedding_weight_all_ = self.em_dropout(embedding_weight_all)
candi_mask = self.candi_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
candi_exp_output = []
candi_score_output = []
for i in range(self.cfg.max_output_len):
# initial varible
if i==0:
input_token = torch.LongTensor([[self.sos_id]]*rem_size).to(self.device) # rem_size x 1
rnn_hidden = problem_output[:, sample_id:sample_id+1].repeat(1, rem_size, 1) # layer_num x rem_size x H
current_score = torch.FloatTensor([[0.0]]*rem_size).to(self.device) # rem_size x 1
current_exp_list = [[]]*rem_size
else:
input_token = torch.LongTensor(token_list).unsqueeze(1).to(self.device)
rnn_hidden = rnn_hidden[:, cand_list]
rem_size = len(exp_list)
current_score = torch.FloatTensor(score_list[:rem_size]).unsqueeze(1).to(self.device)
current_exp_list = exp_list
# input embedding
tgt_novar_id = torch.clamp(input_token, max=self.var_start-1) # rem_size x 1
novar_embedding = self.embedding_tgt(tgt_novar_id) # rem_size x 1 x H
tgt_var_id = torch.clamp(input_token-self.var_start, min=0) # rem_size x 1
var_embeddings = embedding_var[:rem_size].gather(dim=1, index=tgt_var_id.unsqueeze(2). \
repeat(1, 1, self.cfg.decoder_embedding_size)) # rem_size x 1 x H
choose_mask = (input_token<self.var_start).unsqueeze(2). \
repeat(1, 1, self.cfg.decoder_embedding_size) # rem_size x 1 x H
embedding_all = torch.where(choose_mask, novar_embedding, var_embeddings) # rem_size x 1 x H
embedding_all_ = self.em_dropout(embedding_all)
# attention
current_hiddens = self.em_dropout(rnn_hidden[-1].unsqueeze(1)) # rem_size x 1 x H
attn_weights = self.attn(current_hiddens, encoder_output[:rem_size], src_mask[:rem_size]) # rem_size x S1
context = attn_weights.unsqueeze(1).bmm(encoder_output[:rem_size]) # rem_size x 1 x H
# Get current hidden state from input word and last hidden state
rnn_output, rnn_hidden = self.gru(torch.cat((embedding_all_, context), 2), rnn_hidden)
# rnn_output: rem_size x 1 x H
# rnn_hidden: num_layers x rem_size x H
current_fusion_emb = torch.cat((rnn_output, context), 2)
current_fusion_emb_ = self.em_dropout(current_fusion_emb)
candi_score = self.score(current_fusion_emb_, embedding_weight_all_[:rem_size], \
candi_mask[:rem_size]) # rem_size x (no_var_size + var_size)
if i==0:
new_score = F.log_softmax(candi_score, dim=1)[:1]
else:
new_score = F.log_softmax(candi_score, dim=1) + current_score
cand_tup_list = [(score, id) for id, score in enumerate(new_score.view(-1).tolist())]
cand_tup_list += [(score, -1) for score in candi_score_output]
cand_tup_list.sort(key=lambda x:x[0], reverse=True)
token_list = []
cand_list = []
exp_list = []
score_list = []
for tv, ti in cand_tup_list[:self.cfg.beam_size]:
if ti!=-1:
idex = ti
x = idex // candi_score.size(-1)
y = idex % candi_score.size(-1)
if y!=self.eos_id:
token_list.append(y)
cand_list.append(x)
exp_list.append(current_exp_list[x]+[y])
score_list.append(tv)
else:
candi_exp_output.append(current_exp_list[x])
candi_score_output.append(float(tv))
if len(token_list)==0:
break
if len(candi_exp_output)>0:
_, candi_exp_output = zip(*sorted(zip(candi_score_output, candi_exp_output), reverse=True))
exp_outputs.append(list(candi_exp_output[:self.cfg.beam_size]))
else:
exp_outputs.append([])
return exp_outputs