hkhho commited on
Commit
872be7d
·
1 Parent(s): faea0f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -34
app.py CHANGED
@@ -1,78 +1,81 @@
1
  import streamlit as st
2
  import cv2
3
- import skimage
4
  import numpy as np
5
  from PIL import Image
6
  import torch
7
  import matplotlib.pyplot as plt
8
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
9
  from transformers.image_utils import ImageFeatureExtractionMixin
10
- import requests
11
 
12
- # Set up the model
 
13
  model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
14
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
15
  device = torch.device("cpu")
16
  model = model.to(device)
17
  model.eval()
18
 
19
- # Set up the page layout
20
- st.set_page_config(layout="wide")
21
- title = """<h1 id="title">Zero-shot Object Detection</h1>"""
22
- st.markdown(title, unsafe_allow_html=True)
23
- col1, col2 = st.beta_columns(2)
24
 
25
- # Upload the input image
 
26
  with col1:
27
- st.subheader("Input Image")
28
- uploaded_image = st.file_uploader(label="Upload an image", type=["jpg", "jpeg", "png"])
 
 
29
 
30
- # Upload the query image
31
  with col2:
32
- st.subheader("Query Image")
33
- query_url = "https://assets.gamepur.com/wp-content/uploads/2022/07/29191902/DIGIMON-SURVIVE_20220729122707-1.jpg"
34
- query_image = Image.open(requests.get(query_url, stream=True).raw)
35
- st.image(query_image, caption="Query Image")
36
 
37
- # Set the threshold ratio
38
- threshold_ratio = st.slider("Select a threshold ratio:", min_value=0.1, max_value=1.0, value=0.6, step=0.1)
 
39
 
40
- # Set the text query
41
- text_query = st.text_input("Enter the class name:")
42
-
43
- # Process the input and query images
44
- if uploaded_image is not None:
45
- image = Image.open(uploaded_image)
46
  target_sizes = torch.Tensor([image.size[::-1]])
47
- inputs = processor(images=image, query_images=query_image, text_queries=[text_query], return_tensors="pt").to(device)
48
 
49
- # Run the model
50
  with torch.no_grad():
51
  outputs = model.image_guided_detection(**inputs)
52
 
53
- # Post-process the results
54
  img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
55
  outputs.logits = outputs.logits.cpu()
56
  outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
 
57
  results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold_ratio, nms_threshold=0.3, target_sizes=target_sizes)
58
  boxes, scores = results[0]["boxes"], results[0]["scores"]
59
 
60
- # Draw the predicted bounding boxes
61
  for box, score in zip(boxes, scores):
62
  box = [int(i) for i in box.tolist()]
63
- cx, cy, x, y = box
64
- img = cv2.rectangle(img, box[:2], box[2:], (255, 0, 0), 5)
65
  if box[3] + 25 > 768:
66
  y = box[3] - 10
67
  else:
68
  y = box[3] + 25
69
 
70
  plt.imshow(img[:,:,::-1])
71
- plt.text(cx, cy, text_query + " " + str(round(score.tolist(), 2)), ha="left", va="top", color="red", bbox={"facecolor": "white", "edgecolor": "red", "boxstyle": "square,pad=.3"})
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Display the output image
74
- output_image = plt.gcf().canvas.tostring_rgb()
75
- plt.clf()
76
  st.image(output_image, caption='Predicted Image', use_column_width=True)
77
 
78
  else:
 
1
  import streamlit as st
2
  import cv2
 
3
  import numpy as np
4
  from PIL import Image
5
  import torch
6
  import matplotlib.pyplot as plt
7
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
8
  from transformers.image_utils import ImageFeatureExtractionMixin
 
9
 
10
+ st.set_option('deprecation.showfileUploaderEncoding', False)
11
+
12
  model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
13
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
14
  device = torch.device("cpu")
15
  model = model.to(device)
16
  model.eval()
17
 
18
+ st.title('Zero-shot Object Detection')
 
 
 
 
19
 
20
+ # Input image and query image upload
21
+ col1, col2 = st.beta_columns(2)
22
  with col1:
23
+ uploaded_image = st.file_uploader("Upload input image", type=["jpg", "jpeg", "png"])
24
+ if uploaded_image is not None:
25
+ image = Image.open(uploaded_image)
26
+ st.image(image, caption='Input Image', use_column_width=True)
27
 
 
28
  with col2:
29
+ uploaded_query = st.file_uploader("Upload query image", type=["jpg", "jpeg", "png"])
30
+ if uploaded_query is not None:
31
+ query_image = Image.open(uploaded_query)
32
+ st.image(query_image, caption='Query Image', use_column_width=True)
33
 
34
+ # Threshold ratio bar and class name input
35
+ threshold_ratio = st.slider('Select threshold ratio:', min_value=0.0, max_value=1.0, step=0.1, value=0.6)
36
+ class_name = st.text_input('Enter class name:', value='agumon')
37
 
38
+ if uploaded_image is not None and uploaded_query is not None:
39
+ # Process input and query image
40
+ text_queries = [class_name]
 
 
 
41
  target_sizes = torch.Tensor([image.size[::-1]])
42
+ inputs = processor(images=image, query_images=query_image, return_tensors="pt").to(device)
43
 
 
44
  with torch.no_grad():
45
  outputs = model.image_guided_detection(**inputs)
46
 
 
47
  img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
48
  outputs.logits = outputs.logits.cpu()
49
  outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
50
+
51
  results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold_ratio, nms_threshold=0.3, target_sizes=target_sizes)
52
  boxes, scores = results[0]["boxes"], results[0]["scores"]
53
 
54
+ # Draw predicted bounding boxes and text
55
  for box, score in zip(boxes, scores):
56
  box = [int(i) for i in box.tolist()]
57
+ cx,cy,x,y=box
58
+ img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
59
  if box[3] + 25 > 768:
60
  y = box[3] - 10
61
  else:
62
  y = box[3] + 25
63
 
64
  plt.imshow(img[:,:,::-1])
65
+ plt.text(
66
+ cx ,
67
+ cy ,
68
+ class_name+str(round(score.tolist(),2)) ,
69
+ ha="left",
70
+ va="top",
71
+ color="red",
72
+ bbox={
73
+ "facecolor": "white",
74
+ "edgecolor": "red",
75
+ "boxstyle": "square,pad=.3"
76
+ })
77
 
78
+ output_image = img[:,:,::-1]
 
 
79
  st.image(output_image, caption='Predicted Image', use_column_width=True)
80
 
81
  else: