horsevision-mini / model.py
lamont-granquist's picture
initial commit
0636324
raw
history blame contribute delete
421 Bytes
from torchvision import models
from torch import nn
def create_model(num_classes:int=3):
weights = models.Swin_V2_T_Weights.DEFAULT
model_transforms = weights.transforms()
model_ft = models.swin_v2_t(weights=weights)
for param in model_ft.parameters():
param.requires_grad = False
model_ft.head = nn.Linear(in_features=768, out_features=3, bias=True)
return model_ft, model_transforms