import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class Graph(): """ The Graph to model the skeletons Args: strategy (string): must be one of the follow candidates - uniform: Uniform Labeling - distance: Distance Partitioning - spatial: Spatial Configuration max_hop (int): the maximal distance between two connected nodes dilation (int): controls the spacing between the kernel points """ def __init__(self, strategy='spatial', max_hop=1, dilation=1): self.max_hop = max_hop self.dilation = dilation self.get_edge() self.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop) self.get_adjacency(strategy) def __str__(self): return self.A def get_edge(self): # edge is a list of [child, parent] paris self.num_node = 22 self_link = [(i, i) for i in range(self.num_node)] neighbor_link = [(1,0), (2,1), (3,2), (4,3), (5,0), (6,5), (7,6), (8,7), (9,0), (10,9), (11,10), (12,11), \ (13,12), (14,11), (15,14), (16,15), (17,16), (18,11), (19,18), (20,19), (21,20)] self.edge = self_link + neighbor_link self.center = 0 def get_adjacency(self, strategy): valid_hop = range(0, self.max_hop + 1, self.dilation) adjacency = np.zeros((self.num_node, self.num_node)) for hop in valid_hop: adjacency[self.hop_dis == hop] = 1 normalize_adjacency = normalize_digraph(adjacency) if strategy == 'uniform': A = np.zeros((1, self.num_node, self.num_node)) A[0] = normalize_adjacency self.A = A elif strategy == 'distance': A = np.zeros((len(valid_hop), self.num_node, self.num_node)) for i, hop in enumerate(valid_hop): A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] self.A = A elif strategy == 'spatial': A = [] for hop in valid_hop: a_root = np.zeros((self.num_node, self.num_node)) a_close = np.zeros((self.num_node, self.num_node)) a_further = np.zeros((self.num_node, self.num_node)) for i in range(self.num_node): for j in range(self.num_node): if self.hop_dis[j, i] == hop: if self.hop_dis[j, self.center] == self.hop_dis[ i, self.center]: a_root[j, i] = normalize_adjacency[j, i] elif self.hop_dis[j, self.center] > self.hop_dis[ i, self.center]: a_close[j, i] = normalize_adjacency[j, i] else: a_further[j, i] = normalize_adjacency[j, i] if hop == 0: A.append(a_root) else: A.append(a_root + a_close) A.append(a_further) A = np.stack(A) self.A = A else: raise ValueError("Do Not Exist This Strategy") def get_hop_distance(num_node, edge, max_hop=1): A = np.zeros((num_node, num_node)) for i, j in edge: A[j, i] = 1 A[i, j] = 1 # compute hop steps hop_dis = np.zeros((num_node, num_node)) + np.inf transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] arrive_mat = (np.stack(transfer_mat) > 0) for d in range(max_hop, -1, -1): hop_dis[arrive_mat[d]] = d return hop_dis def normalize_digraph(A): Dl = np.sum(A, 0) num_node = A.shape[0] Dn = np.zeros((num_node, num_node)) for i in range(num_node): if Dl[i] > 0: Dn[i, i] = Dl[i]**(-1) AD = np.dot(A, Dn) return AD def normalize_undigraph(A): Dl = np.sum(A, 0) num_node = A.shape[0] Dn = np.zeros((num_node, num_node)) for i in range(num_node): if Dl[i] > 0: Dn[i, i] = Dl[i]**(-0.5) DAD = np.dot(np.dot(Dn, A), Dn) return DAD def zero(x): return 0 def iden(x): return x class ConvTemporalGraphical(nn.Module): r"""The basic module for applying a graph convolution. Args: in_channels (int): Number of channels in the input sequence data out_channels (int): Number of channels produced by the convolution kernel_size (int): Size of the graph convolving kernel t_kernel_size (int): Size of the temporal convolving kernel t_stride (int, optional): Stride of the temporal convolution. Default: 1 t_padding (int, optional): Temporal zero-padding added to both sides of the input. Default: 0 t_dilation (int, optional): Spacing between temporal kernel elements. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` Shape: - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format - Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format where :math:`N` is a batch size, :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, :math:`T_{in}/T_{out}` is a length of input/output sequence, :math:`V` is the number of graph nodes. """ def __init__(self, in_channels, out_channels, kernel_size, t_kernel_size=1, t_stride=1, t_padding=0, t_dilation=1, bias=True): super().__init__() self.kernel_size = kernel_size self.conv = nn.Conv2d(in_channels, out_channels * kernel_size, kernel_size=(t_kernel_size, 1), padding=(t_padding, 0), stride=(t_stride, 1), dilation=(t_dilation, 1), bias=bias) def forward(self, x, A): assert A.size(0) == self.kernel_size x = self.conv(x) n, kc, t, v = x.size() x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v) x = torch.einsum('nkctv,kvw->nctw', (x, A)) return x.contiguous(), A class st_gcn_block(nn.Module): r"""Applies a spatial temporal graph convolution over an input graph sequence. Args: in_channels (int): Number of channels in the input sequence data out_channels (int): Number of channels produced by the convolution kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel stride (int, optional): Stride of the temporal convolution. Default: 1 dropout (int, optional): Dropout rate of the final output. Default: 0 residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True`` Shape: - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format where :math:`N` is a batch size, :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, :math:`T_{in}/T_{out}` is a length of input/output sequence, :math:`V` is the number of graph nodes. """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dropout=0, residual=True): super().__init__() assert len(kernel_size) == 2 assert kernel_size[0] % 2 == 1 padding = ((kernel_size[0] - 1) // 2, 0) self.gcn = ConvTemporalGraphical(in_channels, out_channels, kernel_size[1]) self.tcn = nn.Sequential( nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d( out_channels, out_channels, (kernel_size[0], 1), (stride, 1), padding, ), nn.BatchNorm2d(out_channels), nn.Dropout(dropout, inplace=True), ) if not residual: self.residual = zero elif (in_channels == out_channels) and (stride == 1): self.residual = iden else: self.residual = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=(stride, 1)), nn.BatchNorm2d(out_channels), ) self.relu = nn.ReLU(inplace=True) def forward(self, x, A): res = self.residual(x) x, A = self.gcn(x, A) x = self.tcn(x) + res return self.relu(x), A class ST_GCN_18(nn.Module): r"""Spatial temporal graph convolutional networks. Args: in_channels (int): Number of channels in the input data num_class (int): Number of classes for the classification task graph_cfg (dict): The arguments for building the graph edge_importance_weighting (bool): If ``True``, adds a learnable importance weighting to the edges of the graph **kwargs (optional): Other parameters for graph convolution units Shape: - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})` - Output: :math:`(N, num_class)` where :math:`N` is a batch size, :math:`T_{in}` is a length of input sequence, :math:`V_{in}` is the number of graph nodes, :math:`M_{in}` is the number of instance in a frame. """ def __init__(self, in_channels, edge_importance_weighting=True, data_bn=True, **kwargs): super().__init__() # load graph self.graph = Graph() A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) self.register_buffer('A', A) # build networks spatial_kernel_size = A.size(0) temporal_kernel_size = 9 kernel_size = (temporal_kernel_size, spatial_kernel_size) self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) if data_bn else iden kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} self.st_gcn_networks = nn.ModuleList(( st_gcn_block(in_channels, 64, kernel_size, 1, residual=False, **kwargs0), st_gcn_block(64, 64, kernel_size, 1, **kwargs), st_gcn_block(64, 64, kernel_size, 1, **kwargs), st_gcn_block(64, 64, kernel_size, 1, **kwargs), st_gcn_block(64, 128, kernel_size, 2, **kwargs), st_gcn_block(128, 128, kernel_size, 1, **kwargs), st_gcn_block(128, 128, kernel_size, 1, **kwargs), st_gcn_block(128, 256, kernel_size, 2, **kwargs), st_gcn_block(256, 256, kernel_size, 1, **kwargs), st_gcn_block(256, 512, kernel_size, 1, **kwargs), )) # initialize parameters for edge importance weighting if edge_importance_weighting: self.edge_importance = nn.ParameterList([ nn.Parameter(torch.ones(self.A.size())) for i in self.st_gcn_networks ]) else: self.edge_importance = [1] * len(self.st_gcn_networks) def forward(self, x): # data normalization N, C, T, V, M = x.size() x = x.permute(0, 4, 3, 1, 2).contiguous() x = x.view(N * M, V * C, T) x = self.data_bn(x) x = x.view(N, M, V, C, T) x = x.permute(0, 1, 3, 4, 2).contiguous() x = x.view(N * M, C, T, V) # forward for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): x, _ = gcn(x, self.A * importance) # global pooling x = F.avg_pool2d(x, x.size()[2:]) # (b, 512, t, joint) x = x.view(N, M, -1, 1, 1).mean(dim=1) return x