import math import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint class MultiHeadAttention(nn.Module): def __init__(self, heads, hidden_dim): super(MultiHeadAttention, self).__init__() assert hidden_dim % heads == 0 self.heads = heads head_dim = hidden_dim // heads self.alpha = 1 / math.sqrt(head_dim) self.nn_Q = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) self.nn_O = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim)) for param in self.parameters(): stdv = 1. / math.sqrt(param.size(-1)) param.data.uniform_(-stdv, stdv) def forward(self, q, K, V, mask): batch_size, query_num, hidden_dim = q.size() size = (self.heads, batch_size, query_num, -1) q = q.reshape(-1, hidden_dim) Q = torch.matmul(q, self.nn_Q).view(size) value_num = V.size(2) heads_batch = self.heads * batch_size Q = Q.view(heads_batch, query_num, -1) K = K.view(heads_batch, value_num, -1).transpose(1, 2) S = masked_tensor(mask, self.heads) S = S.view(heads_batch, query_num, value_num) S.baddbmm_(Q, K, alpha=self.alpha) S = S.view(self.heads, batch_size, query_num, value_num) S = F.softmax(S, dim=-1) x = torch.matmul(S, V).permute(1, 2, 0, 3) x = x.reshape(batch_size, query_num, -1) x = torch.matmul(x, self.nn_O) return x class Decode(nn.Module): def __init__(self, nn_args): super(Decode, self).__init__() self.nn_args = nn_args heads = nn_args['decode_atten_heads'] hidden_dim = nn_args['decode_hidden_dim'] self.heads = heads self.alpha = 1 / math.sqrt(hidden_dim) if heads > 0: assert hidden_dim % heads == 0 head_dim = hidden_dim // heads self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) self.nn_mha = MultiHeadAttention(heads, hidden_dim) decode_rnn = nn_args.setdefault('decode_rnn', 'LSTM') assert decode_rnn in ('GRU', 'LSTM', 'NONE') if decode_rnn == 'GRU': self.nn_rnn_cell = nn.GRUCell(hidden_dim, hidden_dim) elif decode_rnn == 'LSTM': self.nn_rnn_cell = nn.LSTMCell(hidden_dim, hidden_dim) else: self.nn_rnn_cell = None self.vars_dim = sum(nn_args['variable_dim'].values()) if self.vars_dim > 0: atten_type = nn_args.setdefault('decode_atten_type', 'add') assert atten_type == 'add', "must be addition attention when vars_dim > 0, {}".format(atten_type) self.nn_A = nn.Parameter(torch.Tensor(self.vars_dim, hidden_dim)) self.nn_B = nn.Parameter(torch.Tensor(hidden_dim)) else: atten_type = nn_args.setdefault('decode_atten_type', 'prod') if atten_type == 'add': self.nn_W = nn.Parameter(torch.Tensor(hidden_dim)) else: self.nn_W = None for param in self.parameters(): stdv = 1 / math.sqrt(param.size(-1)) param.data.uniform_(-stdv, stdv) def forward(self, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p, clip, mode, memopt=0): if self.training and memopt > 2: state1, state2 = checkpoint(self.rnn_step, query, state1, state2) else: state1, state2 = self.rnn_step(query, state1, state2) query = state1 NP = X.size(0) NR = query.size(0) // NP batch_size = query.size(0) if self.heads > 0: query = query.view(NP, NR, -1) if self.training and memopt > 1: query = checkpoint(self.nn_mha, query, K, V, mask) else: query = self.nn_mha(query, K, V, mask) query = query.view(batch_size, -1) if self.nn_W is None: query = query.view(NP, NR, -1) logit = masked_tensor(mask, 1) logit = logit.view(NP, NR, -1) X = X.permute(0, 2, 1) logit.baddbmm_(query, X, alpha=self.alpha) logit = logit.view(batch_size, -1) else: if self.training and self.vars_dim > 0 and memopt > 0: logit = checkpoint(self.atten, query, X, varfeat, mask) else: logit = self.atten(query, X, varfeat, mask) chosen_p = choose(logit, chosen, sample_p, clip, mode) return state1, state2, chosen_p def rnn_step(self, query, state1, state2): if isinstance(self.nn_rnn_cell, nn.GRUCell): state1 = self.nn_rnn_cell(query, state1) elif isinstance(self.nn_rnn_cell, nn.LSTMCell): state1, state2 = self.nn_rnn_cell(query, (state1, state2)) return state1, state2 def atten(self, query, keyvalue, varfeat, mask): if self.vars_dim > 0: varfeat = vfaddmm(varfeat, mask, self.nn_A, self.nn_B) return atten(query, keyvalue, varfeat, mask, self.nn_W) def choose(logit, chosen, sample_p, clip, mode): mask = logit == -math.inf logit = torch.tanh(logit) * clip logit[mask] = -math.inf if mode == 0: pass elif mode == 1: chosen[:] = logit.argmax(1) elif mode == 2: p = logit.exp() chosen[:] = torch.multinomial(p, 1).squeeze(1) else: raise Exception() logp = logit.log_softmax(1) logp = logp.gather(1, chosen[:, None]) logp = logp.squeeze(1) return logp def atten(query, keyvalue, varfeat, mask, weight): batch_size = query.size(0) NP, NK, ND = keyvalue.size() query = query.view(NP, -1, 1, ND) varfeat = varfeat.view(NP, -1, NK, ND) keyvalue = keyvalue[:, None, :, :] keyvalue = keyvalue + varfeat + query keyvalue = torch.tanh(keyvalue) keyvalue = keyvalue.view(-1, ND) logit = masked_tensor(mask, 1).view(-1) logit.addmv_(keyvalue, weight) return logit.view(batch_size, -1) def masked_tensor(mask, heads): size = list(mask.size()) size.insert(0, heads) mask = mask[None].expand(size) result = mask.new_zeros(size, dtype=torch.float32) result[mask] = -math.inf return result def vfaddmm(varfeat, mask, A, B): varfeat = varfeat.permute(0, 2, 1) return F.linear(varfeat, A.permute(1, 0), B)