JohnAlexander23 commited on
Commit
0ad2d2a
1 Parent(s): 51d22ab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoProcessor, AutoModelForObjectDetection
4
+ from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.patches as patches
9
+ import random
10
+
11
+ # Constants
12
+ EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg'
13
+ THRESHOLD = 0.2
14
+
15
+ # Load model and processor
16
+ @st.cache_resource
17
+ def load_model():
18
+ model_id = 'onnx-community/yolov10m'
19
+ processor = AutoProcessor.from_pretrained(model_id)
20
+ model = AutoModelForObjectDetection.from_pretrained(model_id)
21
+ return processor, model
22
+
23
+ processor, model = load_model()
24
+
25
+ # Function to detect objects in the image
26
+ def detect(image):
27
+ # Preprocess image
28
+ inputs = processor(images=image, return_tensors="pt")
29
+
30
+ # Predict bounding boxes
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+
34
+ # Extract bounding boxes and labels
35
+ target_sizes = torch.tensor([image.size[::-1]])
36
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=THRESHOLD)[0]
37
+
38
+ return results
39
+
40
+ # Function to render bounding boxes
41
+ def render_box(image, results):
42
+ plt.figure(figsize=(10, 10))
43
+ plt.imshow(image)
44
+ ax = plt.gca()
45
+
46
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
47
+ if score < THRESHOLD:
48
+ continue
49
+
50
+ color = tuple([random.random() for _ in range(3)]) # Random color for each box
51
+ xmin, ymin, xmax, ymax = box
52
+
53
+ rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor=color, facecolor='none')
54
+ ax.add_patch(rect)
55
+ plt.text(xmin, ymin, f"{processor.id2label[label.item()]}: {score:.2f}", color=color, fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
56
+
57
+ plt.axis('off')
58
+ st.pyplot(plt)
59
+
60
+ # Streamlit app
61
+ st.title("Object Detection with Hugging Face Transformers")
62
+
63
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
64
+
65
+ if uploaded_file is not None:
66
+ image = Image.open(uploaded_file)
67
+ results = detect(image)
68
+ render_box(image, results)
69
+ else:
70
+ if st.button("Try Example Image"):
71
+ response = requests.get(EXAMPLE_URL)
72
+ image = Image.open(BytesIO(response.content))
73
+ results = detect(image)
74
+ render_box(image, results)