import torch from torch import nn from collections import OrderedDict from torchvision.transforms import InterpolationMode from torchvision import transforms from torchvision.models import vit_b_16, ViT_B_16_Weights def create_vit_instance(num_classes:int = 1000, device:torch.device = 'cpu'): vit_weight = ViT_B_16_Weights.DEFAULT vit_transforms = vit_weight.transforms() vit_model = vit_b_16(weights=vit_weight).to(device) for param in vit_model.parameters(): param.requires_grad = False vit_model.heads = nn.Sequential( OrderedDict([ ('head', nn.Linear(in_features=768, out_features=num_classes)) ]) ).to(device) transform = transforms.Compose([ transforms.Resize(256, interpolation=InterpolationMode.BILINEAR), transforms.CenterCrop(224), transforms.Grayscale(num_output_channels=3), # Convert grayscale to RGB transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return (vit_model, transform)