|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
from ..layers import deformable_conv, SE |
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
|
|
class CNN_layer(nn.Module): |
|
def __init__(self, |
|
in_ch, |
|
out_ch, |
|
kernel_size, |
|
dropout, |
|
bias=True): |
|
super(CNN_layer, self).__init__() |
|
self.kernel_size = kernel_size |
|
padding = ( |
|
(kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) |
|
assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1 |
|
|
|
self.block1 = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, dilation=(1, 1)), |
|
nn.BatchNorm2d(out_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
] |
|
|
|
self.block1 = nn.Sequential(*self.block1) |
|
|
|
def forward(self, x): |
|
output = self.block1(x) |
|
return output |
|
|
|
|
|
class FPN(nn.Module): |
|
def __init__(self, in_ch, |
|
out_ch, |
|
kernel, |
|
dropout, |
|
reduction, |
|
): |
|
super(FPN, self).__init__() |
|
kernel_size = kernel if isinstance(kernel, (tuple, list)) else (kernel, kernel) |
|
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) |
|
pad1 = (padding[0], padding[1]) |
|
pad2 = (padding[0] + pad1[0], padding[1] + pad1[1]) |
|
pad3 = (padding[0] + pad2[0], padding[1] + pad2[1]) |
|
dil1 = (1, 1) |
|
dil2 = (1 + pad1[0], 1 + pad1[1]) |
|
dil3 = (1 + pad2[0], 1 + pad2[1]) |
|
self.block1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=pad1, dilation=dil1), |
|
nn.BatchNorm2d(out_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
) |
|
self.block2 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=pad2, dilation=dil2), |
|
nn.BatchNorm2d(out_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
) |
|
self.block3 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=pad3, dilation=dil3), |
|
nn.BatchNorm2d(out_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
) |
|
self.pooling = nn.AdaptiveAvgPool2d((1, 1)) |
|
self.compress = nn.Conv2d(out_ch * 3 + in_ch, |
|
out_ch, |
|
kernel_size=(1, 1)) |
|
|
|
def forward(self, x): |
|
b, dim, joints, seq = x.shape |
|
global_action = F.interpolate(self.pooling(x), (joints, seq)) |
|
out = torch.cat((self.block1(x), self.block2(x), self.block3(x), global_action), dim=1) |
|
out = self.compress(out) |
|
return out |
|
|
|
|
|
def mish(x): |
|
return (x * torch.tanh(F.softplus(x))) |
|
|
|
|
|
class ConvTemporalGraphical(nn.Module): |
|
|
|
r"""The basic module for applying a graph convolution. |
|
Args: |
|
Shape: |
|
- Input: Input graph sequence in :math:`(N, in_ch, T_{in}, V)` format |
|
- Output: Outpu graph sequence in :math:`(N, out_ch, T_{out}, V)` format |
|
where |
|
:math:`N` is a batch size, |
|
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, |
|
:math:`T_{in}/T_{out}` is a length of input/output sequence, |
|
:math:`V` is the number of graph nodes. |
|
""" |
|
|
|
def __init__(self, time_dim, joints_dim, domain, interpratable): |
|
super(ConvTemporalGraphical, self).__init__() |
|
|
|
if domain == "time": |
|
|
|
size = joints_dim |
|
if not interpratable: |
|
self.A = nn.Parameter(torch.FloatTensor(time_dim, size, size)) |
|
self.domain = 'nctv,tvw->nctw' |
|
else: |
|
self.domain = 'nctv,ntvw->nctw' |
|
elif domain == "space": |
|
size = time_dim |
|
if not interpratable: |
|
self.A = nn.Parameter(torch.FloatTensor(joints_dim, size, size)) |
|
self.domain = 'nctv,vtq->ncqv' |
|
else: |
|
self.domain = 'nctv,nvtq->ncqv' |
|
if not interpratable: |
|
stdv = 1. / math.sqrt(self.A.size(1)) |
|
self.A.data.uniform_(-stdv, stdv) |
|
|
|
def forward(self, x): |
|
x = torch.einsum(self.domain, (x, self.A)) |
|
return x.contiguous() |
|
|
|
|
|
class Map2Adj(nn.Module): |
|
def __init__(self, |
|
in_ch, |
|
time_dim, |
|
joints_dim, |
|
domain, |
|
dropout, |
|
): |
|
super(Map2Adj, self).__init__() |
|
self.domain = domain |
|
inter_ch = in_ch // 2 |
|
self.time_compress = nn.Sequential(nn.Conv2d(in_ch, inter_ch, kernel_size=1, bias=False), |
|
nn.BatchNorm2d(inter_ch), |
|
nn.PReLU(), |
|
nn.Conv2d(inter_ch, inter_ch, kernel_size=(time_dim, 1), bias=False), |
|
nn.BatchNorm2d(inter_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.Conv2d(inter_ch, time_dim, kernel_size=1, bias=False), |
|
) |
|
self.joint_compress = nn.Sequential(nn.Conv2d(in_ch, inter_ch, kernel_size=1, bias=False), |
|
nn.BatchNorm2d(inter_ch), |
|
nn.PReLU(), |
|
nn.Conv2d(inter_ch, inter_ch, kernel_size=(1, joints_dim), bias=False), |
|
nn.BatchNorm2d(inter_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.Conv2d(inter_ch, joints_dim, kernel_size=1, bias=False), |
|
) |
|
|
|
if self.domain == "space": |
|
ch = joints_dim |
|
self.perm1 = (0, 1, 2, 3) |
|
self.perm2 = (0, 3, 2, 1) |
|
if self.domain == "time": |
|
ch = time_dim |
|
self.perm1 = (0, 2, 1, 3) |
|
self.perm2 = (0, 1, 2, 3) |
|
|
|
inter_ch = ch |
|
self.expansor = nn.Sequential(nn.Conv2d(ch, inter_ch, kernel_size=1, bias=False), |
|
nn.BatchNorm2d(inter_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
nn.Conv2d(inter_ch, ch, kernel_size=1, bias=False), |
|
) |
|
self.time_compress.apply(self._init_weights) |
|
self.joint_compress.apply(self._init_weights) |
|
self.expansor.apply(self._init_weights) |
|
|
|
def _init_weights(self, m, gain=0.05): |
|
if isinstance(m, nn.Linear): |
|
torch.nn.init.xavier_uniform_(m.weight, gain=gain) |
|
if isinstance(m, (nn.Conv2d, nn.Conv1d)): |
|
torch.nn.init.xavier_normal_(m.weight, gain=gain) |
|
if isinstance(m, nn.PReLU): |
|
torch.nn.init.constant_(m.weight, 0.25) |
|
|
|
def forward(self, x): |
|
b, dims, seq, joints = x.shape |
|
dim_seq = self.time_compress(x) |
|
dim_space = self.joint_compress(x) |
|
o = torch.matmul(dim_space.permute(self.perm1), dim_seq.permute(self.perm2)) |
|
Adj = self.expansor(o) |
|
return Adj |
|
|
|
|
|
class Domain_GCNN_layer(nn.Module): |
|
""" |
|
Shape: |
|
- Input[0]: Input graph sequence in :math:`(N, in_ch, T_{in}, V)` format |
|
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format |
|
- Output[0]: Outpu graph sequence in :math:`(N, out_ch, T_{out}, V)` format |
|
where |
|
:math:`N` is a batch size, |
|
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, |
|
:math:`T_{in}/T_{out}` is a length of input/output sequence, |
|
:math:`V` is the number of graph nodes. |
|
:in_ch= dimension of coordinates |
|
: out_ch=dimension of coordinates |
|
+ |
|
""" |
|
|
|
def __init__(self, |
|
in_ch, |
|
out_ch, |
|
kernel_size, |
|
stride, |
|
time_dim, |
|
joints_dim, |
|
domain, |
|
interpratable, |
|
dropout, |
|
bias=True): |
|
|
|
super(Domain_GCNN_layer, self).__init__() |
|
self.kernel_size = kernel_size |
|
assert self.kernel_size[0] % 2 == 1 |
|
assert self.kernel_size[1] % 2 == 1 |
|
padding = ((self.kernel_size[0] - 1) // 2, (self.kernel_size[1] - 1) // 2) |
|
self.interpratable = interpratable |
|
self.domain = domain |
|
|
|
self.gcn = ConvTemporalGraphical(time_dim, joints_dim, domain, interpratable) |
|
self.tcn = nn.Sequential(nn.Conv2d(in_ch, |
|
out_ch, |
|
(self.kernel_size[0], self.kernel_size[1]), |
|
(stride, stride), |
|
padding, |
|
), |
|
nn.BatchNorm2d(out_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
) |
|
|
|
if stride != 1 or in_ch != out_ch: |
|
self.residual = nn.Sequential(nn.Conv2d(in_ch, |
|
out_ch, |
|
kernel_size=1, |
|
stride=(1, 1)), |
|
nn.BatchNorm2d(out_ch), |
|
) |
|
else: |
|
self.residual = nn.Identity() |
|
if self.interpratable: |
|
self.map_to_adj = Map2Adj(in_ch, |
|
time_dim, |
|
joints_dim, |
|
domain, |
|
dropout, |
|
) |
|
else: |
|
self.map_to_adj = nn.Identity() |
|
self.prelu = nn.PReLU() |
|
|
|
def forward(self, x): |
|
|
|
res = self.residual(x) |
|
self.Adj = self.map_to_adj(x) |
|
if self.interpratable: |
|
self.gcn.A = self.Adj |
|
x1 = self.gcn(x) |
|
x2 = self.tcn(x1) |
|
x3 = x2 + res |
|
x4 = self.prelu(x3) |
|
return x4 |
|
|
|
|
|
|
|
class DSTD_GC(nn.Module): |
|
""" |
|
Shape: |
|
- Input[0]: Input graph sequence in :math:`(N, in_ch, T_{in}, V)` format |
|
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format |
|
- Output[0]: Outpu graph sequence in :math:`(N, out_ch, T_{out}, V)` format |
|
where |
|
:math:`N` is a batch size, |
|
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, |
|
:math:`T_{in}/T_{out}` is a length of input/output sequence, |
|
:math:`V` is the number of graph nodes. |
|
: in_ch= dimension of coordinates |
|
: out_ch=dimension of coordinates |
|
+ |
|
""" |
|
|
|
def __init__(self, |
|
in_ch, |
|
out_ch, |
|
interpratable, |
|
kernel_size, |
|
stride, |
|
time_dim, |
|
joints_dim, |
|
reduction, |
|
dropout): |
|
super(DSTD_GC, self).__init__() |
|
self.dsgn = Domain_GCNN_layer(in_ch, out_ch, kernel_size, stride, |
|
time_dim, joints_dim, "space", interpratable, dropout) |
|
self.tsgn = Domain_GCNN_layer(in_ch, out_ch, kernel_size, stride, |
|
time_dim, joints_dim, "time", interpratable, dropout) |
|
|
|
self.compressor = nn.Sequential(nn.Conv2d(out_ch * 2, out_ch, 1, bias=False), |
|
nn.BatchNorm2d(out_ch), |
|
nn.PReLU(), |
|
SE.SELayer2d(out_ch, reduction=reduction), |
|
) |
|
if stride != 1 or in_ch != out_ch: |
|
self.residual = nn.Sequential(nn.Conv2d(in_ch, |
|
out_ch, |
|
kernel_size=1, |
|
stride=(1, 1)), |
|
nn.BatchNorm2d(out_ch), |
|
) |
|
else: |
|
self.residual = nn.Identity() |
|
|
|
|
|
out_ch_c = out_ch // 2 if out_ch // 2 > 1 else 1 |
|
self.global_norm = nn.BatchNorm2d(in_ch) |
|
self.conv_s = nn.Sequential(nn.Conv2d(in_ch, out_ch_c, (time_dim, 1), bias=False), |
|
nn.BatchNorm2d(out_ch_c), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
nn.Conv2d(out_ch_c, out_ch, (1, joints_dim), bias=False), |
|
nn.BatchNorm2d(out_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
) |
|
self.conv_t = nn.Sequential(nn.Conv2d(in_ch, out_ch_c, (time_dim, 1), bias=False), |
|
nn.BatchNorm2d(out_ch_c), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
nn.Conv2d(out_ch_c, out_ch, (1, joints_dim), bias=False), |
|
nn.BatchNorm2d(out_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
) |
|
self.map_s = nn.Sequential(nn.Linear(out_ch + 2 + time_dim * 2, out_ch, bias=False), |
|
nn.BatchNorm1d(out_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
nn.Linear(out_ch, out_ch, bias=False), |
|
) |
|
self.map_t = nn.Sequential(nn.Linear(out_ch + 2 + time_dim * 2, out_ch, bias=False), |
|
nn.BatchNorm1d(out_ch), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
nn.Linear(out_ch, out_ch, bias=False), |
|
) |
|
self.prelu1 = nn.Sequential(nn.BatchNorm2d(out_ch), |
|
nn.PReLU(), |
|
) |
|
self.prelu2 = nn.Sequential(nn.BatchNorm2d(out_ch), |
|
nn.PReLU(), |
|
) |
|
|
|
def _get_stats_(self, x): |
|
global_avg_pool = x.mean((3, 2)).mean(1, keepdims=True) |
|
global_avg_pool_features = x.mean(3).mean(1) |
|
global_std_pool = x.std((3, 2)).std(1, keepdims=True) |
|
global_std_pool_features = x.std(3).std(1) |
|
return torch.cat(( |
|
global_avg_pool, |
|
global_avg_pool_features, |
|
global_std_pool, |
|
global_std_pool_features, |
|
), |
|
dim=1) |
|
|
|
def forward(self, x): |
|
b, dim, seq, joints = x.shape |
|
xn = self.global_norm(x) |
|
|
|
stats = self._get_stats_(xn) |
|
w1 = torch.cat((self.conv_s(xn).view(b, -1), stats), dim=1) |
|
stats = self._get_stats_(xn) |
|
w2 = torch.cat((self.conv_t(xn).view(b, -1), stats), dim=1) |
|
self.w1 = self.map_s(w1) |
|
self.w2 = self.map_t(w2) |
|
w1 = self.w1[..., None, None] |
|
w2 = self.w2[..., None, None] |
|
|
|
x1 = self.dsgn(xn) |
|
x2 = self.tsgn(xn) |
|
out = torch.cat((self.prelu1(w1 * x1), self.prelu2(w2 * x2)), dim=1) |
|
out = self.compressor(out) |
|
return out + self.residual(xn) |
|
|
|
|
|
class ContextLayer(nn.Module): |
|
def __init__(self, |
|
in_ch, |
|
hidden_ch, |
|
output_seq, |
|
input_seq, |
|
joints, |
|
dims=3, |
|
reduction=8, |
|
dropout=0.1, |
|
): |
|
super(ContextLayer, self).__init__() |
|
self.n_output = output_seq |
|
self.n_joints = joints |
|
self.n_input = input_seq |
|
self.context_conv1 = nn.Sequential(nn.Conv2d(in_ch, hidden_ch, 1, bias=False), |
|
nn.BatchNorm2d(hidden_ch), |
|
nn.PReLU(), |
|
) |
|
|
|
self.context_conv2 = nn.Sequential(nn.Conv2d(in_ch, hidden_ch, (input_seq, 1), bias=False), |
|
nn.BatchNorm2d(hidden_ch), |
|
nn.PReLU(), |
|
) |
|
self.context_conv3 = nn.Sequential(nn.Conv2d(in_ch, hidden_ch, 1, bias=False), |
|
nn.BatchNorm2d(hidden_ch), |
|
nn.PReLU(), |
|
) |
|
self.map1 = nn.Sequential(nn.Linear(hidden_ch, self.n_output, bias=False), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
) |
|
self.map2 = nn.Sequential(nn.Linear(hidden_ch, self.n_output, bias=False), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
) |
|
self.map3 = nn.Sequential(nn.Linear(hidden_ch, self.n_output, bias=False), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
) |
|
|
|
self.fmap_s = nn.Sequential(nn.Linear(self.n_output * 3, self.n_joints, bias=False), |
|
nn.BatchNorm1d(self.n_joints), |
|
nn.Dropout(dropout, inplace=True), ) |
|
|
|
self.fmap_t = nn.Sequential(nn.Linear(self.n_output * 3, self.n_output, bias=False), |
|
nn.BatchNorm1d(self.n_output), |
|
nn.Dropout(dropout, inplace=True), ) |
|
|
|
|
|
self.norm_map = nn.Sequential(nn.Conv1d(self.n_output, self.n_output, 1, bias=False), |
|
nn.BatchNorm1d(self.n_output), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
SE.SELayer1d(self.n_output, reduction=reduction), |
|
nn.Conv1d(self.n_output, self.n_output, 1, bias=False), |
|
nn.BatchNorm1d(self.n_output), |
|
nn.Dropout(dropout, inplace=True), |
|
nn.PReLU(), |
|
) |
|
|
|
self.fconv = nn.Sequential(nn.Conv2d(1, dims, 1, bias=False), |
|
nn.BatchNorm2d(dims), |
|
nn.PReLU(), |
|
nn.Conv2d(dims, dims, 1, bias=False), |
|
nn.BatchNorm2d(dims), |
|
nn.PReLU(), |
|
) |
|
self.SE = SE.SELayer2d(self.n_output, reduction=reduction) |
|
|
|
def forward(self, x): |
|
b, _, seq, joint_dim = x.shape |
|
y1 = self.context_conv1(x).max(-1)[0].max(-1)[0] |
|
y2 = self.context_conv2(x).view(b, -1, joint_dim).max(-1)[0] |
|
ym = self.context_conv3(x).mean((2, 3)) |
|
y = torch.cat((self.map1(y1), self.map2(y2), self.map3(ym)), dim=1) |
|
self.joints = self.fmap_s(y) |
|
self.displacements = self.fmap_t(y) |
|
self.seq_joints = torch.bmm(self.displacements.unsqueeze(2), self.joints.unsqueeze(1)) |
|
self.seq_joints_n = self.norm_map(self.seq_joints) |
|
self.seq_joints_dims = self.fconv(self.seq_joints_n.view(b, 1, self.n_output, self.n_joints)) |
|
o = self.SE(self.seq_joints_dims.permute(0, 2, 3, 1)) |
|
return o |
|
|
|
|
|
class MlpMixer_ext(nn.Module): |
|
""" |
|
Shape: |
|
- Input[0]: Input sequence in :math:`(N, in_ch,T_in, V)` format |
|
- Output[0]: Output sequence in :math:`(N,T_out,in_ch, V)` format |
|
where |
|
:math:`N` is a batch size, |
|
:math:`T_{in}/T_{out}` is a length of input/output sequence, |
|
:math:`V` is the number of graph nodes. |
|
:in_ch=number of channels for the coordiantes(default=3) |
|
+ |
|
""" |
|
|
|
def __init__(self, arch, learn): |
|
super(MlpMixer_ext, self).__init__() |
|
self.clipping = arch.model_params.clipping |
|
|
|
self.n_input = arch.model_params.input_n |
|
self.n_output = arch.model_params.output_n |
|
self.n_joints = arch.model_params.joints |
|
self.n_txcnn_layers = arch.model_params.n_txcnn_layers |
|
self.txc_kernel_size = [arch.model_params.txc_kernel_size] * 2 |
|
self.input_gcn = arch.model_params.input_gcn |
|
self.output_gcn = arch.model_params.output_gcn |
|
self.reduction = arch.model_params.reduction |
|
self.hidden_dim = arch.model_params.hidden_dim |
|
|
|
self.st_gcnns = nn.ModuleList() |
|
self.txcnns = nn.ModuleList() |
|
self.se = nn.ModuleList() |
|
|
|
self.in_conv = nn.ModuleList() |
|
self.context_layer = nn.ModuleList() |
|
self.trans = nn.ModuleList() |
|
self.in_ch = 10 |
|
self.model_tx = self.input_gcn.model_complexity.copy() |
|
self.model_tx.insert(0, 1) |
|
|
|
self.input_gcn.model_complexity.insert(0, self.in_ch) |
|
self.input_gcn.model_complexity.append(self.in_ch) |
|
|
|
|
|
for i in range(len(self.input_gcn.model_complexity) - 1): |
|
self.st_gcnns.append(DSTD_GC(self.input_gcn.model_complexity[i], |
|
self.input_gcn.model_complexity[i + 1], |
|
self.input_gcn.interpretable[i], |
|
[1, 1], 1, self.n_input, self.n_joints, self.reduction, learn.dropout)) |
|
|
|
self.context_layer = ContextLayer(1, self.hidden_dim, |
|
self.n_output, self.n_output, self.n_joints, |
|
3, self.reduction, learn.dropout |
|
) |
|
|
|
|
|
|
|
self.txcnns.append(FPN(self.n_input, self.n_output, self.txc_kernel_size, 0., self.reduction)) |
|
for i in range(1, self.n_txcnn_layers): |
|
self.txcnns.append(FPN(self.n_output, self.n_output, self.txc_kernel_size, 0., self.reduction)) |
|
|
|
self.prelus = nn.ModuleList() |
|
for j in range(self.n_txcnn_layers): |
|
self.prelus.append(nn.PReLU()) |
|
|
|
self.dim_conversor = nn.Sequential(nn.Conv2d(self.in_ch, 3, 1, bias=False), |
|
nn.BatchNorm2d(3), |
|
nn.PReLU(), |
|
nn.Conv2d(3, 3, 1, bias=False), |
|
nn.PReLU(3), ) |
|
|
|
self.st_gcnns_o = nn.ModuleList() |
|
self.output_gcn.model_complexity.insert(0, 3) |
|
for i in range(len(self.output_gcn.model_complexity) - 1): |
|
self.st_gcnns_o.append(DSTD_GC(self.output_gcn.model_complexity[i], |
|
self.output_gcn.model_complexity[i + 1], |
|
self.output_gcn.interpretable[i], |
|
[1, 1], 1, self.n_joints, self.n_output, self.reduction, learn.dropout)) |
|
|
|
self.st_gcnns_o.apply(self._init_weights) |
|
self.st_gcnns.apply(self._init_weights) |
|
self.txcnns.apply(self._init_weights) |
|
|
|
def _init_weights(self, m, gain=0.1): |
|
if isinstance(m, nn.Linear): |
|
torch.nn.init.xavier_uniform_(m.weight, gain=gain) |
|
|
|
|
|
if isinstance(m, nn.PReLU): |
|
torch.nn.init.constant_(m.weight, 0.25) |
|
|
|
def forward(self, x): |
|
b, seq, joints, dim = x.shape |
|
vel = torch.zeros_like(x) |
|
vel[:, :-1] = torch.diff(x, dim=1) |
|
vel[:, -1] = x[:, -1] |
|
acc = torch.zeros_like(x) |
|
acc[:, :-1] = torch.diff(vel, dim=1) |
|
acc[:, -1] = vel[:, -1] |
|
x1 = torch.cat((x, acc, vel, torch.norm(vel, dim=-1, keepdim=True)), dim=-1) |
|
x2 = x1.permute((0, 3, 1, 2)) |
|
x3 = x2 |
|
|
|
for i in range(len(self.st_gcnns)): |
|
x3 = self.st_gcnns[i](x3) |
|
|
|
x5 = x3.permute(0, 2, 1, 3) |
|
|
|
x6 = self.prelus[0](self.txcnns[0](x5)) |
|
for i in range(1, self.n_txcnn_layers): |
|
x6 = self.prelus[i](self.txcnns[i](x6)) + x6 |
|
|
|
x6 = self.dim_conversor(x6.permute(0, 2, 1, 3)).permute(0, 2, 3, 1) |
|
x7 = x6.cumsum(1) |
|
|
|
act = self.context_layer(x7.reshape(b, 1, self.n_output, joints * x7.shape[-1])) |
|
x8 = x7.permute(0, 3, 2, 1) |
|
for i in range(len(self.st_gcnns_o)): |
|
x8 = self.st_gcnns_o[i](x8) |
|
x9 = x8.permute(0, 3, 2, 1) + act |
|
|
|
return x[:, -1:] + x9, |
|
|