IbrahimHasani's picture
Update app.py
3f492a7 verified
raw
history blame
No virus
8.23 kB
import gradio as gr
import torch
import numpy as np
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from torchvision import transforms
from PIL import Image, ImageDraw
import cv2
import torch.nn.functional as F
import tempfile
import os
from SuperGluePretrainedNetwork.models.matching import Matching
from SuperGluePretrainedNetwork.models.utils import read_image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models
mixin = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
model = mixin.to(device)
matching = Matching({
'superpoint': {'nms_radius': 4, 'keypoint_threshold': 0.005, 'max_keypoints': 1024},
'superglue': {'weights': 'outdoor', 'sinkhorn_iterations': 20, 'match_threshold': 0.2}
}).eval().to(device)
# Utility functions
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
def save_array_to_temp_image(arr):
rgb_arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
img = Image.fromarray(rgb_arr)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file_name = temp_file.name
temp_file.close()
img.save(temp_file_name)
return temp_file_name
def stitch_images(images):
if not images:
return Image.new('RGB', (100, 100), color='gray')
max_width = max([img.width for img in images])
total_height = sum(img.height for img in images)
composite = Image.new('RGB', (max_width, total_height))
y_offset = 0
for img in images:
composite.paste(img, (0, y_offset))
y_offset += img.height
return composite
def unified_matching_plot2(image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, text, path=None, show_keypoints=False, fast_viz=False, opencv_display=False, opencv_title='matches', small_text=[]):
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.imshow(image0)
plt.scatter(kpts0[:, 0], kpts0[:, 1], color='r', s=1)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image1)
plt.scatter(kpts1[:, 0], kpts1[:, 1], color='r', s=1)
plt.axis('off')
fig, ax = plt.subplots(figsize=(20, 20))
plt.plot([mkpts0[:, 0], mkpts1[:, 0] + image0.shape[1]], [mkpts0[:, 1], mkpts1[:, 1]], 'r', lw=0.5)
plt.scatter(mkpts0[:, 0], mkpts0[:, 1], s=2, marker='o', color='b')
plt.scatter(mkpts1[:, 0] + image0.shape[1], mkpts1[:, 1], s=2, marker='o', color='g')
plt.imshow(np.hstack([image0, image1]), aspect='auto')
plt.suptitle('\n'.join(text), fontsize=20, fontweight='bold')
plt.tight_layout()
plt.show()
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.close(fig)
return img
# Main functions
def detect_and_crop(target_image, query_image, threshold=0.5, nms_threshold=0.3):
target_sizes = torch.Tensor([target_image.size[::-1]])
inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs)
img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB)
outputs.logits = outputs.logits.cpu()
outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
boxes, scores = results[0]["boxes"], results[0]["scores"]
if len(boxes) == 0:
return [], None
filtered_boxes = []
for box in boxes:
x1, y1, x2, y2 = [int(i) for i in box.tolist()]
cropped_img = img[y1:y2, x1:x2]
if cropped_img.size != 0:
filtered_boxes.append(cropped_img)
draw = ImageDraw.Draw(target_image)
for box in boxes:
draw.rectangle(box.tolist(), outline="red", width=3)
return filtered_boxes, target_image
def image_matching_no_pyramid(query_img, target_img, visualize=True):
temp_query = save_array_to_temp_image(np.array(query_img))
temp_target = save_array_to_temp_image(np.array(target_img))
image1, inp1, scales1 = read_image(temp_target, device, [640*2], 0, True)
image0, inp0, scales0 = read_image(temp_query, device, [640*2], 0, True)
if image0 is None or image1 is None:
return None
pred = matching({'image0': inp0, 'image1': inp1})
pred = {k: v[0] for k, v in pred.items()}
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
matches, conf = pred['matches0'], pred['matching_scores0']
valid = matches > -1
mkpts0 = kpts0[valid]
mkpts1 = kpts1[matches[valid]]
mconf = conf[valid]
color = cm.jet(mconf.detach().cpu().numpy())[:len(mkpts0)]
valid_count = np.sum(valid.tolist())
mkpts0_np = mkpts0.cpu().numpy()
mkpts1_np = mkpts1.cpu().numpy()
try:
H, inliers = cv2.findHomography(mkpts0_np, mkpts1_np, cv2.RANSAC, 5.0)
except:
inliers = 0
num_inliers = np.sum(inliers)
if visualize:
visualized_img = unified_matching_plot2(
image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, ['Matches'], True, False, True, 'Matches', [])
else:
visualized_img = None
return {
'valid': [valid_count],
'inliers': [num_inliers],
'visualized_image': [visualized_img]
}
def check_object_in_image(query_image, target_image, threshold=50, scale_factor=[0.33, 0.66, 1]):
images_to_return = []
cropped_images, bbox_image = detect_and_crop(target_image, query_image)
temp_files = [save_array_to_temp_image(i) for i in cropped_images]
crop_results = [image_matching_no_pyramid(query_image, Image.open(i), visualize=True) for i in temp_files]
cropped_visuals = []
cropped_inliers = []
for result in crop_results:
if result:
for img in result['visualized_image']:
cropped_visuals.append(Image.fromarray(img))
for inliers_ in result['inliers']:
cropped_inliers.append(inliers_)
images_to_return.append(stitch_images(cropped_visuals))
is_present = any(value >= threshold for value in cropped_inliers)
return {
'is_present': is_present,
'images': images_to_return,
'object detection inliers': [int(i) for i in cropped_inliers],
'bbox_image': bbox_image,
}
def interface(poster_source, media_source, threshold, scale_factor):
result1 = check_object_in_image(poster_source, media_source, threshold, scale_factor)
if result1['is_present']:
return result1
result2 = check_object_in_image(poster_source, media_source, threshold, scale_factor)
return result2 if result2['is_present'] else result1
iface = gr.Interface(
fn=interface,
inputs=[
gr.Image(type="pil", label="Upload a Query Image (Poster)"),
gr.Image(type="pil", label="Upload a Target Image (Media)"),
gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Threshold"),
gr.CheckboxGroup(choices=[0.33, 0.66, 1.0], value=[0.33, 0.66, 1.0], label="Scale Factors")
],
outputs=[
gr.JSON(label="Result")
],
title="Object Detection in Image",
description="""
**Instructions:**
1. **Upload a Query Image (Poster)**: Select an image file that contains the object you want to detect.
2. **Upload a Target Image (Media)**: Select an image file where you want to detect the object.
3. **Set Threshold**: Adjust the slider to set the threshold for object detection.
4. **Set Scale Factors**: Select the scale factors for image pyramid.
5. **View Results**: The result will show whether the object is present in the image along with additional details.
"""
)
if __name__ == "__main__":
iface.launch()