import torch import torchvision.transforms as T from PIL import Image import joblib import json # Define image transformation transform_image = T.Compose([ T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.5], [0.5]) ]) def load_image(img_path: str) -> torch.Tensor: """ Load an image and return a tensor that can be used as an input to the model. """ img = Image.open(img_path).convert("RGB") transformed_img = transform_image(img).unsqueeze(0) return transformed_img # Load DINOv2 model for feature extraction dinov2_model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16') device = torch.device('cuda' if torch.cuda.is_available() else "cpu") dinov2_model.to(device) dinov2_model.eval() # Set the model to evaluation mode # Load the pre-trained SVM classifier clf = joblib.load('svm_model.joblib') # Function to predict the class for a new image def predict(image_path: str): # Load and transform the image new_image = load_image(image_path).to(device) # Extract features using DINOv2 with torch.no_grad(): features = dinov2_model(new_image) # Flatten features to 2D for SVM input embedding = features.cpu().numpy().reshape(1, -1) # Predict the class using the SVM classifier prediction = clf.predict(embedding) return prediction[0] # If running as a script if __name__ == "__main__": import sys image_path = sys.argv[1] # Get image path from command line arguments predicted_class = predict(image_path) print("Predicted class:", predicted_class)