File size: 781 Bytes
9d02e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn
import torchvision

def createVITModel(out_features: int) -> nn.Module:
    # 1. Get pretrained weights for ViT-Base
    pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
    # 2. Setup a ViT model instance with pretrained weights
    pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)
    # 3. Freeze the base parameters
    for parameter in pretrained_vit.parameters():
        parameter.requires_grad = False
    # 4. Change the classifier head (set the seeds to ensure same initialization with linear head)
    pretrained_vit.heads = nn.Linear(in_features=768, out_features=out_features).to('cpu')
    vit_transforms = pretrained_vit_weights.transforms()
    return pretrained_vit, vit_transforms