|
import torch |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
import joblib |
|
import json |
|
|
|
|
|
transform_image = T.Compose([ |
|
T.Resize(224), |
|
T.CenterCrop(224), |
|
T.ToTensor(), |
|
T.Normalize([0.5], [0.5]) |
|
]) |
|
|
|
def load_image(img_path: str) -> torch.Tensor: |
|
""" |
|
Load an image and return a tensor that can be used as an input to the model. |
|
""" |
|
img = Image.open(img_path).convert("RGB") |
|
transformed_img = transform_image(img).unsqueeze(0) |
|
return transformed_img |
|
|
|
|
|
dinov2_model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16') |
|
device = torch.device('cuda' if torch.cuda.is_available() else "cpu") |
|
dinov2_model.to(device) |
|
dinov2_model.eval() |
|
|
|
|
|
clf = joblib.load('svm_model.joblib') |
|
|
|
|
|
def predict(image_path: str): |
|
|
|
new_image = load_image(image_path).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
features = dinov2_model(new_image) |
|
|
|
|
|
embedding = features.cpu().numpy().reshape(1, -1) |
|
|
|
|
|
prediction = clf.predict(embedding) |
|
|
|
return prediction[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
image_path = sys.argv[1] |
|
predicted_class = predict(image_path) |
|
print("Predicted class:", predicted_class) |
|
|