Spaces:
Sleeping
Sleeping
import torch | |
import torchvision | |
from torch import nn | |
import torchvision.models as models | |
def ResNet18_model(num_classes:int=3): | |
# Create ResNet18 model | |
model_0 = models.resnet18(pretrained=True) | |
# Get the length of class_names (one output unit for each class) | |
output_shape = num_classes | |
num_ftrs = model_0.fc.in_features | |
# Define the number of output classes for your task | |
num_classes = output_shape | |
# Replace the last linear layer with a new one that has the right number of output units | |
model_0.fc = torch.nn.Sequential( | |
torch.nn.Linear(num_ftrs, num_classes), | |
torch.nn.Dropout(p=0.2) | |
) | |
return model_0 | |