|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import build_activation_layer
|
|
from mmengine.model import BaseModule, ModuleList, Sequential
|
|
|
|
from mmaction.models.utils import unit_tcn
|
|
|
|
|
|
|
|
|
|
class MSTCN(BaseModule):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
dilations=[1, 2, 3, 4],
|
|
residual=True,
|
|
act_cfg=dict(type='ReLU'),
|
|
init_cfg=[
|
|
dict(type='Constant', layer='BatchNorm2d', val=1),
|
|
dict(type='Kaiming', layer='Conv2d', mode='fan_out')
|
|
],
|
|
tcn_dropout=0):
|
|
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
self.num_branches = len(dilations) + 2
|
|
branch_channels = out_channels // self.num_branches
|
|
branch_channels_rem = out_channels - branch_channels * (
|
|
self.num_branches - 1)
|
|
|
|
if type(kernel_size) == list:
|
|
assert len(kernel_size) == len(dilations)
|
|
else:
|
|
kernel_size = [kernel_size] * len(dilations)
|
|
|
|
self.branches = ModuleList([
|
|
Sequential(
|
|
nn.Conv2d(
|
|
in_channels, branch_channels, kernel_size=1, padding=0),
|
|
nn.BatchNorm2d(branch_channels),
|
|
build_activation_layer(act_cfg),
|
|
unit_tcn(
|
|
branch_channels,
|
|
branch_channels,
|
|
kernel_size=ks,
|
|
stride=stride,
|
|
dilation=dilation),
|
|
) for ks, dilation in zip(kernel_size, dilations)
|
|
])
|
|
|
|
|
|
self.branches.append(
|
|
Sequential(
|
|
nn.Conv2d(
|
|
in_channels, branch_channels, kernel_size=1, padding=0),
|
|
nn.BatchNorm2d(branch_channels),
|
|
build_activation_layer(act_cfg),
|
|
nn.MaxPool2d(
|
|
kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)),
|
|
nn.BatchNorm2d(branch_channels)))
|
|
|
|
self.branches.append(
|
|
Sequential(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
branch_channels_rem,
|
|
kernel_size=1,
|
|
padding=0,
|
|
stride=(stride, 1)), nn.BatchNorm2d(branch_channels_rem)))
|
|
|
|
|
|
if not residual:
|
|
self.residual = lambda x: 0
|
|
elif (in_channels == out_channels) and (stride == 1):
|
|
self.residual = lambda x: x
|
|
else:
|
|
self.residual = unit_tcn(
|
|
in_channels, out_channels, kernel_size=1, stride=stride)
|
|
|
|
self.act = build_activation_layer(act_cfg)
|
|
self.drop = nn.Dropout(tcn_dropout)
|
|
|
|
def forward(self, x):
|
|
|
|
res = self.residual(x)
|
|
branch_outs = []
|
|
for tempconv in self.branches:
|
|
out = tempconv(x)
|
|
branch_outs.append(out)
|
|
|
|
out = torch.cat(branch_outs, dim=1)
|
|
out += res
|
|
out = self.act(out)
|
|
out = self.drop(out)
|
|
return out
|
|
|
|
|
|
class CTRGC(BaseModule):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
rel_reduction=8,
|
|
init_cfg=[
|
|
dict(type='Constant', layer='BatchNorm2d', val=1),
|
|
dict(type='Kaiming', layer='Conv2d', mode='fan_out')
|
|
]):
|
|
super(CTRGC, self).__init__(init_cfg=init_cfg)
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
if in_channels <= 16:
|
|
self.rel_channels = 8
|
|
else:
|
|
self.rel_channels = in_channels // rel_reduction
|
|
self.conv1 = nn.Conv2d(
|
|
self.in_channels, self.rel_channels, kernel_size=1)
|
|
self.conv2 = nn.Conv2d(
|
|
self.in_channels, self.rel_channels, kernel_size=1)
|
|
self.conv3 = nn.Conv2d(
|
|
self.in_channels, self.out_channels, kernel_size=1)
|
|
self.conv4 = nn.Conv2d(
|
|
self.rel_channels, self.out_channels, kernel_size=1)
|
|
self.tanh = nn.Tanh()
|
|
|
|
def forward(self, x, A=None, alpha=1):
|
|
|
|
x1, x2, x3 = self.conv1(x).mean(-2), self.conv2(x).mean(
|
|
-2), self.conv3(x)
|
|
|
|
|
|
x1 = self.tanh(x1.unsqueeze(-1) - x2.unsqueeze(-2))
|
|
|
|
x1 = self.conv4(x1) * alpha + (A[None, None] if A is not None else 0
|
|
)
|
|
x1 = torch.einsum('ncuv,nctu->nctv', x1, x3)
|
|
return x1
|
|
|
|
|
|
class unit_ctrgcn(BaseModule):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
A,
|
|
init_cfg=[
|
|
dict(
|
|
type='Constant',
|
|
layer='BatchNorm2d',
|
|
val=1,
|
|
override=dict(type='Constant', name='bn', val=1e-6)),
|
|
dict(type='Kaiming', layer='Conv2d', mode='fan_out')
|
|
]):
|
|
|
|
super(unit_ctrgcn, self).__init__(init_cfg=init_cfg)
|
|
inter_channels = out_channels // 4
|
|
self.inter_c = inter_channels
|
|
self.out_c = out_channels
|
|
self.in_c = in_channels
|
|
|
|
self.num_subset = A.shape[0]
|
|
self.convs = ModuleList()
|
|
|
|
for i in range(self.num_subset):
|
|
self.convs.append(CTRGC(in_channels, out_channels))
|
|
|
|
if in_channels != out_channels:
|
|
self.down = Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1),
|
|
nn.BatchNorm2d(out_channels))
|
|
else:
|
|
self.down = lambda x: x
|
|
|
|
self.A = nn.Parameter(A.clone())
|
|
|
|
self.alpha = nn.Parameter(torch.zeros(1))
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
self.soft = nn.Softmax(-2)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
def forward(self, x):
|
|
y = None
|
|
|
|
for i in range(self.num_subset):
|
|
z = self.convs[i](x, self.A[i], self.alpha)
|
|
y = z + y if y is not None else z
|
|
|
|
y = self.bn(y)
|
|
y += self.down(x)
|
|
return self.relu(y)
|
|
|