ianpan's picture
Initial commit
231edce
raw
history blame
10.9 kB
import math
import numpy as np
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchvideo.models.x3d import create_x3d_stem
from timm.models.vision_transformer import VisionTransformer
from timm.models.swin_transformer_v2 import SwinTransformerV2
from . import backbones
from . import segmentation
from .pooling import create_pool2d_layer, create_pool3d_layer
from .sequence import Transformer, DualTransformer, DualTransformerV2
from .tools import change_initial_stride, change_num_input_channels
class Net2D(nn.Module):
def __init__(self,
backbone,
pretrained,
num_classes,
dropout,
pool,
in_channels=3,
change_stride=None,
feature_reduction=None,
multisample_dropout=False,
load_pretrained_backbone=None,
freeze_backbone=False,
backbone_params={},
pool_layer_params={}):
super().__init__()
self.backbone, dim_feats = backbones.create_backbone(name=backbone, pretrained=pretrained, **backbone_params)
if isinstance(pool, str):
self.pool_layer = create_pool2d_layer(name=pool, **pool_layer_params)
else:
self.pool_layer = nn.Identity()
if pool == "catavgmax":
dim_feats *= 2
self.msdo = multisample_dropout
if in_channels != 3:
self.backbone = change_num_input_channels(self.backbone, in_channels)
if change_stride:
self.backbone = change_initial_stride(self.backbone, tuple(change_stride), in_channels)
self.dropout = nn.Dropout(p=dropout)
if isinstance(feature_reduction, int):
# Use 1D grouped convolution to reduce # of parameters
groups = math.gcd(dim_feats, feature_reduction)
self.feature_reduction = nn.Conv1d(dim_feats, feature_reduction, groups=groups, kernel_size=1,
stride=1, bias=False)
dim_feats = feature_reduction
self.classifier = nn.Linear(dim_feats, num_classes)
if load_pretrained_backbone:
# Assumes that model has a `backbone` attribute
# Note: if you want to load the entire pretrained model, this is done via the
# builder.build_model function
print(f"Loading pretrained backbone from {load_pretrained_backbone} ...")
weights = torch.load(load_pretrained_backbone, map_location=lambda storage, loc: storage)['state_dict']
weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items()}
# Get feature_reduction, if present
feat_reduce_weight = {re.sub(r"^feature_reduction.", "", k): v
for k, v in weights.items() if "feature_reduction" in k}
# Get backbone only
weights = {re.sub(r'^backbone.', '', k) : v for k,v in weights.items() if 'backbone' in k}
self.backbone.load_state_dict(weights)
if len(feat_reduce_weight) > 0:
print("Also loading feature reduction layer ...")
self.feature_reduction.load_state_dict(feat_reduce_weight)
if freeze_backbone:
print("Freezing backbone ...")
for param in self.backbone.parameters():
param.requires_grad = False
def extract_features(self, x):
features = self.backbone(x)
features = self.pool_layer(features)
if isinstance(self.backbone, VisionTransformer):
features = features[:, self.backbone.num_prefix_tokens:].mean(dim=1)
if isinstance(self.backbone, SwinTransformerV2):
features = features.mean(dim=1)
if hasattr(self, "feature_reduction"):
features = self.feature_reduction(features.unsqueeze(-1)).squeeze(-1)
return features
def forward(self, x):
features = self.extract_features(x)
if self.msdo:
x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0)
else:
x = self.classifier(self.dropout(features))
# Important nuance:
# For binary classification, the model returns a tensor of shape (N,)
# Otherwise, (N,C)
return x[:, 0] if self.classifier.out_features == 1 else x
class SeqNet2D(Net2D):
def forward(self, x):
# x.shape = (N, C, Z, H, W)
features = torch.stack([self.extract_features(x[:, :, _]) for _ in range(x.size(2))], dim=2)
features = features.max(2)[0]
if self.msdo:
x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0)
else:
x = self.classifier(self.dropout(features))
# Important nuance:
# For binary classification, the model returns a tensor of shape (N,)
# Otherwise, (N,C)
return x[:, 0] if self.classifier.out_features == 1 else x
class TDCNN(nn.Module):
def __init__(self, cnn_params, transformer_params, freeze_cnn=False, freeze_transformer=False):
super().__init__()
self.cnn = Net2D(**cnn_params)
del self.cnn.dropout
del self.cnn.classifier
self.transformer = Transformer(**transformer_params)
if freeze_cnn:
for param in self.cnn.parameters():
param.requires_grad = False
if freeze_transformer:
for param in self.transformer.parameters():
param.requires_grad = False
def extract_features(self, x):
N, C, Z, H, W = x.size()
assert N == 1, "For feature extraction, batch size must be 1"
features = self.cnn.extract_features(x.squeeze(0).transpose(0, 1)).unsqueeze(0)
# features.shape = (1, Z, dim_feats)
return self.transformer.extract_features((features, torch.ones((features.size(0), features.size(1))).to(features.device)))
def forward(self, x):
# BCZHW
features = torch.stack([self.cnn.extract_features(x[:, :, i]) for i in range(x.size(2))], dim=1)
# B, seq_len, dim_feat
return self.transformer((features, torch.ones((features.size(0), features.size(1))).to(features.device)))
class Net2DWith3DStem(Net2D):
def __init__(self, *args, **kwargs):
stem_out_channels = kwargs.pop("stem_out_channels", 24)
load_pretrained_stem = kwargs.pop("load_pretrained_stem", None)
conv_kernel_size = tuple(kwargs.pop("conv_kernel_size", (5, 3, 3)))
conv_stride = tuple(kwargs.pop("conv_stride", (1, 2, 2)))
in_channels = kwargs.pop("in_channels", 3)
kwargs["in_channels"] = stem_out_channels
super().__init__(*args, **kwargs)
self.stem_layer = create_x3d_stem(in_channels=in_channels,
out_channels=stem_out_channels,
conv_kernel_size=conv_kernel_size,
conv_stride=conv_stride)
if kwargs["pretrained"]:
from pytorchvideo.models.hub import x3d_l
self.stem_layer.load_state_dict(x3d_l(pretrained=True).blocks[0].state_dict())
if load_pretrained_stem:
import re
print(f" Loading pretrained stem from {load_pretrained_stem} ...")
weights = torch.load(load_pretrained_stem, map_location=lambda storage, loc: storage)['state_dict']
stem_weights = {k.replace("model.backbone.blocks.0.", ""): v for k, v in weights.items() if "backbone.blocks.0" in k}
self.stem_layer.load_state_dict(stem_weights)
def forward(self, x):
x = self.stem_layer(x)
x = x.mean(3)
features = self.extract_features(x)
if self.msdo:
x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0)
else:
x = self.classifier(self.dropout(features))
# Important nuance:
# For binary classification, the model returns a tensor of shape (N,)
# Otherwise, (N,C)
return x[:, 0] if self.classifier.out_features == 1 else x
class Net3D(Net2D):
def __init__(self, *args, **kwargs):
z_strides = kwargs.pop("z_strides", [1,1,1,1,1])
super().__init__(*args, **kwargs)
self.pool_layer = create_pool3d_layer(name=kwargs["pool"], **kwargs.pop("pool_layer_params", {}))
class NetSegment2D(nn.Module):
""" For now, this class essentially servers as a wrapper for the
segmentation model which is mostly defined in the segmentation submodule,
similar to the original segmentation_models.pytorch.
It may be worth refactoring it in the future, such that you define this as
a general class, then select your choice of encoder and decoder. The encoder
is pretty much the same across all the segmentation models currently
implemented (DeepLabV3+, FPN, Unet).
"""
def __init__(self,
architecture,
encoder_name,
encoder_params,
decoder_params,
num_classes,
dropout,
in_channels,
load_pretrained_encoder=None,
freeze_encoder=False,
deep_supervision=False,
pool_layer_params={},
aux_head_params={}):
super().__init__()
self.segmentation_model = getattr(segmentation, architecture)(
encoder_name=encoder_name,
encoder_params=encoder_params,
dropout=dropout,
classes=num_classes,
deep_supervision=deep_supervision,
in_channels=in_channels,
**decoder_params
)
if load_pretrained_encoder:
# Assumes that model has a `encoder` attribute
# Note: if you want to load the entire pretrained model, this is done via the
# builder.build_model function
print(f"Loading pretrained encoder from {load_pretrained_encoder} ...")
weights = torch.load(load_pretrained_encoder, map_location=lambda storage, loc: storage)['state_dict']
weights = {re.sub(r'^model.segmentation_model', '', k) : v for k,v in weights.items()}
# Get encoder only
weights = {re.sub(r'^encoder.', '', k) : v for k,v in weights.items() if 'backbone' in k}
self.segmentation_model.encoder.load_state_dict(weights)
if freeze_encoder:
print("Freezing encoder ...")
for param in self.segmentation_model.encoder.parameters():
param.requires_grad = False
def forward(self, x):
return self.segmentation_model(x)
class NetSegment3D(NetSegment2D):
pass