|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import math |
|
import paddle |
|
from paddle import nn |
|
import paddle.nn.functional as F |
|
from paddle import ParamAttr |
|
|
|
|
|
class SDMGRHead(nn.Layer): |
|
def __init__(self, |
|
in_channels, |
|
num_chars=92, |
|
visual_dim=16, |
|
fusion_dim=1024, |
|
node_input=32, |
|
node_embed=256, |
|
edge_input=5, |
|
edge_embed=256, |
|
num_gnn=2, |
|
num_classes=26, |
|
bidirectional=False): |
|
super().__init__() |
|
|
|
self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim) |
|
self.node_embed = nn.Embedding(num_chars, node_input, 0) |
|
hidden = node_embed // 2 if bidirectional else node_embed |
|
self.rnn = nn.LSTM( |
|
input_size=node_input, hidden_size=hidden, num_layers=1) |
|
self.edge_embed = nn.Linear(edge_input, edge_embed) |
|
self.gnn_layers = nn.LayerList( |
|
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]) |
|
self.node_cls = nn.Linear(node_embed, num_classes) |
|
self.edge_cls = nn.Linear(edge_embed, 2) |
|
|
|
def forward(self, input, targets): |
|
relations, texts, x = input |
|
node_nums, char_nums = [], [] |
|
for text in texts: |
|
node_nums.append(text.shape[0]) |
|
char_nums.append(paddle.sum((text > -1).astype(int), axis=-1)) |
|
|
|
max_num = max([char_num.max() for char_num in char_nums]) |
|
all_nodes = paddle.concat([ |
|
paddle.concat( |
|
[text, paddle.zeros( |
|
(text.shape[0], max_num - text.shape[1]))], -1) |
|
for text in texts |
|
]) |
|
temp = paddle.clip(all_nodes, min=0).astype(int) |
|
embed_nodes = self.node_embed(temp) |
|
rnn_nodes, _ = self.rnn(embed_nodes) |
|
|
|
b, h, w = rnn_nodes.shape |
|
nodes = paddle.zeros([b, w]) |
|
all_nums = paddle.concat(char_nums) |
|
valid = paddle.nonzero((all_nums > 0).astype(int)) |
|
temp_all_nums = ( |
|
paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1) |
|
temp_all_nums = paddle.expand(temp_all_nums, [ |
|
temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1] |
|
]) |
|
temp_all_nodes = paddle.gather(rnn_nodes, valid) |
|
N, C, A = temp_all_nodes.shape |
|
one_hot = F.one_hot( |
|
temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1]) |
|
one_hot = paddle.multiply( |
|
temp_all_nodes, one_hot.astype("float32")).sum(axis=1, keepdim=True) |
|
t = one_hot.expand([N, 1, A]).squeeze(1) |
|
nodes = paddle.scatter(nodes, valid.squeeze(1), t) |
|
|
|
if x is not None: |
|
nodes = self.fusion([x, nodes]) |
|
|
|
all_edges = paddle.concat( |
|
[rel.reshape([-1, rel.shape[-1]]) for rel in relations]) |
|
embed_edges = self.edge_embed(all_edges.astype('float32')) |
|
embed_edges = F.normalize(embed_edges) |
|
|
|
for gnn_layer in self.gnn_layers: |
|
nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums) |
|
|
|
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes) |
|
return node_cls, edge_cls |
|
|
|
|
|
class GNNLayer(nn.Layer): |
|
def __init__(self, node_dim=256, edge_dim=256): |
|
super().__init__() |
|
self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim) |
|
self.coef_fc = nn.Linear(node_dim, 1) |
|
self.out_fc = nn.Linear(node_dim, node_dim) |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, nodes, edges, nums): |
|
start, cat_nodes = 0, [] |
|
for num in nums: |
|
sample_nodes = nodes[start:start + num] |
|
cat_nodes.append( |
|
paddle.concat([ |
|
paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]), |
|
paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1]) |
|
], -1).reshape([num**2, -1])) |
|
start += num |
|
cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1) |
|
cat_nodes = self.relu(self.in_fc(cat_nodes)) |
|
coefs = self.coef_fc(cat_nodes) |
|
|
|
start, residuals = 0, [] |
|
for num in nums: |
|
residual = F.softmax( |
|
-paddle.eye(num).unsqueeze(-1) * 1e9 + |
|
coefs[start:start + num**2].reshape([num, num, -1]), 1) |
|
residuals.append((residual * cat_nodes[start:start + num**2] |
|
.reshape([num, num, -1])).sum(1)) |
|
start += num**2 |
|
|
|
nodes += self.relu(self.out_fc(paddle.concat(residuals))) |
|
return [nodes, cat_nodes] |
|
|
|
|
|
class Block(nn.Layer): |
|
def __init__(self, |
|
input_dims, |
|
output_dim, |
|
mm_dim=1600, |
|
chunks=20, |
|
rank=15, |
|
shared=False, |
|
dropout_input=0., |
|
dropout_pre_lin=0., |
|
dropout_output=0., |
|
pos_norm='before_cat'): |
|
super().__init__() |
|
self.rank = rank |
|
self.dropout_input = dropout_input |
|
self.dropout_pre_lin = dropout_pre_lin |
|
self.dropout_output = dropout_output |
|
assert (pos_norm in ['before_cat', 'after_cat']) |
|
self.pos_norm = pos_norm |
|
|
|
self.linear0 = nn.Linear(input_dims[0], mm_dim) |
|
self.linear1 = (self.linear0 |
|
if shared else nn.Linear(input_dims[1], mm_dim)) |
|
self.merge_linears0 = nn.LayerList() |
|
self.merge_linears1 = nn.LayerList() |
|
self.chunks = self.chunk_sizes(mm_dim, chunks) |
|
for size in self.chunks: |
|
ml0 = nn.Linear(size, size * rank) |
|
self.merge_linears0.append(ml0) |
|
ml1 = ml0 if shared else nn.Linear(size, size * rank) |
|
self.merge_linears1.append(ml1) |
|
self.linear_out = nn.Linear(mm_dim, output_dim) |
|
|
|
def forward(self, x): |
|
x0 = self.linear0(x[0]) |
|
x1 = self.linear1(x[1]) |
|
bs = x1.shape[0] |
|
if self.dropout_input > 0: |
|
x0 = F.dropout(x0, p=self.dropout_input, training=self.training) |
|
x1 = F.dropout(x1, p=self.dropout_input, training=self.training) |
|
x0_chunks = paddle.split(x0, self.chunks, -1) |
|
x1_chunks = paddle.split(x1, self.chunks, -1) |
|
zs = [] |
|
for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0, |
|
self.merge_linears1): |
|
m = m0(x0_c) * m1(x1_c) |
|
m = m.reshape([bs, self.rank, -1]) |
|
z = paddle.sum(m, 1) |
|
if self.pos_norm == 'before_cat': |
|
z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z)) |
|
z = F.normalize(z) |
|
zs.append(z) |
|
z = paddle.concat(zs, 1) |
|
if self.pos_norm == 'after_cat': |
|
z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z)) |
|
z = F.normalize(z) |
|
|
|
if self.dropout_pre_lin > 0: |
|
z = F.dropout(z, p=self.dropout_pre_lin, training=self.training) |
|
z = self.linear_out(z) |
|
if self.dropout_output > 0: |
|
z = F.dropout(z, p=self.dropout_output, training=self.training) |
|
return z |
|
|
|
def chunk_sizes(self, dim, chunks): |
|
split_size = (dim + chunks - 1) // chunks |
|
sizes_list = [split_size] * chunks |
|
sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim) |
|
return sizes_list |
|
|