|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from model.decoder.interaction.base_interaction import BaseInteraction |
|
|
|
|
|
class GraphAttentionLayer(nn.Module): |
|
""" |
|
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 |
|
""" |
|
|
|
def __init__(self, in_features, out_features, dropout, alpha, concat=True): |
|
super(GraphAttentionLayer, self).__init__() |
|
self.dropout = dropout |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.alpha = alpha |
|
self.concat = concat |
|
|
|
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) |
|
nn.init.xavier_uniform_(self.W.data, gain=1.414) |
|
self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1))) |
|
nn.init.xavier_uniform_(self.a.data, gain=1.414) |
|
|
|
self.leakyrelu = nn.LeakyReLU(self.alpha) |
|
|
|
def forward(self, input, adj): |
|
h = torch.matmul(input, self.W) |
|
B, N = h.size()[0], h.size()[1] |
|
|
|
a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=2).view(B, N, -1, |
|
2 * self.out_features) |
|
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3)) |
|
|
|
zero_vec = -9e15 * torch.ones_like(e) |
|
attention = torch.where(adj > 0, e, zero_vec) |
|
attention = F.softmax(attention, dim=2) |
|
attention = F.dropout(attention, self.dropout, training=self.training) |
|
h_prime = torch.matmul(attention, h) |
|
|
|
if self.concat: |
|
return F.elu(h_prime) |
|
else: |
|
return h_prime |
|
|
|
|
|
class GAT(nn.Module): |
|
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, nlayers=2): |
|
"""Dense version of GAT.""" |
|
super(GAT, self).__init__() |
|
self.dropout = dropout |
|
self.nlayers = nlayers |
|
self.nheads = nheads |
|
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in |
|
range(nheads)] |
|
for i, attention in enumerate(self.attentions): |
|
self.add_module('attention_{}'.format(i), attention) |
|
if self.nlayers > 2: |
|
for i in range(self.nlayers - 2): |
|
for j in range(self.nheads): |
|
self.add_module('attention_{}_{}'.format(i + 1, j), |
|
GraphAttentionLayer(nhid * nheads, nhid, dropout=dropout, alpha=alpha, concat=True)) |
|
|
|
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) |
|
|
|
def forward(self, x, adj): |
|
x = F.dropout(x, self.dropout, training=self.training) |
|
input = x |
|
x = torch.cat([att(x, adj) for att in self.attentions], dim=2) |
|
if self.nlayers > 2: |
|
for i in range(self.nlayers - 2): |
|
temp = [] |
|
x = F.dropout(x, self.dropout, training=self.training) |
|
cur_input = x |
|
for j in range(self.nheads): |
|
temp.append(self.__getattr__('attention_{}_{}'.format(i + 1, j))(x, adj)) |
|
x = torch.cat(temp, dim=2) + cur_input |
|
x = F.dropout(x, self.dropout, training=self.training) |
|
x = F.elu(self.out_att(x, adj)) |
|
return x + input |
|
|
|
|
|
def normalize_adj(mx): |
|
""" |
|
Row-normalize matrix D^{-1}A |
|
torch.diag_embed: https://github.com/pytorch/pytorch/pull/12447 |
|
""" |
|
mx = mx.float() |
|
rowsum = mx.sum(2) |
|
r_inv = torch.pow(rowsum, -1) |
|
r_inv[torch.isinf(r_inv)] = 0. |
|
r_mat_inv = torch.diag_embed(r_inv, 0) |
|
mx = r_mat_inv.matmul(mx) |
|
return mx |
|
|
|
|
|
class AGIFInteraction(BaseInteraction): |
|
def __init__(self, **config): |
|
super().__init__(**config) |
|
self.intent_embedding = nn.Parameter( |
|
torch.FloatTensor(self.config["intent_label_num"], self.config["intent_embedding_dim"])) |
|
nn.init.normal_(self.intent_embedding.data) |
|
self.adj = None |
|
self.graph = GAT( |
|
config["output_dim"], |
|
config["hidden_dim"], |
|
config["output_dim"], |
|
config["dropout_rate"], |
|
config["alpha"], |
|
config["num_heads"], |
|
config["num_layers"]) |
|
|
|
def generate_adj_gat(self, index, batch, intent_label_num): |
|
intent_idx_ = [[torch.tensor(0)] for i in range(batch)] |
|
for item in index: |
|
intent_idx_[item[0]].append(item[1] + 1) |
|
intent_idx = intent_idx_ |
|
self.adj = torch.cat([torch.eye(intent_label_num + 1).unsqueeze(0) for i in range(batch)]) |
|
for i in range(batch): |
|
for j in intent_idx[i]: |
|
self.adj[i, j, intent_idx[i]] = 1. |
|
if self.config["row_normalized"]: |
|
self.adj = normalize_adj(self.adj) |
|
self.adj = self.adj.to(self.intent_embedding.device) |
|
|
|
def forward(self, encode_hidden, **interaction_args): |
|
if self.adj is None or interaction_args["sent_id"] == 0: |
|
self.generate_adj_gat(interaction_args["intent_index"], interaction_args["batch_size"], interaction_args["intent_label_num"]) |
|
lstm_out = torch.cat((encode_hidden, |
|
self.intent_embedding.unsqueeze(0).repeat(encode_hidden.shape[0], 1, 1)), dim=1) |
|
return self.graph(lstm_out, self.adj[interaction_args["sent_id"]]) |
|
|