VIT_Demo / vit_model_test.py
benjaminStreltzin's picture
Update vit_model_test.py
152bbff verified
raw
history blame
1.47 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import ViTForImageClassification
from PIL import Image
class CustomModel:
def __init__(self):
# Check for GPU availability
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the pre-trained ViT model and move it to GPU
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
self.model.load_state_dict(torch.load('trained_model.pth'))
self.model.eval()
# Define the image preprocessing pipeline
self.preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def predict(self, image: Image.Image):
# Preprocess the image
image = self.preprocess(image).unsqueeze(0).to(self.device) # Add batch dimension
# Perform inference
with torch.no_grad():
outputs = self.model(image)
logits = outputs.logits
probabilities = F.softmax(logits, dim=1)
confidences, predicted = torch.max(probabilities, 1)
predicted_label = predicted.item()
confidence = confidences.item() * 100 # Convert to percentage
return predicted_label, confidence