foodvisiontest / model.py
ramirjf
initial commit
9d02e76
import torch
import torch.nn as nn
import torchvision
def createVITModel(out_features: int) -> nn.Module:
# 1. Get pretrained weights for ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
# 2. Setup a ViT model instance with pretrained weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)
# 3. Freeze the base parameters
for parameter in pretrained_vit.parameters():
parameter.requires_grad = False
# 4. Change the classifier head (set the seeds to ensure same initialization with linear head)
pretrained_vit.heads = nn.Linear(in_features=768, out_features=out_features).to('cpu')
vit_transforms = pretrained_vit_weights.transforms()
return pretrained_vit, vit_transforms