from torchvision import models from torch import nn def create_model(num_classes:int=3): weights = models.Swin_V2_T_Weights.DEFAULT model_transforms = weights.transforms() model_ft = models.swin_v2_t(weights=weights) for param in model_ft.parameters(): param.requires_grad = False model_ft.head = nn.Linear(in_features=768, out_features=3, bias=True) return model_ft, model_transforms