IbrahimHasani's picture
Update app.py
3f1f1a2 verified
import gradio as gr
import torch
import numpy as np
from transformers import OwlViTProcessor, OwlViTForObjectDetection, ResNetModel
from torchvision import transforms
from PIL import Image
import cv2
import torch.nn.functional as F
import tempfile
import os
# Load models
resnet = ResNetModel.from_pretrained("microsoft/resnet-50")
resnet.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet = resnet.to(device)
mixin = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
model = mixin.to(device)
# Preprocess the image
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
def extract_embedding(image):
image_tensor = preprocess_image(image).to(device)
with torch.no_grad():
output = resnet(image_tensor)
embedding = output.pooler_output
return embedding
def cosine_similarity(embedding1, embedding2):
return F.cosine_similarity(embedding1, embedding2)
def l2_distance(embedding1, embedding2):
return torch.norm(embedding1 - embedding2, p=2)
def save_array_to_temp_image(arr):
rgb_arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
img = Image.fromarray(rgb_arr)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file_name = temp_file.name
temp_file.close()
img.save(temp_file_name)
return temp_file_name
def detect_and_crop(target_image, query_image, threshold=0.6, nms_threshold=0.3):
target_sizes = torch.Tensor([target_image.size[::-1]])
inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs)
img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB)
outputs.logits = outputs.logits.cpu()
outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
boxes, scores = results[0]["boxes"], results[0]["scores"]
if len(boxes) == 0:
return []
filtered_boxes = []
for box in boxes:
x1, y1, x2, y2 = [int(i) for i in box.tolist()]
cropped_img = img[y1:y2, x1:x2]
if cropped_img.size != 0:
filtered_boxes.append(cropped_img)
return filtered_boxes
def process_video(video_path, query_image, skipframes=0):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return
frame_count = 0
all_results = []
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count % (skipframes + 1) == 0:
frame_file = save_array_to_temp_image(frame)
result_frames = detect_and_crop(Image.open(frame_file), query_image)
for res in result_frames:
saved_res = save_array_to_temp_image(res)
embedding1 = extract_embedding(query_image)
embedding2 = extract_embedding(Image.open(saved_res))
dist = l2_distance(embedding1, embedding2).item()
cos = cosine_similarity(embedding1, embedding2).item()
all_results.append({'l2_dist': dist, 'cos': cos})
frame_count += 1
cap.release()
return all_results
def process_videos_and_compare(image, video, skipframes=5, threshold=0.47):
def median(values):
n = len(values)
return (values[n // 2 - 1] + values[n // 2]) / 2 if n % 2 == 0 else values[n // 2]
results = process_video(video, image, skipframes)
if results:
l2_dists = [item['l2_dist'] for item in results]
cosines = [item['cos'] for item in results]
avg_l2_dist = sum(l2_dists) / len(l2_dists)
avg_cos = sum(cosines) / len(cosines)
median_l2_dist = median(sorted(l2_dists))
median_cos = median(sorted(cosines))
result = {
"avg_l2_dist": avg_l2_dist,
"avg_cos": avg_cos,
"median_l2_dist": median_l2_dist,
"median_cos": median_cos,
"avg_cos_dist": 1 - avg_cos,
"median_cos_dist": 1 - median_cos,
"is_present": avg_cos >= threshold
}
else:
result = {
"avg_l2_dist": float('inf'),
"avg_cos": 0,
"median_l2_dist": float('inf'),
"median_cos": 0,
"avg_cos_dist": float('inf'),
"median_cos_dist": float('inf'),
"is_present": False
}
return result
def interface(video, image, skipframes, threshold):
result = process_videos_and_compare(image, video, skipframes, threshold)
return result
iface = gr.Interface(
fn=interface,
inputs=[
gr.Video(label="Upload a Video"),
gr.Image(type="pil", label="Upload a Query Image"),
gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Skip Frames"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.47, label="Threshold")
],
outputs=[
gr.JSON(label="Result")
],
title="Object Detection in Video",
description="""
**Instructions:**
1. **Upload a Video**: Select a video file to upload.
2. **Upload a Query Image**: Select an image file that contains the object you want to detect in the video.
3. **Set Skip Frames**: Adjust the slider to set the number of frames to skip between each processing.
4. **Set Threshold**: Adjust the slider to set the threshold for cosine similarity to determine if the object is present in the video.
5. **View Results**: The result will show the average and median distances and similarities, and whether the object is present in the video based on the threshold.
"""
)
if __name__ == "__main__":
iface.launch()