Spaces:
Sleeping
Sleeping
File size: 781 Bytes
9d02e76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
import torch.nn as nn
import torchvision
def createVITModel(out_features: int) -> nn.Module:
# 1. Get pretrained weights for ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
# 2. Setup a ViT model instance with pretrained weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)
# 3. Freeze the base parameters
for parameter in pretrained_vit.parameters():
parameter.requires_grad = False
# 4. Change the classifier head (set the seeds to ensure same initialization with linear head)
pretrained_vit.heads = nn.Linear(in_features=768, out_features=out_features).to('cpu')
vit_transforms = pretrained_vit_weights.transforms()
return pretrained_vit, vit_transforms
|