|
import torch |
|
from torch import nn |
|
|
|
from logger.logger import get_logger |
|
from models import c3d |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def _construct_depth_model(base_model): |
|
|
|
modules = list(base_model.modules()) |
|
|
|
first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv3d), |
|
list(range(len(modules)))))[0] |
|
conv_layer = modules[first_conv_idx] |
|
container = modules[first_conv_idx - 1] |
|
|
|
|
|
motion_length = 1 |
|
params = [x.clone() for x in conv_layer.parameters()] |
|
kernel_size = params[0].size() |
|
new_kernel_size = kernel_size[:1] + (1 * motion_length,) + kernel_size[2:] |
|
new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() |
|
|
|
new_conv = nn.Conv3d(1, conv_layer.out_channels, conv_layer.kernel_size, conv_layer.stride, |
|
conv_layer.padding, bias=True if len(params) == 2 else False) |
|
new_conv.weight.data = new_kernels |
|
if len(params) == 2: |
|
new_conv.bias.data = params[1].data |
|
layer_name = list(container.state_dict().keys())[0][:-7] |
|
|
|
|
|
setattr(container, layer_name, new_conv) |
|
|
|
return base_model |
|
|
|
|
|
def _construct_rgbdepth_model(base_model): |
|
|
|
modules = list(base_model.modules()) |
|
|
|
first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv3d), |
|
list(range(len(modules)))))[0] |
|
conv_layer = modules[first_conv_idx] |
|
container = modules[first_conv_idx - 1] |
|
|
|
motion_length = 1 |
|
params = [x.clone() for x in conv_layer.parameters()] |
|
kernel_size = params[0].size() |
|
new_kernel_size = kernel_size[:1] + (1 * motion_length,) + kernel_size[2:] |
|
new_kernels = torch.mul(torch.cat((params[0].data, |
|
params[0].data.mean(dim=1, keepdim=True) |
|
.expand(new_kernel_size) |
|
.contiguous()), 1), 0.6) |
|
new_kernel_size = kernel_size[:1] + (3 + 1 * motion_length,) + kernel_size[2:] |
|
new_conv = nn.Conv3d(4, conv_layer.out_channels, conv_layer.kernel_size, conv_layer.stride, |
|
conv_layer.padding, bias=True if len(params) == 2 else False) |
|
new_conv.weight.data = new_kernels |
|
if len(params) == 2: |
|
new_conv.bias.data = params[1].data |
|
layer_name = list(container.state_dict().keys())[0][:-7] |
|
|
|
|
|
setattr(container, layer_name, new_conv) |
|
return base_model |
|
|
|
|
|
def _modify_first_conv_layer(base_model, new_kernel_size1, new_filter_num): |
|
modules = list(base_model.modules()) |
|
first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv3d), |
|
list(range(len(modules)))))[0] |
|
conv_layer = modules[first_conv_idx] |
|
container = modules[first_conv_idx - 1] |
|
|
|
new_conv = nn.Conv3d(new_filter_num, conv_layer.out_channels, kernel_size=(new_kernel_size1, 7, 7), |
|
stride=(1, 2, 2), padding=(1, 3, 3), bias=False) |
|
layer_name = list(container.state_dict().keys())[0][:-7] |
|
|
|
setattr(container, layer_name, new_conv) |
|
return base_model |
|
|
|
|
|
def modify_kernels(model, modality): |
|
if modality == 'RGB' and model not in ['c3d']: |
|
logger.info(f" RGB model is used for init model") |
|
model = _modify_first_conv_layer(model, 3, 3) |
|
elif modality == 'Depth': |
|
logger.info(f" Converting the pretrained model to Depth init model") |
|
model = _construct_depth_model(model) |
|
logger.info(f" Done. Flow model ready.") |
|
elif modality == 'RGB-D': |
|
logger.info(f" Converting the pretrained model to RGB+D init model") |
|
model = _construct_rgbdepth_model(model) |
|
logger.info(f" Done. RGB-D model ready.") |
|
modules = list(model.modules()) |
|
first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv3d), list(range(len(modules)))))[0] |
|
return model |
|
|
|
|
|
def generate_model(n_classes, sample_size, ft_portion, no_cuda=False, modality="RGB-D", sample_duration=8): |
|
logger.info(f"Torch version: {torch.__version__}") |
|
logger.info(f"Is CUDA enabled? {torch.cuda.is_available()}") |
|
from models.c3d import get_fine_tuning_parameters |
|
model = c3d.get_model( |
|
num_classes=n_classes, |
|
sample_size=sample_size, |
|
sample_duration=sample_duration) |
|
if not no_cuda: |
|
model = model.cuda() |
|
model = nn.DataParallel(model, device_ids=None) |
|
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
logger.info(f"Total number of trainable parameters: {pytorch_total_params}") |
|
|
|
model = modify_kernels(model, modality) |
|
parameters = get_fine_tuning_parameters(model, ft_portion) |
|
return model, parameters |
|
|