File size: 421 Bytes
0636324
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torchvision import models
from torch import nn

def create_model(num_classes:int=3):
    weights = models.Swin_V2_T_Weights.DEFAULT
    model_transforms = weights.transforms()
    model_ft = models.swin_v2_t(weights=weights)

    for param in model_ft.parameters():
        param.requires_grad = False

    model_ft.head = nn.Linear(in_features=768, out_features=3, bias=True)

    return model_ft, model_transforms