Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from collections import OrderedDict | |
from torchvision.transforms import InterpolationMode | |
from torchvision import transforms | |
from torchvision.models import vit_b_16, ViT_B_16_Weights | |
def create_vit_instance(num_classes:int = 1000, | |
device:torch.device = 'cpu'): | |
vit_weight = ViT_B_16_Weights.DEFAULT | |
vit_transforms = vit_weight.transforms() | |
vit_model = vit_b_16(weights=vit_weight).to(device) | |
for param in vit_model.parameters(): | |
param.requires_grad = False | |
vit_model.heads = nn.Sequential( | |
OrderedDict([ | |
('head', nn.Linear(in_features=768, | |
out_features=num_classes)) | |
]) | |
).to(device) | |
transform = transforms.Compose([ | |
transforms.Resize(256, interpolation=InterpolationMode.BILINEAR), | |
transforms.CenterCrop(224), | |
transforms.Grayscale(num_output_channels=3), # Convert grayscale to RGB | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
return (vit_model, transform) | |