import torch import torchvision from torch import nn def create_densenet121_model(num_classes:int=3, seed:int=42): # Create DenseNet121 pretrained weights, transforms and model weights = torchvision.models.DenseNet121_Weights.DEFAULT transforms = weights.transforms() model = torchvision.models.densenet121(weights=weights) # Freeze all layers in base model for param in model.parameters(): param.requires_grad = False # Change classifier head with random seed for reproducibility torch.manual_seed(seed) model.classifier = torch.nn.Sequential( torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(in_features=1024, out_features=3, bias=True)) return model, transforms