import torch import torchvision from torch import nn def create_effnet( pretrained_weights: torchvision.models.Weights, model: torchvision.models, in_features: int, dropout: int, out_features: int, device: torch.device, ): # Get the weights and setup the model model = model(weights=pretrained_weights).to(device) transforms = pretrained_weights.transforms() # Freeze the base model layers for param in model.features.parameters(): param.requires_grad = False # Change the classifier head model.classifier = nn.Sequential( nn.Dropout(p=dropout, inplace=True), nn.Linear(in_features=in_features, out_features=out_features), ).to(device) return model, transforms