pjdevelop commited on
Commit
75ea8f5
1 Parent(s): 629fcd8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +29 -17
model.py CHANGED
@@ -6,37 +6,49 @@ import json
6
 
7
  # Define image transformation
8
  transform_image = T.Compose([
9
- T.Resize(224), # Corrected to 224 to match CenterCrop
10
  T.CenterCrop(224),
11
  T.ToTensor(),
12
  T.Normalize([0.5], [0.5])
13
  ])
14
 
15
- def load_image(img: str) -> torch.Tensor:
16
  """
17
- Load an image and return a tensor that can be used as an input to DINOv2.
18
  """
19
- img = Image.open(img)
20
- transformed_img = transform_image(img)[:3].unsqueeze(0)
21
  return transformed_img
22
 
23
- # Load models for inference
24
- dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
25
  device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
26
- dinov2_vits14.to(device)
27
- dinov2_vits14.eval() # Set the model to evaluation mode
28
 
29
- # Load the classifier
30
  clf = joblib.load('svm_model.joblib')
31
 
32
- # Load the embeddings
33
- with open('all_embeddings.json', 'r') as f:
34
- embeddings = json.load(f)
35
-
36
- # Predict class for a new image
37
- def predict(image_path):
38
  new_image = load_image(image_path).to(device)
 
 
39
  with torch.no_grad():
40
- embedding = dinov2_vits14(new_image).cpu().numpy().reshape(1, -1)
 
 
 
 
 
41
  prediction = clf.predict(embedding)
 
42
  return prediction[0]
 
 
 
 
 
 
 
 
6
 
7
  # Define image transformation
8
  transform_image = T.Compose([
9
+ T.Resize(224),
10
  T.CenterCrop(224),
11
  T.ToTensor(),
12
  T.Normalize([0.5], [0.5])
13
  ])
14
 
15
+ def load_image(img_path: str) -> torch.Tensor:
16
  """
17
+ Load an image and return a tensor that can be used as an input to the model.
18
  """
19
+ img = Image.open(img_path).convert("RGB")
20
+ transformed_img = transform_image(img).unsqueeze(0)
21
  return transformed_img
22
 
23
+ # Load DINOv2 model for feature extraction
24
+ dinov2_model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
25
  device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
26
+ dinov2_model.to(device)
27
+ dinov2_model.eval() # Set the model to evaluation mode
28
 
29
+ # Load the pre-trained SVM classifier
30
  clf = joblib.load('svm_model.joblib')
31
 
32
+ # Function to predict the class for a new image
33
+ def predict(image_path: str):
34
+ # Load and transform the image
 
 
 
35
  new_image = load_image(image_path).to(device)
36
+
37
+ # Extract features using DINOv2
38
  with torch.no_grad():
39
+ features = dinov2_model(new_image)
40
+
41
+ # Flatten features to 2D for SVM input
42
+ embedding = features.cpu().numpy().reshape(1, -1)
43
+
44
+ # Predict the class using the SVM classifier
45
  prediction = clf.predict(embedding)
46
+
47
  return prediction[0]
48
+
49
+ # If running as a script
50
+ if __name__ == "__main__":
51
+ import sys
52
+ image_path = sys.argv[1] # Get image path from command line arguments
53
+ predicted_class = predict(image_path)
54
+ print("Predicted class:", predicted_class)