pjdevelop commited on
Commit
ed36966
1 Parent(s): 19ddbe5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -15
app.py CHANGED
@@ -1,26 +1,61 @@
1
-
2
- import gradio as gr
3
  import torch
4
- import numpy as np
5
  from PIL import Image
6
- from torchvision import transforms as T
7
  import joblib
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Load models
10
- dinov2_vits14 = torch.load('dinov2_vits14.pth', map_location=torch.device('cpu'))
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  clf = joblib.load('svm_model.joblib')
12
 
13
- # Transform for input image
14
- transform_image = T.Compose([T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])])
 
15
 
16
- def predict(image):
17
- image = Image.fromarray(image)
18
- transformed_img = transform_image(image)[:3].unsqueeze(0)
19
  with torch.no_grad():
20
- embedding = dinov2_vits14(transformed_img)
21
- prediction = clf.predict(np.array(embedding[0].cpu()).reshape(1, -1))
22
  return prediction[0]
23
 
24
- iface = gr.Interface(fn=predict, inputs="image", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  iface.launch()
26
-
 
 
 
1
  import torch
2
+ import torchvision.transforms as T
3
  from PIL import Image
 
4
  import joblib
5
+ import json
6
+ import cv2
7
+ import gradio as gr
8
+
9
+ # Define image transformation
10
+ transform_image = T.Compose([
11
+ T.ToTensor(),
12
+ T.Resize(244),
13
+ T.CenterCrop(224),
14
+ T.Normalize([0.5], [0.5])
15
+ ])
16
 
17
+ def load_image(img: str) -> torch.Tensor:
18
+ """
19
+ Load an image and return a tensor that can be used as an input to DINOv2.
20
+ """
21
+ img = Image.open(img)
22
+ transformed_img = transform_image(img)[:3].unsqueeze(0)
23
+ return transformed_img
24
+
25
+ # Load models for inference
26
+ dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
27
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
28
+ dinov2_vits14.to(device)
29
+ dinov2_vits14.eval() # Set the model to evaluation mode
30
+
31
+ # Load the classifier
32
  clf = joblib.load('svm_model.joblib')
33
 
34
+ # Load the embeddings
35
+ with open('all_embeddings.json', 'r') as f:
36
+ embeddings = json.load(f)
37
 
38
+ # Predict class for a new image
39
+ def predict_image_class(image_path):
40
+ new_image = load_image(image_path).to(device)
41
  with torch.no_grad():
42
+ embedding = dinov2_vits14(new_image).cpu().numpy().reshape(1, -1)
43
+ prediction = clf.predict(embedding)
44
  return prediction[0]
45
 
46
+ # Gradio interface
47
+ def classify_image(image):
48
+ predicted_class = predict_image_class(image)
49
+ return f"Predicted class: {predicted_class}"
50
+
51
+ # Define the Gradio interface
52
+ iface = gr.Interface(
53
+ fn=classify_image,
54
+ inputs=gr.Image(type="filepath"),
55
+ outputs="text",
56
+ title="Currency Classifier",
57
+ description="Upload an image of currency to classify."
58
+ )
59
+
60
+ # Launch the Gradio interface
61
  iface.launch()