|
from torchvision import models |
|
from torch import nn |
|
|
|
model_mapping = { |
|
"densenet121": ( |
|
models.densenet121, |
|
{"weights": models.DenseNet121_Weights.DEFAULT, "family": "densenet"}, |
|
), |
|
"densenet161": ( |
|
models.densenet161, |
|
{"weights": models.DenseNet161_Weights.DEFAULT, "family": "densenet"}, |
|
), |
|
"densenet169": ( |
|
models.densenet169, |
|
{"weights": models.DenseNet169_Weights.DEFAULT, "family": "densenet"}, |
|
), |
|
"densenet201": ( |
|
models.densenet201, |
|
{"weights": models.DenseNet201_Weights.DEFAULT, "family": "densenet"}, |
|
), |
|
"resnet50": ( |
|
models.resnet50, |
|
{"weights": models.ResNet50_Weights.IMAGENET1K_V2, "family": "resnet"}, |
|
), |
|
"resnet101": ( |
|
models.resnet101, |
|
{"weights": models.ResNet101_Weights.IMAGENET1K_V2, "family": "resnet"}, |
|
), |
|
"resnet152": ( |
|
models.resnet152, |
|
{"weights": models.ResNet152_Weights.IMAGENET1K_V2, "family": "resnet"}, |
|
), |
|
"vit-b-16": ( |
|
models.vit_b_16, |
|
{"weights": models.ViT_B_16_Weights.DEFAULT, "family": "vit"}, |
|
), |
|
"vit-b-32": ( |
|
models.vit_b_32, |
|
{"weights": models.ViT_B_32_Weights.DEFAULT, "family": "vit"}, |
|
), |
|
|
|
} |
|
|
|
|
|
class Model(nn.Module): |
|
"""Moodel definition.""" |
|
|
|
def __init__(self, model_name: str, num_classes: int): |
|
""" |
|
Initialize Model instance. |
|
|
|
Args: |
|
model_name (str): Name of the model architecture. |
|
num_classes (int): Number of output classes. |
|
""" |
|
super(Model, self).__init__() |
|
|
|
model_class, model_config = model_mapping[model_name] |
|
self.model = model_class(weights=model_config["weights"]) |
|
|
|
|
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
in_features = self._get_in_features(model_config["family"]) |
|
|
|
if model_config["family"] == "densenet": |
|
self.model.classifier = self._create_classifier(in_features, num_classes) |
|
elif model_config["family"] == "resnet": |
|
self.model.fc = self._create_classifier(in_features, num_classes) |
|
elif model_config["family"] == "vit": |
|
self.model.heads = self._create_classifier(in_features, num_classes) |
|
|
|
def forward(self, x): |
|
"""Forward pass through the model.""" |
|
return self.model(x) |
|
|
|
def _get_in_features(self, family: str) -> int: |
|
"""Return the number of input features for the classifier.""" |
|
if family == "densenet": |
|
return self.model.classifier.in_features |
|
elif family == "resnet": |
|
return self.model.fc.in_features |
|
elif family == "vit": |
|
return self.model.heads.head.in_features |
|
|
|
def _create_classifier(self, in_features: int, num_classes: int) -> nn.Sequential: |
|
"""Create the classifier module.""" |
|
return nn.Sequential( |
|
nn.Linear(in_features, in_features // 2), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(in_features // 2, num_classes), |
|
) |
|
|
|
|
|
class ModelFactory: |
|
""" |
|
Factory for creating different models based on their names. |
|
|
|
Args: |
|
name (str): The name of the model factory. |
|
num_classes (int): The number of output classes. |
|
|
|
Raises: |
|
ValueError: If the specified model factory is not implemented. |
|
""" |
|
|
|
def __init__(self, name: str, num_classes: int): |
|
""" |
|
Initialize ModelFactory instance. |
|
|
|
Args: |
|
name (str): The name of the model. |
|
num_classes (int): The number of output classes. |
|
""" |
|
self.name = name |
|
self.num_classes = num_classes |
|
|
|
def __call__(self): |
|
""" |
|
Create a model instance based on the provided name. |
|
|
|
Args: |
|
model_name (str): Name of the model architecture. |
|
num_classes (int): Number of output classes. |
|
|
|
Returns: |
|
Model: An instance of the selected model. |
|
""" |
|
if self.name not in model_mapping: |
|
valid_options = ", ".join(model_mapping.keys()) |
|
raise ValueError( |
|
f"Invalid model name: '{self.name}'. Available options: {valid_options}" |
|
) |
|
|
|
return Model(self.name, self.num_classes) |
|
|
|
|
|
if __name__ == "__main__": |
|
model = ModelFactory("resnet50", 5)() |