mshukor
init
26fd00c
raw
history blame
No virus
6.6 kB
# https://github.com/kenshohara/video-classification-3d-cnn-pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial
__all__ = ['ResNeXt', 'resnet50', 'resnet101']
def conv3x3x3(in_planes, out_planes, stride=1):
# 3x3x3 convolution with padding
return nn.Conv3d(in_planes, out_planes, kernel_size=3,
stride=stride, padding=1, bias=False)
def downsample_basic_block(x, planes, stride):
out = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.Tensor(out.size(0), planes - out.size(1),
out.size(2), out.size(3),
out.size(4)).zero_()
if isinstance(out.data, torch.cuda.FloatTensor):
zero_pads = zero_pads.cuda()
out = Variable(torch.cat([out.data, zero_pads], dim=1))
return out
class ResNeXtBottleneck(nn.Module):
expansion = 2
def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None, norm_layer=nn.BatchNorm3d):
super(ResNeXtBottleneck, self).__init__()
mid_planes = cardinality * int(planes / 32)
self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(mid_planes)
self.conv2 = nn.Conv3d(mid_planes, mid_planes, kernel_size=3, stride=stride,
padding=1, groups=cardinality, bias=False)
self.bn2 = norm_layer(mid_planes)
self.conv3 = nn.Conv3d(mid_planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNeXt3D(nn.Module):
def __init__(self, block, layers, sample_size=16, sample_duration=112, shortcut_type='B', cardinality=32, num_classes=400, last_fc=True, norm_layer=None):
self.last_fc = last_fc
self.inplanes = 64
super(ResNeXt3D, self).__init__()
self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2),
padding=(3, 3, 3), bias=False)
if norm_layer is None:
norm_layer = nn.BatchNorm3d
print("use bn:", norm_layer)
self.bn1 = norm_layer(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
self.layer1 = self._make_layer(block, 128, layers[0], shortcut_type, cardinality, norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 256, layers[1], shortcut_type, cardinality, stride=2, norm_layer=norm_layer)
self.layer3 = self._make_layer(block, 512, layers[2], shortcut_type, cardinality, stride=2, norm_layer=norm_layer)
if len(layers) > 3:
self.layer4 = self._make_layer(block, 1024, layers[3], shortcut_type, cardinality, stride=2, norm_layer=norm_layer)
self.all_layers = True
else:
self.all_layers = False
last_duration = math.ceil(sample_duration / 16)
last_size = math.ceil(sample_size / 32)
self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1)
# self.fc = nn.Linear(cardinality * 32 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, norm_layer):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, shortcut_type, cardinality, stride=1, norm_layer=nn.BatchNorm3d):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if shortcut_type == 'A':
downsample = partial(downsample_basic_block,
planes=planes * block.expansion,
stride=stride)
else:
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
norm_layer(planes * block.expansion)
)
layers = []
layers.append(block(self.inplanes, planes, cardinality, stride, downsample, norm_layer=norm_layer))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, cardinality, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if self.all_layers:
x = self.layer4(x)
# x = self.avgpool(x)
# x = x.view(x.size(0), -1)
# if self.last_fc:
# x = self.fc(x)
return x, x
def get_fine_tuning_parameters(model, ft_begin_index):
if ft_begin_index == 0:
return model.parameters()
ft_module_names = []
for i in range(ft_begin_index, 5):
ft_module_names.append('layer{}'.format(ft_begin_index))
ft_module_names.append('fc')
parameters = []
for k, v in model.named_parameters():
for ft_module in ft_module_names:
if ft_module in k:
parameters.append({'params': v})
break
else:
parameters.append({'params': v, 'lr': 0.0})
return parameters
def resnet50(**kwargs):
"""Constructs a ResNet-50 model.
"""
model = ResNeXt3D(ResNeXtBottleneck, [3, 4, 6, 3], **kwargs)
return model
def resnet101(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNeXt3D(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs)
return model
def resnet152(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNeXt3D(ResNeXtBottleneck, [3, 8, 36, 3], **kwargs)
return model