|
""" |
|
Copyright 2023 LINE Corporation |
|
LINE Corporation licenses this file to you under the Apache License, |
|
version 2.0 (the "License"); you may not use this file except in compliance |
|
with the License. You may obtain a copy of the License at: |
|
https://www.apache.org/licenses/LICENSE-2.0 |
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
|
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
|
License for the specific language governing permissions and limitations |
|
under the License. |
|
""" |
|
|
|
import math |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, reduce, repeat |
|
from torch.autograd import Variable |
|
|
|
|
|
def import_class(name): |
|
components = name.split(".") |
|
mod = __import__(components[0]) |
|
for comp in components[1:]: |
|
mod = getattr(mod, comp) |
|
return mod |
|
|
|
|
|
def conv_branch_init(conv, branches): |
|
weight = conv.weight |
|
n = weight.size(0) |
|
k1 = weight.size(1) |
|
k2 = weight.size(2) |
|
nn.init.normal_(weight, 0, math.sqrt(2.0 / (n * k1 * k2 * branches))) |
|
nn.init.constant_(conv.bias, 0) |
|
|
|
|
|
def conv_init(conv): |
|
nn.init.kaiming_normal_(conv.weight, mode="fan_out") |
|
nn.init.constant_(conv.bias, 0) |
|
|
|
|
|
def bn_init(bn, scale): |
|
nn.init.constant_(bn.weight, scale) |
|
nn.init.constant_(bn.bias, 0) |
|
|
|
|
|
class unit_tcn(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1): |
|
super(unit_tcn, self).__init__() |
|
pad = int((kernel_size - 1) / 2) |
|
self.conv = nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=(kernel_size, 1), |
|
padding=(pad, 0), |
|
stride=(stride, 1), |
|
) |
|
|
|
self.bn = nn.BatchNorm2d(out_channels) |
|
self.relu = nn.ReLU() |
|
conv_init(self.conv) |
|
bn_init(self.bn, 1) |
|
|
|
def forward(self, x): |
|
x = self.bn(self.conv(x)) |
|
return x |
|
|
|
|
|
class unit_gcn(nn.Module): |
|
def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3): |
|
super(unit_gcn, self).__init__() |
|
inter_channels = out_channels // coff_embedding |
|
self.inter_c = inter_channels |
|
self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32))) |
|
nn.init.constant_(self.PA, 1e-6) |
|
self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False) |
|
self.num_subset = num_subset |
|
|
|
self.conv_a = nn.ModuleList() |
|
self.conv_b = nn.ModuleList() |
|
self.conv_d = nn.ModuleList() |
|
for i in range(self.num_subset): |
|
self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1)) |
|
self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1)) |
|
self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1)) |
|
|
|
if in_channels != out_channels: |
|
self.down = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, 1), nn.BatchNorm2d(out_channels) |
|
) |
|
else: |
|
self.down = lambda x: x |
|
|
|
self.bn = nn.BatchNorm2d(out_channels) |
|
self.soft = nn.Softmax(-2) |
|
self.relu = nn.ReLU() |
|
|
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
conv_init(m) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
bn_init(m, 1) |
|
bn_init(self.bn, 1e-6) |
|
for i in range(self.num_subset): |
|
conv_branch_init(self.conv_d[i], self.num_subset) |
|
|
|
def forward(self, x): |
|
N, C, T, V = x.size() |
|
A = self.A |
|
if -1 != x.get_device(): |
|
A = A.cuda(x.get_device()) |
|
A = A + self.PA |
|
|
|
y = None |
|
for i in range(self.num_subset): |
|
A1 = ( |
|
self.conv_a[i](x) |
|
.permute(0, 3, 1, 2) |
|
.contiguous() |
|
.view(N, V, self.inter_c * T) |
|
) |
|
A2 = self.conv_b[i](x).view(N, self.inter_c * T, V) |
|
A1 = self.soft(torch.matmul(A1, A2) / A1.size(-1)) |
|
A1 = A1 + A[i] |
|
A2 = x.view(N, C * T, V) |
|
z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V)) |
|
y = z + y if y is not None else z |
|
|
|
y = self.bn(y) |
|
y += self.down(x) |
|
return self.relu(y) |
|
|
|
|
|
class TCN_GCN_unit(nn.Module): |
|
def __init__(self, in_channels, out_channels, A, stride=1, residual=True): |
|
super(TCN_GCN_unit, self).__init__() |
|
self.gcn1 = unit_gcn(in_channels, out_channels, A) |
|
self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride) |
|
self.relu = nn.ReLU() |
|
if not residual: |
|
self.residual = lambda x: 0 |
|
|
|
elif (in_channels == out_channels) and (stride == 1): |
|
self.residual = lambda x: x |
|
|
|
else: |
|
self.residual = unit_tcn( |
|
in_channels, out_channels, kernel_size=1, stride=stride |
|
) |
|
|
|
def forward(self, x): |
|
x = self.tcn1(self.gcn1(x)) + self.residual(x) |
|
return self.relu(x) |
|
|
|
|
|
class Classifier(nn.Module): |
|
def __init__(self, num_class=60, scale_factor=5.0, temperature=[1.0, 2.0, 5.0]): |
|
super(Classifier, self).__init__() |
|
|
|
|
|
self.ac_center = nn.Parameter(torch.zeros(num_class + 1, 256)) |
|
nn.init.xavier_uniform_(self.ac_center) |
|
|
|
|
|
self.temperature = temperature |
|
self.scale_factor = scale_factor |
|
|
|
def forward(self, x): |
|
|
|
N = x.size(0) |
|
|
|
x_emb = reduce(x, "(n m) c t v -> n t c", "mean", n=N) |
|
|
|
norms_emb = F.normalize(x_emb, dim=2) |
|
norms_ac = F.normalize(self.ac_center) |
|
|
|
|
|
frm_scrs = ( |
|
torch.einsum("ntd,cd->ntc", [norms_emb, norms_ac]) * self.scale_factor |
|
) |
|
|
|
|
|
class_wise_atts = [F.softmax(frm_scrs * t, 1) for t in self.temperature] |
|
|
|
|
|
|
|
mid_vid_scrs = [ |
|
torch.einsum("ntc,ntc->nc", [frm_scrs, att]) for att in class_wise_atts |
|
] |
|
mil_vid_scr = ( |
|
torch.stack(mid_vid_scrs, -1).mean(-1) * 2.0 |
|
) |
|
mil_vid_pred = F.sigmoid(mil_vid_scr) |
|
|
|
return mil_vid_pred, frm_scrs |
|
|
|
|
|
class Model(nn.Module): |
|
def __init__( |
|
self, |
|
num_class=60, |
|
num_point=25, |
|
num_person=1, |
|
graph=None, |
|
graph_args=dict(), |
|
in_channels=2, |
|
scale_factor=5.0, |
|
temperature=[1.0, 2.0, 5.0], |
|
): |
|
super(Model, self).__init__() |
|
|
|
if graph is None: |
|
raise ValueError() |
|
else: |
|
Graph = import_class(graph) |
|
self.graph = Graph(**graph_args) |
|
|
|
A = self.graph.A |
|
self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point) |
|
|
|
self.l1 = TCN_GCN_unit(3, 64, A, residual=False) |
|
self.l2 = TCN_GCN_unit(64, 64, A, stride=2) |
|
self.l3 = TCN_GCN_unit(64, 64, A) |
|
self.l4 = TCN_GCN_unit(64, 64, A) |
|
self.l5 = TCN_GCN_unit(64, 128, A, stride=2) |
|
self.l6 = TCN_GCN_unit(128, 128, A) |
|
self.l7 = TCN_GCN_unit(128, 128, A) |
|
self.l8 = TCN_GCN_unit(128, 256, A, stride=2) |
|
self.l9 = TCN_GCN_unit(256, 256, A) |
|
self.l10 = TCN_GCN_unit(256, 256, A) |
|
|
|
bn_init(self.data_bn, 1) |
|
|
|
self.classifier_1 = Classifier(num_class, scale_factor, temperature) |
|
|
|
self.classifier_2 = Classifier(num_class, scale_factor, temperature) |
|
|
|
def forward(self, x,mask): |
|
N, C, T, V, M = x.size() |
|
|
|
x = rearrange(x, "n c t v m -> n (m v c) t") |
|
|
|
x = rearrange(x, "n (m v c) t -> (n m) c t v", m=M, v=V, c=C) |
|
|
|
x = self.l1(x) |
|
x = self.l2(x) |
|
x = self.l3(x) |
|
x = self.l4(x) |
|
x = self.l5(x) |
|
x = self.l6(x) |
|
x = self.l7(x) |
|
x = self.l8(x) |
|
x = self.l9(x) |
|
x = self.l10(x) |
|
|
|
mil_vid_pred_1, frm_scrs_1 = self.classifier_1(x) |
|
|
|
mil_vid_pred_2, frm_scrs_2 = self.classifier_2(x.detach()) |
|
|
|
|
|
|
|
frm_scrs_1 = rearrange(frm_scrs_1, "n t c -> n c t") |
|
frm_scrs_1 = F.interpolate( |
|
frm_scrs_1, size=(T), mode="linear", align_corners=True |
|
) |
|
frm_scrs_1 = rearrange(frm_scrs_1, "n c t -> n t c") |
|
|
|
frm_scrs_2 = rearrange(frm_scrs_2, "n t c -> n c t") |
|
frm_scrs_2 = F.interpolate( |
|
frm_scrs_2, size=(T), mode="linear", align_corners=True |
|
) |
|
frm_scrs_2 = rearrange(frm_scrs_2, "n c t -> n t c") |
|
|
|
return mil_vid_pred_1, frm_scrs_1, mil_vid_pred_2, frm_scrs_2 |
|
|