mositemp's picture
Add model.py for mobilenet_v2-kather100k
bc6b141 verified
"""
Auto-generated model definition for mobilenet_v2-kather100k.
This file is self-contained and includes the definitions for _get_architecture,
CNNBackbone, and CNNModel.
"""
import torch
import torch.nn as nn
import torchvision.models as torch_models
import timm
def _get_architecture(arch_name, weights="DEFAULT", **kwargs):
backbone_dict = {
"alexnet": torch_models.alexnet,
"resnet18": torch_models.resnet18,
"resnet34": torch_models.resnet34,
"resnet50": torch_models.resnet50,
"resnet101": torch_models.resnet101,
"resnext50_32x4d": torch_models.resnext50_32x4d,
"resnext101_32x8d": torch_models.resnext101_32x8d,
"wide_resnet50_2": torch_models.wide_resnet50_2,
"wide_resnet101_2": torch_models.wide_resnet101_2,
"densenet121": torch_models.densenet121,
"densenet161": torch_models.densenet161,
"densenet169": torch_models.densenet169,
"densenet201": torch_models.densenet201,
"inception_v3": torch_models.inception_v3,
"googlenet": torch_models.googlenet,
"mobilenet_v2": torch_models.mobilenet_v2,
"mobilenet_v3_large": torch_models.mobilenet_v3_large,
"mobilenet_v3_small": torch_models.mobilenet_v3_small,
}
if arch_name not in backbone_dict:
raise ValueError(f"Backbone {arch_name} is not supported.")
creator = backbone_dict[arch_name]
model = creator(weights=weights, **kwargs)
if "resnet" in arch_name or "resnext" in arch_name:
return nn.Sequential(*list(model.children())[:-2])
if "densenet" in arch_name or "alexnet" in arch_name:
return model.features
if "inception_v3" in arch_name or "googlenet" in arch_name:
return nn.Sequential(*list(model.children())[:-3])
return model.features
class CNNBackbone(nn.Module):
def __init__(self, backbone: str):
super().__init__()
self.feat_extract = _get_architecture(backbone)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
feat = self.feat_extract(imgs)
gap_feat = self.pool(feat)
return torch.flatten(gap_feat, 1)
class CNNModel(nn.Module):
def __init__(self, backbone: str, num_classes: int = 1):
super().__init__()
self.feat_extract = _get_architecture(backbone)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
# Dynamically determine number of features from a dummy input.
self.num_features = self.feat_extract(torch.rand([2, 3, 96, 96])).shape[1]
self.classifier = nn.Linear(self.num_features, num_classes)
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
feat = self.feat_extract(imgs)
gap_feat = self.pool(feat)
gap_feat = torch.flatten(gap_feat, 1)
logit = self.classifier(gap_feat)
return torch.softmax(logit, -1)
@timm.models.register_model
def mobilenet_v2-kather100k(pretrained=False, features_only=False, **kwargs):
backbone = "mobilenet_v2"
num_classes = 9
ModelClass = CNNBackbone if features_only else CNNModel
model = ModelClass(backbone=backbone, num_classes=num_classes)
if pretrained:
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
return model