RickOWO12344's picture
Upload 7 files
233e156
import torch
import torchvision
from torch import nn
import torchvision.models as models
def ResNet18_model(num_classes:int=3):
# Create ResNet18 model
model_0 = models.resnet18(pretrained=True)
# Get the length of class_names (one output unit for each class)
output_shape = num_classes
num_ftrs = model_0.fc.in_features
# Define the number of output classes for your task
num_classes = output_shape
# Replace the last linear layer with a new one that has the right number of output units
model_0.fc = torch.nn.Sequential(
torch.nn.Linear(num_ftrs, num_classes),
torch.nn.Dropout(p=0.2)
)
return model_0