convnext_food101_1 / model.py
repositivator's picture
Upload 8 files
eeb0a2f
import torch
import torchvision
from torch import nn
def create_model(num_classes=101, seed=42):
weights = torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT # .DEFAULT = best available weights on ImageNet
transforms = weights.transforms()
model = torchvision.models.convnext_tiny(weights=weights)
# Sequential (features)
for param in model.features.parameters():
param.requires_grad = False # "requires" "grad"ient-descent
# Sequential (classifier)
torch.manual_seed(seed)
model.classifier[-1] = nn.Linear(in_features=model.classifier[-1].in_features,
out_features=num_classes)
return model, transforms