Update model.py
Browse files
model.py
CHANGED
@@ -6,37 +6,49 @@ import json
|
|
6 |
|
7 |
# Define image transformation
|
8 |
transform_image = T.Compose([
|
9 |
-
T.Resize(224),
|
10 |
T.CenterCrop(224),
|
11 |
T.ToTensor(),
|
12 |
T.Normalize([0.5], [0.5])
|
13 |
])
|
14 |
|
15 |
-
def load_image(
|
16 |
"""
|
17 |
-
Load an image and return a tensor that can be used as an input to
|
18 |
"""
|
19 |
-
img = Image.open(
|
20 |
-
transformed_img = transform_image(img)
|
21 |
return transformed_img
|
22 |
|
23 |
-
# Load
|
24 |
-
|
25 |
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
26 |
-
|
27 |
-
|
28 |
|
29 |
-
# Load the classifier
|
30 |
clf = joblib.load('svm_model.joblib')
|
31 |
|
32 |
-
#
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
# Predict class for a new image
|
37 |
-
def predict(image_path):
|
38 |
new_image = load_image(image_path).to(device)
|
|
|
|
|
39 |
with torch.no_grad():
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
41 |
prediction = clf.predict(embedding)
|
|
|
42 |
return prediction[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Define image transformation
|
8 |
transform_image = T.Compose([
|
9 |
+
T.Resize(224),
|
10 |
T.CenterCrop(224),
|
11 |
T.ToTensor(),
|
12 |
T.Normalize([0.5], [0.5])
|
13 |
])
|
14 |
|
15 |
+
def load_image(img_path: str) -> torch.Tensor:
|
16 |
"""
|
17 |
+
Load an image and return a tensor that can be used as an input to the model.
|
18 |
"""
|
19 |
+
img = Image.open(img_path).convert("RGB")
|
20 |
+
transformed_img = transform_image(img).unsqueeze(0)
|
21 |
return transformed_img
|
22 |
|
23 |
+
# Load DINOv2 model for feature extraction
|
24 |
+
dinov2_model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
|
25 |
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
26 |
+
dinov2_model.to(device)
|
27 |
+
dinov2_model.eval() # Set the model to evaluation mode
|
28 |
|
29 |
+
# Load the pre-trained SVM classifier
|
30 |
clf = joblib.load('svm_model.joblib')
|
31 |
|
32 |
+
# Function to predict the class for a new image
|
33 |
+
def predict(image_path: str):
|
34 |
+
# Load and transform the image
|
|
|
|
|
|
|
35 |
new_image = load_image(image_path).to(device)
|
36 |
+
|
37 |
+
# Extract features using DINOv2
|
38 |
with torch.no_grad():
|
39 |
+
features = dinov2_model(new_image)
|
40 |
+
|
41 |
+
# Flatten features to 2D for SVM input
|
42 |
+
embedding = features.cpu().numpy().reshape(1, -1)
|
43 |
+
|
44 |
+
# Predict the class using the SVM classifier
|
45 |
prediction = clf.predict(embedding)
|
46 |
+
|
47 |
return prediction[0]
|
48 |
+
|
49 |
+
# If running as a script
|
50 |
+
if __name__ == "__main__":
|
51 |
+
import sys
|
52 |
+
image_path = sys.argv[1] # Get image path from command line arguments
|
53 |
+
predicted_class = predict(image_path)
|
54 |
+
print("Predicted class:", predicted_class)
|