Spaces:
Sleeping
Sleeping
from torchvision import models | |
from torch import nn | |
def create_model(num_classes:int=3): | |
weights = models.Swin_V2_T_Weights.DEFAULT | |
model_transforms = weights.transforms() | |
model_ft = models.swin_v2_t(weights=weights) | |
for param in model_ft.parameters(): | |
param.requires_grad = False | |
model_ft.head = nn.Linear(in_features=768, out_features=3, bias=True) | |
return model_ft, model_transforms | |