itzRahul's picture
Update model.py
1dc1d17
import torch
from torch import nn
from collections import OrderedDict
from torchvision.transforms import InterpolationMode
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
def create_vit_instance(num_classes:int = 1000,
device:torch.device = 'cpu'):
vit_weight = ViT_B_16_Weights.DEFAULT
vit_transforms = vit_weight.transforms()
vit_model = vit_b_16(weights=vit_weight).to(device)
for param in vit_model.parameters():
param.requires_grad = False
vit_model.heads = nn.Sequential(
OrderedDict([
('head', nn.Linear(in_features=768,
out_features=num_classes))
])
).to(device)
transform = transforms.Compose([
transforms.Resize(256, interpolation=InterpolationMode.BILINEAR),
transforms.CenterCrop(224),
transforms.Grayscale(num_output_channels=3), # Convert grayscale to RGB
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return (vit_model, transform)