|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
class Graph(): |
|
""" The Graph to model the skeletons |
|
|
|
Args: |
|
strategy (string): must be one of the follow candidates |
|
- uniform: Uniform Labeling |
|
- distance: Distance Partitioning |
|
- spatial: Spatial Configuration |
|
max_hop (int): the maximal distance between two connected nodes |
|
dilation (int): controls the spacing between the kernel points |
|
|
|
""" |
|
def __init__(self, |
|
strategy='spatial', |
|
max_hop=1, |
|
dilation=1): |
|
self.max_hop = max_hop |
|
self.dilation = dilation |
|
|
|
self.get_edge() |
|
self.hop_dis = get_hop_distance(self.num_node, |
|
self.edge, |
|
max_hop=max_hop) |
|
self.get_adjacency(strategy) |
|
|
|
def __str__(self): |
|
return self.A |
|
|
|
def get_edge(self): |
|
|
|
self.num_node = 22 |
|
self_link = [(i, i) for i in range(self.num_node)] |
|
neighbor_link = [(1,0), (2,1), (3,2), (4,3), (5,0), (6,5), (7,6), (8,7), (9,0), (10,9), (11,10), (12,11), \ |
|
(13,12), (14,11), (15,14), (16,15), (17,16), (18,11), (19,18), (20,19), (21,20)] |
|
self.edge = self_link + neighbor_link |
|
self.center = 0 |
|
|
|
def get_adjacency(self, strategy): |
|
valid_hop = range(0, self.max_hop + 1, self.dilation) |
|
adjacency = np.zeros((self.num_node, self.num_node)) |
|
for hop in valid_hop: |
|
adjacency[self.hop_dis == hop] = 1 |
|
normalize_adjacency = normalize_digraph(adjacency) |
|
|
|
if strategy == 'uniform': |
|
A = np.zeros((1, self.num_node, self.num_node)) |
|
A[0] = normalize_adjacency |
|
self.A = A |
|
elif strategy == 'distance': |
|
A = np.zeros((len(valid_hop), self.num_node, self.num_node)) |
|
for i, hop in enumerate(valid_hop): |
|
A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == |
|
hop] |
|
self.A = A |
|
elif strategy == 'spatial': |
|
A = [] |
|
for hop in valid_hop: |
|
a_root = np.zeros((self.num_node, self.num_node)) |
|
a_close = np.zeros((self.num_node, self.num_node)) |
|
a_further = np.zeros((self.num_node, self.num_node)) |
|
for i in range(self.num_node): |
|
for j in range(self.num_node): |
|
if self.hop_dis[j, i] == hop: |
|
if self.hop_dis[j, self.center] == self.hop_dis[ |
|
i, self.center]: |
|
a_root[j, i] = normalize_adjacency[j, i] |
|
elif self.hop_dis[j, self.center] > self.hop_dis[ |
|
i, self.center]: |
|
a_close[j, i] = normalize_adjacency[j, i] |
|
else: |
|
a_further[j, i] = normalize_adjacency[j, i] |
|
if hop == 0: |
|
A.append(a_root) |
|
else: |
|
A.append(a_root + a_close) |
|
A.append(a_further) |
|
A = np.stack(A) |
|
self.A = A |
|
else: |
|
raise ValueError("Do Not Exist This Strategy") |
|
|
|
def get_hop_distance(num_node, edge, max_hop=1): |
|
A = np.zeros((num_node, num_node)) |
|
for i, j in edge: |
|
A[j, i] = 1 |
|
A[i, j] = 1 |
|
|
|
|
|
hop_dis = np.zeros((num_node, num_node)) + np.inf |
|
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] |
|
arrive_mat = (np.stack(transfer_mat) > 0) |
|
for d in range(max_hop, -1, -1): |
|
hop_dis[arrive_mat[d]] = d |
|
return hop_dis |
|
|
|
def normalize_digraph(A): |
|
Dl = np.sum(A, 0) |
|
num_node = A.shape[0] |
|
Dn = np.zeros((num_node, num_node)) |
|
for i in range(num_node): |
|
if Dl[i] > 0: |
|
Dn[i, i] = Dl[i]**(-1) |
|
AD = np.dot(A, Dn) |
|
return AD |
|
|
|
def normalize_undigraph(A): |
|
Dl = np.sum(A, 0) |
|
num_node = A.shape[0] |
|
Dn = np.zeros((num_node, num_node)) |
|
for i in range(num_node): |
|
if Dl[i] > 0: |
|
Dn[i, i] = Dl[i]**(-0.5) |
|
DAD = np.dot(np.dot(Dn, A), Dn) |
|
return DAD |
|
|
|
def zero(x): |
|
return 0 |
|
|
|
def iden(x): |
|
return x |
|
|
|
class ConvTemporalGraphical(nn.Module): |
|
r"""The basic module for applying a graph convolution. |
|
|
|
Args: |
|
in_channels (int): Number of channels in the input sequence data |
|
out_channels (int): Number of channels produced by the convolution |
|
kernel_size (int): Size of the graph convolving kernel |
|
t_kernel_size (int): Size of the temporal convolving kernel |
|
t_stride (int, optional): Stride of the temporal convolution. Default: 1 |
|
t_padding (int, optional): Temporal zero-padding added to both sides of |
|
the input. Default: 0 |
|
t_dilation (int, optional): Spacing between temporal kernel elements. |
|
Default: 1 |
|
bias (bool, optional): If ``True``, adds a learnable bias to the output. |
|
Default: ``True`` |
|
|
|
Shape: |
|
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format |
|
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format |
|
- Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format |
|
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format |
|
|
|
where |
|
:math:`N` is a batch size, |
|
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, |
|
:math:`T_{in}/T_{out}` is a length of input/output sequence, |
|
:math:`V` is the number of graph nodes. |
|
""" |
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
t_kernel_size=1, |
|
t_stride=1, |
|
t_padding=0, |
|
t_dilation=1, |
|
bias=True): |
|
super().__init__() |
|
|
|
self.kernel_size = kernel_size |
|
self.conv = nn.Conv2d(in_channels, |
|
out_channels * kernel_size, |
|
kernel_size=(t_kernel_size, 1), |
|
padding=(t_padding, 0), |
|
stride=(t_stride, 1), |
|
dilation=(t_dilation, 1), |
|
bias=bias) |
|
|
|
def forward(self, x, A): |
|
assert A.size(0) == self.kernel_size |
|
|
|
x = self.conv(x) |
|
|
|
n, kc, t, v = x.size() |
|
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v) |
|
x = torch.einsum('nkctv,kvw->nctw', (x, A)) |
|
|
|
return x.contiguous(), A |
|
|
|
class st_gcn_block(nn.Module): |
|
r"""Applies a spatial temporal graph convolution over an input graph sequence. |
|
|
|
Args: |
|
in_channels (int): Number of channels in the input sequence data |
|
out_channels (int): Number of channels produced by the convolution |
|
kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel |
|
stride (int, optional): Stride of the temporal convolution. Default: 1 |
|
dropout (int, optional): Dropout rate of the final output. Default: 0 |
|
residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True`` |
|
|
|
Shape: |
|
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format |
|
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format |
|
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format |
|
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format |
|
|
|
where |
|
:math:`N` is a batch size, |
|
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, |
|
:math:`T_{in}/T_{out}` is a length of input/output sequence, |
|
:math:`V` is the number of graph nodes. |
|
|
|
""" |
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
dropout=0, |
|
residual=True): |
|
super().__init__() |
|
|
|
assert len(kernel_size) == 2 |
|
assert kernel_size[0] % 2 == 1 |
|
padding = ((kernel_size[0] - 1) // 2, 0) |
|
|
|
self.gcn = ConvTemporalGraphical(in_channels, out_channels, |
|
kernel_size[1]) |
|
|
|
self.tcn = nn.Sequential( |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d( |
|
out_channels, |
|
out_channels, |
|
(kernel_size[0], 1), |
|
(stride, 1), |
|
padding, |
|
), |
|
nn.BatchNorm2d(out_channels), |
|
nn.Dropout(dropout, inplace=True), |
|
) |
|
|
|
if not residual: |
|
self.residual = zero |
|
|
|
elif (in_channels == out_channels) and (stride == 1): |
|
self.residual = iden |
|
|
|
else: |
|
self.residual = nn.Sequential( |
|
nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=(stride, 1)), |
|
nn.BatchNorm2d(out_channels), |
|
) |
|
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
def forward(self, x, A): |
|
|
|
res = self.residual(x) |
|
x, A = self.gcn(x, A) |
|
x = self.tcn(x) + res |
|
|
|
return self.relu(x), A |
|
|
|
class ST_GCN_18(nn.Module): |
|
r"""Spatial temporal graph convolutional networks. |
|
|
|
Args: |
|
in_channels (int): Number of channels in the input data |
|
num_class (int): Number of classes for the classification task |
|
graph_cfg (dict): The arguments for building the graph |
|
edge_importance_weighting (bool): If ``True``, adds a learnable |
|
importance weighting to the edges of the graph |
|
**kwargs (optional): Other parameters for graph convolution units |
|
|
|
Shape: |
|
- Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})` |
|
- Output: :math:`(N, num_class)` where |
|
:math:`N` is a batch size, |
|
:math:`T_{in}` is a length of input sequence, |
|
:math:`V_{in}` is the number of graph nodes, |
|
:math:`M_{in}` is the number of instance in a frame. |
|
""" |
|
def __init__(self, |
|
in_channels, |
|
edge_importance_weighting=True, |
|
data_bn=True, |
|
**kwargs): |
|
super().__init__() |
|
|
|
|
|
self.graph = Graph() |
|
A = torch.tensor(self.graph.A, |
|
dtype=torch.float32, |
|
requires_grad=False) |
|
self.register_buffer('A', A) |
|
|
|
|
|
spatial_kernel_size = A.size(0) |
|
temporal_kernel_size = 9 |
|
kernel_size = (temporal_kernel_size, spatial_kernel_size) |
|
self.data_bn = nn.BatchNorm1d(in_channels * |
|
A.size(1)) if data_bn else iden |
|
kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} |
|
self.st_gcn_networks = nn.ModuleList(( |
|
st_gcn_block(in_channels, |
|
64, |
|
kernel_size, |
|
1, |
|
residual=False, |
|
**kwargs0), |
|
st_gcn_block(64, 64, kernel_size, 1, **kwargs), |
|
st_gcn_block(64, 64, kernel_size, 1, **kwargs), |
|
st_gcn_block(64, 64, kernel_size, 1, **kwargs), |
|
st_gcn_block(64, 128, kernel_size, 2, **kwargs), |
|
st_gcn_block(128, 128, kernel_size, 1, **kwargs), |
|
st_gcn_block(128, 128, kernel_size, 1, **kwargs), |
|
st_gcn_block(128, 256, kernel_size, 2, **kwargs), |
|
st_gcn_block(256, 256, kernel_size, 1, **kwargs), |
|
st_gcn_block(256, 512, kernel_size, 1, **kwargs), |
|
)) |
|
|
|
|
|
if edge_importance_weighting: |
|
self.edge_importance = nn.ParameterList([ |
|
nn.Parameter(torch.ones(self.A.size())) |
|
for i in self.st_gcn_networks |
|
]) |
|
else: |
|
self.edge_importance = [1] * len(self.st_gcn_networks) |
|
|
|
def forward(self, x): |
|
|
|
N, C, T, V, M = x.size() |
|
x = x.permute(0, 4, 3, 1, 2).contiguous() |
|
x = x.view(N * M, V * C, T) |
|
x = self.data_bn(x) |
|
x = x.view(N, M, V, C, T) |
|
x = x.permute(0, 1, 3, 4, 2).contiguous() |
|
x = x.view(N * M, C, T, V) |
|
|
|
|
|
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): |
|
x, _ = gcn(x, self.A * importance) |
|
|
|
|
|
x = F.avg_pool2d(x, x.size()[2:]) |
|
x = x.view(N, M, -1, 1, 1).mean(dim=1) |
|
|
|
return x |