Spaces:
Runtime error
Runtime error
import streamlit as st | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import torch | |
import matplotlib.pyplot as plt | |
from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
from transformers.image_utils import ImageFeatureExtractionMixin | |
st.set_option('deprecation.showfileUploaderEncoding', False) | |
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") | |
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") | |
device = torch.device("cpu") | |
model = model.to(device) | |
model.eval() | |
st.title('Zero-shot Object Detection') | |
# Input image and query image upload | |
col1, col2 = st.columns(2) | |
with col1: | |
uploaded_image = st.file_uploader("Upload input image(image to predict)", type=["jpg", "jpeg", "png"]) | |
if uploaded_image is not None: | |
image = Image.open(uploaded_image) | |
st.image(image, caption='Input Image', use_column_width=True) | |
#else: | |
#st.image('2.png', caption='Input Image', use_column_width=True) | |
with col2: | |
uploaded_query = st.file_uploader("Upload query image(image contains object we wanna predict)", type=["jpg", "jpeg", "png"]) | |
if uploaded_query is not None: | |
query_image = Image.open(uploaded_query) | |
st.image(query_image, caption='Query Image', use_column_width=True) | |
#else: | |
#st.image('1.png', caption='Input Image', use_column_width=True) | |
# Threshold ratio bar and class name input | |
threshold_ratio = st.slider('Select threshold ratio:', min_value=0.0, max_value=1.0, step=0.1, value=0.6) | |
#class_name = st.text_input('Enter class name:', value='agumon') | |
# | |
start_button = st.button('Start prediction') | |
if uploaded_image is not None and uploaded_query is not None and start_button: | |
# Process input and query image | |
target_sizes = torch.Tensor([image.size[::-1]]) | |
inputs = processor(images=image, query_images=query_image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model.image_guided_detection(**inputs) | |
img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB) | |
outputs.logits = outputs.logits.cpu() | |
outputs.target_pred_boxes = outputs.target_pred_boxes.cpu() | |
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold_ratio, nms_threshold=0.3, target_sizes=target_sizes) | |
boxes, scores = results[0]["boxes"], results[0]["scores"] | |
# Draw predicted bounding boxes and text | |
for box, score in zip(boxes, scores): | |
box = [int(i) for i in box.tolist()] | |
cx,cy,x,y=box | |
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) | |
if box[3] + 25 > 768: | |
y = box[3] - 10 | |
else: | |
y = box[3] + 25 | |
plt.imshow(img[:,:,::-1]) | |
output_image = img[:,:,::-1] | |
st.image(output_image, caption='Predicted Image', use_column_width=True) | |
else: | |
st.write('Please upload an image and a query image.') | |