Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,78 +1,81 @@
|
|
1 |
import streamlit as st
|
2 |
import cv2
|
3 |
-
import skimage
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
import matplotlib.pyplot as plt
|
8 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
9 |
from transformers.image_utils import ImageFeatureExtractionMixin
|
10 |
-
import requests
|
11 |
|
12 |
-
|
|
|
13 |
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
|
14 |
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
|
15 |
device = torch.device("cpu")
|
16 |
model = model.to(device)
|
17 |
model.eval()
|
18 |
|
19 |
-
|
20 |
-
st.set_page_config(layout="wide")
|
21 |
-
title = """<h1 id="title">Zero-shot Object Detection</h1>"""
|
22 |
-
st.markdown(title, unsafe_allow_html=True)
|
23 |
-
col1, col2 = st.beta_columns(2)
|
24 |
|
25 |
-
#
|
|
|
26 |
with col1:
|
27 |
-
st.
|
28 |
-
uploaded_image
|
|
|
|
|
29 |
|
30 |
-
# Upload the query image
|
31 |
with col2:
|
32 |
-
st.
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
#
|
38 |
-
threshold_ratio = st.slider(
|
|
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
# Process the input and query images
|
44 |
-
if uploaded_image is not None:
|
45 |
-
image = Image.open(uploaded_image)
|
46 |
target_sizes = torch.Tensor([image.size[::-1]])
|
47 |
-
inputs = processor(images=image, query_images=query_image,
|
48 |
|
49 |
-
# Run the model
|
50 |
with torch.no_grad():
|
51 |
outputs = model.image_guided_detection(**inputs)
|
52 |
|
53 |
-
# Post-process the results
|
54 |
img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
|
55 |
outputs.logits = outputs.logits.cpu()
|
56 |
outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
|
|
|
57 |
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold_ratio, nms_threshold=0.3, target_sizes=target_sizes)
|
58 |
boxes, scores = results[0]["boxes"], results[0]["scores"]
|
59 |
|
60 |
-
# Draw
|
61 |
for box, score in zip(boxes, scores):
|
62 |
box = [int(i) for i in box.tolist()]
|
63 |
-
cx,
|
64 |
-
img = cv2.rectangle(img, box[:2], box[2:], (255,
|
65 |
if box[3] + 25 > 768:
|
66 |
y = box[3] - 10
|
67 |
else:
|
68 |
y = box[3] + 25
|
69 |
|
70 |
plt.imshow(img[:,:,::-1])
|
71 |
-
plt.text(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
output_image = plt.gcf().canvas.tostring_rgb()
|
75 |
-
plt.clf()
|
76 |
st.image(output_image, caption='Predicted Image', use_column_width=True)
|
77 |
|
78 |
else:
|
|
|
1 |
import streamlit as st
|
2 |
import cv2
|
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import torch
|
6 |
import matplotlib.pyplot as plt
|
7 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
8 |
from transformers.image_utils import ImageFeatureExtractionMixin
|
|
|
9 |
|
10 |
+
st.set_option('deprecation.showfileUploaderEncoding', False)
|
11 |
+
|
12 |
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
|
13 |
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
|
14 |
device = torch.device("cpu")
|
15 |
model = model.to(device)
|
16 |
model.eval()
|
17 |
|
18 |
+
st.title('Zero-shot Object Detection')
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
# Input image and query image upload
|
21 |
+
col1, col2 = st.beta_columns(2)
|
22 |
with col1:
|
23 |
+
uploaded_image = st.file_uploader("Upload input image", type=["jpg", "jpeg", "png"])
|
24 |
+
if uploaded_image is not None:
|
25 |
+
image = Image.open(uploaded_image)
|
26 |
+
st.image(image, caption='Input Image', use_column_width=True)
|
27 |
|
|
|
28 |
with col2:
|
29 |
+
uploaded_query = st.file_uploader("Upload query image", type=["jpg", "jpeg", "png"])
|
30 |
+
if uploaded_query is not None:
|
31 |
+
query_image = Image.open(uploaded_query)
|
32 |
+
st.image(query_image, caption='Query Image', use_column_width=True)
|
33 |
|
34 |
+
# Threshold ratio bar and class name input
|
35 |
+
threshold_ratio = st.slider('Select threshold ratio:', min_value=0.0, max_value=1.0, step=0.1, value=0.6)
|
36 |
+
class_name = st.text_input('Enter class name:', value='agumon')
|
37 |
|
38 |
+
if uploaded_image is not None and uploaded_query is not None:
|
39 |
+
# Process input and query image
|
40 |
+
text_queries = [class_name]
|
|
|
|
|
|
|
41 |
target_sizes = torch.Tensor([image.size[::-1]])
|
42 |
+
inputs = processor(images=image, query_images=query_image, return_tensors="pt").to(device)
|
43 |
|
|
|
44 |
with torch.no_grad():
|
45 |
outputs = model.image_guided_detection(**inputs)
|
46 |
|
|
|
47 |
img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
|
48 |
outputs.logits = outputs.logits.cpu()
|
49 |
outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
|
50 |
+
|
51 |
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold_ratio, nms_threshold=0.3, target_sizes=target_sizes)
|
52 |
boxes, scores = results[0]["boxes"], results[0]["scores"]
|
53 |
|
54 |
+
# Draw predicted bounding boxes and text
|
55 |
for box, score in zip(boxes, scores):
|
56 |
box = [int(i) for i in box.tolist()]
|
57 |
+
cx,cy,x,y=box
|
58 |
+
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
|
59 |
if box[3] + 25 > 768:
|
60 |
y = box[3] - 10
|
61 |
else:
|
62 |
y = box[3] + 25
|
63 |
|
64 |
plt.imshow(img[:,:,::-1])
|
65 |
+
plt.text(
|
66 |
+
cx ,
|
67 |
+
cy ,
|
68 |
+
class_name+str(round(score.tolist(),2)) ,
|
69 |
+
ha="left",
|
70 |
+
va="top",
|
71 |
+
color="red",
|
72 |
+
bbox={
|
73 |
+
"facecolor": "white",
|
74 |
+
"edgecolor": "red",
|
75 |
+
"boxstyle": "square,pad=.3"
|
76 |
+
})
|
77 |
|
78 |
+
output_image = img[:,:,::-1]
|
|
|
|
|
79 |
st.image(output_image, caption='Predicted Image', use_column_width=True)
|
80 |
|
81 |
else:
|