import os import torch import torchvision from torch import nn from torchvision import transforms from PIL import Image class CFG: DEVICE = 'cpu' NUM_DEVICES = torch.cuda.device_count() NUM_WORKERS = os.cpu_count() NUM_CLASSES = 4 EPOCHS = 16 BATCH_SIZE = 32 LR = 0.001 APPLY_SHUFFLE = True SEED = 768 HEIGHT = 224 WIDTH = 224 CHANNELS = 3 IMAGE_SIZE = (224, 224, 3) class VisionTransformerModel(nn.Module): def __init__(self, backbone_model, name='vision-transformer', num_classes=CFG.NUM_CLASSES, device=CFG.DEVICE): super(VisionTransformerModel, self).__init__() self.backbone_model = backbone_model self.device = device self.num_classes = num_classes self.name = name self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(p=0.2, inplace=True), nn.Linear(in_features=1000, out_features=256, bias=True), nn.GELU(), nn.Dropout(p=0.2, inplace=True), nn.Linear(in_features=256, out_features=num_classes, bias=False) ).to(device) def forward(self, image): vit_output = self.backbone_model(image) return self.classifier(vit_output) def get_vit_b32_model( device: torch.device=CFG.NUM_CLASSES) -> nn.Module: # Set the manual seeds torch.manual_seed(CFG.SEED) torch.cuda.manual_seed(CFG.SEED) # Get model weights model_weights = ( torchvision .models .ViT_L_32_Weights .DEFAULT ) # Get model and push to device model = ( torchvision.models.vit_l_32( weights=model_weights ) ).to(device) # Freeze Model Parameters for param in model.parameters(): param.requires_grad = False return model # Get ViT model vit_backbone = get_vit_b32_model(CFG.DEVICE) vit_params = { 'backbone_model' : vit_backbone, 'name' : 'ViT-L-B32', 'device' : CFG.DEVICE } # Generate Model vit_model = VisionTransformerModel(**vit_params) vit_model.load_state_dict( torch.load('vit_model.pth', map_location=torch.device('cpu')) ) # Define the image transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) def predict(image_path): image = Image.open(image_path) input_tensor = transform(image) input_batch = input_tensor.unsqueeze(0).to(CFG.DEVICE) # Add batch dimension # Perform inference with torch.no_grad(): output = vit_model(input_batch).to(CFG.DEVICE) # You can now use the 'output' tensor as needed (e.g., get predictions) return output