Spaces:
Sleeping
Sleeping
File size: 421 Bytes
0636324 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from torchvision import models
from torch import nn
def create_model(num_classes:int=3):
weights = models.Swin_V2_T_Weights.DEFAULT
model_transforms = weights.transforms()
model_ft = models.swin_v2_t(weights=weights)
for param in model_ft.parameters():
param.requires_grad = False
model_ft.head = nn.Linear(in_features=768, out_features=3, bias=True)
return model_ft, model_transforms
|