hkhho's picture
Update app.py
9d9a64e
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import torch
import matplotlib.pyplot as plt
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from transformers.image_utils import ImageFeatureExtractionMixin
st.set_option('deprecation.showfileUploaderEncoding', False)
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
device = torch.device("cpu")
model = model.to(device)
model.eval()
st.title('Zero-shot Object Detection')
# Input image and query image upload
col1, col2 = st.columns(2)
with col1:
uploaded_image = st.file_uploader("Upload input image(image to predict)", type=["jpg", "jpeg", "png"])
if uploaded_image is not None:
image = Image.open(uploaded_image)
st.image(image, caption='Input Image', use_column_width=True)
#else:
#st.image('2.png', caption='Input Image', use_column_width=True)
with col2:
uploaded_query = st.file_uploader("Upload query image(image contains object we wanna predict)", type=["jpg", "jpeg", "png"])
if uploaded_query is not None:
query_image = Image.open(uploaded_query)
st.image(query_image, caption='Query Image', use_column_width=True)
#else:
#st.image('1.png', caption='Input Image', use_column_width=True)
# Threshold ratio bar and class name input
threshold_ratio = st.slider('Select threshold ratio:', min_value=0.0, max_value=1.0, step=0.1, value=0.6)
#class_name = st.text_input('Enter class name:', value='agumon')
#
start_button = st.button('Start prediction')
if uploaded_image is not None and uploaded_query is not None and start_button:
# Process input and query image
target_sizes = torch.Tensor([image.size[::-1]])
inputs = processor(images=image, query_images=query_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs)
img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
outputs.logits = outputs.logits.cpu()
outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold_ratio, nms_threshold=0.3, target_sizes=target_sizes)
boxes, scores = results[0]["boxes"], results[0]["scores"]
# Draw predicted bounding boxes and text
for box, score in zip(boxes, scores):
box = [int(i) for i in box.tolist()]
cx,cy,x,y=box
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
if box[3] + 25 > 768:
y = box[3] - 10
else:
y = box[3] + 25
plt.imshow(img[:,:,::-1])
output_image = img[:,:,::-1]
st.image(output_image, caption='Predicted Image', use_column_width=True)
else:
st.write('Please upload an image and a query image.')