File size: 2,902 Bytes
c4412fa
d152885
 
8af6c47
d152885
 
faea0f9
8af6c47
b200e0c
9fe8876
872be7d
 
faea0f9
 
 
 
 
 
872be7d
faea0f9
872be7d
b14627c
faea0f9
9d9a64e
872be7d
 
 
9d9a64e
 
faea0f9
9d9a64e
872be7d
 
 
9d9a64e
 
faea0f9
872be7d
 
ec69de3
d295684
 
 
872be7d
faea0f9
872be7d
faea0f9
 
 
 
 
 
 
872be7d
faea0f9
 
 
872be7d
faea0f9
 
872be7d
 
faea0f9
 
 
 
b9f68bd
19c3ae0
e09b3fa
b02a893
 
8af6c47
d152885
faea0f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import torch
import matplotlib.pyplot as plt
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from transformers.image_utils import ImageFeatureExtractionMixin


st.set_option('deprecation.showfileUploaderEncoding', False)

model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
device = torch.device("cpu")
model = model.to(device)
model.eval()

st.title('Zero-shot Object Detection')

# Input image and query image upload
col1, col2 = st.columns(2)
with col1:
    uploaded_image = st.file_uploader("Upload input image(image to predict)", type=["jpg", "jpeg", "png"])
    if uploaded_image is not None:
        image = Image.open(uploaded_image)
        st.image(image, caption='Input Image', use_column_width=True)
    #else:
        #st.image('2.png', caption='Input Image', use_column_width=True)
with col2:
    uploaded_query = st.file_uploader("Upload query image(image contains object we wanna predict)", type=["jpg", "jpeg", "png"])
    if uploaded_query is not None:
        query_image = Image.open(uploaded_query)
        st.image(query_image, caption='Query Image', use_column_width=True)
    #else:
        #st.image('1.png', caption='Input Image', use_column_width=True)

# Threshold ratio bar and class name input
threshold_ratio = st.slider('Select threshold ratio:', min_value=0.0, max_value=1.0, step=0.1, value=0.6)
#class_name = st.text_input('Enter class name:', value='agumon')
# 
start_button = st.button('Start prediction')
if uploaded_image is not None and uploaded_query is not None and start_button:
    # Process input and query image
    target_sizes = torch.Tensor([image.size[::-1]])
    inputs = processor(images=image, query_images=query_image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.image_guided_detection(**inputs)

    img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
    outputs.logits = outputs.logits.cpu()
    outputs.target_pred_boxes = outputs.target_pred_boxes.cpu() 

    results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold_ratio, nms_threshold=0.3, target_sizes=target_sizes)
    boxes, scores = results[0]["boxes"], results[0]["scores"]

    # Draw predicted bounding boxes and text
    for box, score in zip(boxes, scores):
        box = [int(i) for i in box.tolist()]
        cx,cy,x,y=box
        img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
        if box[3] + 25 > 768:
            y = box[3] - 10
        else:
            y = box[3] + 25 
    
        plt.imshow(img[:,:,::-1])
  
    output_image = img[:,:,::-1]
    st.image(output_image, caption='Predicted Image', use_column_width=True)

else:
    st.write('Please upload an image and a query image.')