|
""" |
|
Auto-generated model definition for mobilenet_v2-kather100k. |
|
This file is self-contained and includes the definitions for _get_architecture, |
|
CNNBackbone, and CNNModel. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torchvision.models as torch_models |
|
import timm |
|
|
|
def _get_architecture(arch_name, weights="DEFAULT", **kwargs): |
|
backbone_dict = { |
|
"alexnet": torch_models.alexnet, |
|
"resnet18": torch_models.resnet18, |
|
"resnet34": torch_models.resnet34, |
|
"resnet50": torch_models.resnet50, |
|
"resnet101": torch_models.resnet101, |
|
"resnext50_32x4d": torch_models.resnext50_32x4d, |
|
"resnext101_32x8d": torch_models.resnext101_32x8d, |
|
"wide_resnet50_2": torch_models.wide_resnet50_2, |
|
"wide_resnet101_2": torch_models.wide_resnet101_2, |
|
"densenet121": torch_models.densenet121, |
|
"densenet161": torch_models.densenet161, |
|
"densenet169": torch_models.densenet169, |
|
"densenet201": torch_models.densenet201, |
|
"inception_v3": torch_models.inception_v3, |
|
"googlenet": torch_models.googlenet, |
|
"mobilenet_v2": torch_models.mobilenet_v2, |
|
"mobilenet_v3_large": torch_models.mobilenet_v3_large, |
|
"mobilenet_v3_small": torch_models.mobilenet_v3_small, |
|
} |
|
if arch_name not in backbone_dict: |
|
raise ValueError(f"Backbone {arch_name} is not supported.") |
|
creator = backbone_dict[arch_name] |
|
model = creator(weights=weights, **kwargs) |
|
if "resnet" in arch_name or "resnext" in arch_name: |
|
return nn.Sequential(*list(model.children())[:-2]) |
|
if "densenet" in arch_name or "alexnet" in arch_name: |
|
return model.features |
|
if "inception_v3" in arch_name or "googlenet" in arch_name: |
|
return nn.Sequential(*list(model.children())[:-3]) |
|
return model.features |
|
|
|
class CNNBackbone(nn.Module): |
|
def __init__(self, backbone: str): |
|
super().__init__() |
|
self.feat_extract = _get_architecture(backbone) |
|
self.pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
def forward(self, imgs: torch.Tensor) -> torch.Tensor: |
|
feat = self.feat_extract(imgs) |
|
gap_feat = self.pool(feat) |
|
return torch.flatten(gap_feat, 1) |
|
|
|
class CNNModel(nn.Module): |
|
def __init__(self, backbone: str, num_classes: int = 1): |
|
super().__init__() |
|
self.feat_extract = _get_architecture(backbone) |
|
self.pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
|
self.num_features = self.feat_extract(torch.rand([2, 3, 96, 96])).shape[1] |
|
self.classifier = nn.Linear(self.num_features, num_classes) |
|
def forward(self, imgs: torch.Tensor) -> torch.Tensor: |
|
feat = self.feat_extract(imgs) |
|
gap_feat = self.pool(feat) |
|
gap_feat = torch.flatten(gap_feat, 1) |
|
logit = self.classifier(gap_feat) |
|
return torch.softmax(logit, -1) |
|
|
|
@timm.models.register_model |
|
def mobilenet_v2-kather100k(pretrained=False, features_only=False, **kwargs): |
|
backbone = "mobilenet_v2" |
|
num_classes = 9 |
|
ModelClass = CNNBackbone if features_only else CNNModel |
|
model = ModelClass(backbone=backbone, num_classes=num_classes) |
|
if pretrained: |
|
state_dict = torch.load("pytorch_model.bin", map_location="cpu") |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
return model |
|
|