import torch import torch.nn as nn import torchvision def createVITModel(out_features: int) -> nn.Module: # 1. Get pretrained weights for ViT-Base pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT # 2. Setup a ViT model instance with pretrained weights pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights) # 3. Freeze the base parameters for parameter in pretrained_vit.parameters(): parameter.requires_grad = False # 4. Change the classifier head (set the seeds to ensure same initialization with linear head) pretrained_vit.heads = nn.Linear(in_features=768, out_features=out_features).to('cpu') vit_transforms = pretrained_vit_weights.transforms() return pretrained_vit, vit_transforms