pjdevelop's picture
Update model.py
75ea8f5 verified
raw
history blame
1.6 kB
import torch
import torchvision.transforms as T
from PIL import Image
import joblib
import json
# Define image transformation
transform_image = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize([0.5], [0.5])
])
def load_image(img_path: str) -> torch.Tensor:
"""
Load an image and return a tensor that can be used as an input to the model.
"""
img = Image.open(img_path).convert("RGB")
transformed_img = transform_image(img).unsqueeze(0)
return transformed_img
# Load DINOv2 model for feature extraction
dinov2_model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
dinov2_model.to(device)
dinov2_model.eval() # Set the model to evaluation mode
# Load the pre-trained SVM classifier
clf = joblib.load('svm_model.joblib')
# Function to predict the class for a new image
def predict(image_path: str):
# Load and transform the image
new_image = load_image(image_path).to(device)
# Extract features using DINOv2
with torch.no_grad():
features = dinov2_model(new_image)
# Flatten features to 2D for SVM input
embedding = features.cpu().numpy().reshape(1, -1)
# Predict the class using the SVM classifier
prediction = clf.predict(embedding)
return prediction[0]
# If running as a script
if __name__ == "__main__":
import sys
image_path = sys.argv[1] # Get image path from command line arguments
predicted_class = predict(image_path)
print("Predicted class:", predicted_class)