import torch import torchvision from torch import nn import torchvision.models as models def ResNet18_model(num_classes:int=3): # Create ResNet18 model model_0 = models.resnet18(pretrained=True) # Get the length of class_names (one output unit for each class) output_shape = num_classes num_ftrs = model_0.fc.in_features # Define the number of output classes for your task num_classes = output_shape # Replace the last linear layer with a new one that has the right number of output units model_0.fc = torch.nn.Sequential( torch.nn.Linear(num_ftrs, num_classes), torch.nn.Dropout(p=0.2) ) return model_0