Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision | |
def createVITModel(out_features: int) -> nn.Module: | |
# 1. Get pretrained weights for ViT-Base | |
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT | |
# 2. Setup a ViT model instance with pretrained weights | |
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights) | |
# 3. Freeze the base parameters | |
for parameter in pretrained_vit.parameters(): | |
parameter.requires_grad = False | |
# 4. Change the classifier head (set the seeds to ensure same initialization with linear head) | |
pretrained_vit.heads = nn.Linear(in_features=768, out_features=out_features).to('cpu') | |
vit_transforms = pretrained_vit_weights.transforms() | |
return pretrained_vit, vit_transforms | |