import gradio as gr import cv2 from PIL import Image, ImageDraw, ImageFont import torch from transformers import Owlv2Processor, Owlv2ForObjectDetection import numpy as np import os import matplotlib.pyplot as plt from io import BytesIO import base64 # Check if CUDA is available, otherwise use CPU device = 'cuda' if torch.cuda.is_available() else 'cpu' processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16") model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device) def process_video(video_path, target, progress=gr.Progress()): if video_path is None: return None, None, "Error: No video uploaded" if not os.path.exists(video_path): return None, None, f"Error: Video file not found at {video_path}" cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None, None, f"Error: Unable to open video file at {video_path}" frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) original_fps = int(cap.get(cv2.CAP_PROP_FPS)) output_fps = 3 frame_duration = 1 / output_fps video_duration = frame_count / original_fps processed_frames = [] frame_scores = [] batch_size = 1 batch_frames = [] batch_times = [] for time in progress.tqdm(np.arange(0, video_duration, frame_duration)): frame_number = int(time * original_fps) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) ret, img = cap.read() if not ret: break pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) batch_frames.append(pil_img) batch_times.append(time) if len(batch_frames) == batch_size or time + frame_duration >= video_duration: # Process the batch texts = [[target]] * len(batch_frames) inputs = processor(text=texts, images=batch_frames, return_tensors="pt", padding=True).to(device) outputs = model(**inputs) for i, (image, batch_time) in enumerate(zip(batch_frames, batch_times)): target_sizes = torch.Tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes) draw = ImageDraw.Draw(image) max_score = 0 try: font = ImageFont.truetype("arial.ttf", 30) except IOError: font = ImageFont.load_default() boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] for box, score, label in zip(boxes, scores, labels): if score.item() >= 0.5: box = [round(i, 2) for i in box.tolist()] object_label = target confidence = round(score.item(), 3) annotation = f"{object_label}: {confidence}" draw.rectangle(box, outline="red", width=4) text_position = (box[0], box[1] - 30) draw.text(text_position, annotation, fill="white", font=font) max_score = max(max_score, confidence) processed_frames.append(np.array(image)) frame_scores.append(max_score) batch_frames = [] batch_times = [] cap.release() return processed_frames, frame_scores, None def create_heatmap(frame_scores): plt.figure(figsize=(10, 2)) plt.imshow([frame_scores], cmap='hot', aspect='auto') plt.colorbar(label='Confidence') plt.title('Object Detection Heatmap') plt.xlabel('Frame') plt.yticks([]) plt.tight_layout() buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close() return base64.b64encode(buf.getvalue()).decode('utf-8') def load_sample_frame(video_path): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None ret, frame = cap.read() cap.release() if not ret: return None frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return frame_rgb def gradio_app(): with gr.Blocks() as app: gr.Markdown("# Video Object Detection with Owlv2") video_input = gr.Video(label="Upload Video") target_input = gr.Textbox(label="Target Object", value="Elephant") frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0) output_image = gr.Image(label="Processed Frame") heatmap_output = gr.Image(label="Detection Heatmap") error_output = gr.Textbox(label="Error Messages", visible=False) sample_video_frame = gr.Image(value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4"), label="Sample Video Frame") use_sample_button = gr.Button("Use Sample Video") progress_bar = gr.Progress() processed_frames = gr.State([]) frame_scores = gr.State([]) def process_and_update(video, target): frames, scores, error = process_video(video, target, progress_bar) if frames is not None: heatmap = create_heatmap(scores) return frames, scores, frames[0], heatmap, error, gr.Slider(maximum=len(frames) - 1, value=0) return None, None, None, None, error, gr.Slider(maximum=100, value=0) def update_frame(frame_index, frames): if frames and 0 <= frame_index < len(frames): return frames[frame_index] return None video_input.upload(process_and_update, inputs=[video_input, target_input], outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider]) frame_slider.change(update_frame, inputs=[frame_slider, processed_frames], outputs=[output_image]) def use_sample_video(): sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4" return process_and_update(sample_video_path, "Elephant") use_sample_button.click(use_sample_video, inputs=None, outputs=[processed_frames, frame_scores, output_image, heatmap_output, error_output, frame_slider]) return app if __name__ == "__main__": app = gradio_app() app.launch(share=True)