|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
from functools import partial |
|
|
|
|
|
class C3D(nn.Module): |
|
def __init__(self, |
|
sample_size, |
|
sample_duration, |
|
num_classes=600): |
|
super(C3D, self).__init__() |
|
self.group1 = nn.Sequential( |
|
nn.Conv3d(3, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(64), |
|
nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2))) |
|
self.group2 = nn.Sequential( |
|
nn.Conv3d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(128), |
|
nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))) |
|
self.group3 = nn.Sequential( |
|
nn.Conv3d(128, 256, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(256), |
|
nn.ReLU(), |
|
nn.Conv3d(256, 256, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(256), |
|
nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))) |
|
self.group4 = nn.Sequential( |
|
nn.Conv3d(256, 512, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(512), |
|
nn.ReLU(), |
|
nn.Conv3d(512, 512, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(512), |
|
nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))) |
|
self.group5 = nn.Sequential( |
|
nn.Conv3d(512, 512, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(512), |
|
nn.ReLU(), |
|
nn.Conv3d(512, 512, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(512), |
|
nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1))) |
|
|
|
last_duration = int(math.floor(sample_duration / 16)) |
|
last_size = int(math.ceil(sample_size / 32)) |
|
self.fc1 = nn.Sequential( |
|
nn.Linear((512 * last_duration * last_size * last_size), 2048), |
|
nn.ReLU(), |
|
nn.Dropout(0.5)) |
|
self.fc2 = nn.Sequential( |
|
nn.Linear(2048, 2048), |
|
nn.ReLU(), |
|
nn.Dropout(0.5)) |
|
self.fc = nn.Sequential( |
|
nn.Linear(2048, num_classes)) |
|
|
|
def forward(self, x): |
|
out = self.group1(x) |
|
out = self.group2(out) |
|
out = self.group3(out) |
|
out = self.group4(out) |
|
out = self.group5(out) |
|
out = out.view(out.size(0), -1) |
|
out = self.fc1(out) |
|
out = self.fc2(out) |
|
out = self.fc(out) |
|
return out |
|
|
|
|
|
def get_fine_tuning_parameters(model, ft_portion): |
|
if ft_portion == "complete": |
|
return model.parameters() |
|
|
|
elif ft_portion == "last_layer": |
|
ft_module_names = [] |
|
ft_module_names.append('fc') |
|
|
|
parameters = [] |
|
for k, v in model.named_parameters(): |
|
for ft_module in ft_module_names: |
|
if ft_module in k: |
|
parameters.append({'params': v}) |
|
break |
|
else: |
|
parameters.append({'params': v, 'lr': 0.0}) |
|
return parameters |
|
|
|
else: |
|
raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected") |
|
|
|
|
|
def get_model(**kwargs): |
|
""" |
|
Returns the model. |
|
""" |
|
model = C3D(**kwargs) |
|
return model |
|
|
|
|
|
if __name__ == '__main__': |
|
model = get_model(sample_size=112, sample_duration=16, num_classes=600) |
|
model = model.cuda() |
|
model = nn.DataParallel(model, device_ids=None) |
|
print(model) |
|
|
|
input_var = Variable(torch.randn(8, 3, 16, 112, 112)) |
|
output = model(input_var) |
|
print(output.shape) |
|
|