Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
import gradio as gr | |
from torchvision.utils import draw_bounding_boxes | |
# Load the model | |
model_path = "R_CNN.pth" | |
model = torch.load(model_path, map_location=torch.device('cpu')) | |
model.eval() | |
# Define classes if not already defined | |
classes = ['creatures', 'fish', 'jellyfish', 'penguin', 'puffin', 'shark', 'starfish', 'stingray'] # List of class labels | |
# Define function for processing video | |
def process_video(input_video): | |
if isinstance(input_video, str): | |
# This is the case when the input is a filename | |
input_video_path = input_video | |
else: | |
# This is the case when the input is a file object | |
input_video_path = input_video.name | |
output_path = 'video_output.avi' | |
cap = cv2.VideoCapture(input_video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (int(cap.get(3)), int(cap.get(4)))) | |
threshold = 0.8 # Confidence threshold for bounding boxes | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
img = torch.tensor(frame.transpose(2, 0, 1) / 255.0, dtype=torch.float32) | |
img = img.unsqueeze(0) | |
with torch.no_grad(): | |
prediction = model(img) | |
pred = prediction[0] | |
img_int = torch.tensor(frame, dtype=torch.uint8) | |
if img_int.shape[2] > 3: | |
img_int = img_int[:, :, :3] | |
drawn_frame = draw_bounding_boxes( | |
img_int.permute(2, 0, 1), | |
pred['boxes'][pred['scores'] > threshold], | |
[classes[i] for i in pred['labels'][pred['scores'] > threshold].tolist()], | |
width=4 | |
).permute(1, 2, 0) | |
drawn_frame = drawn_frame.cpu().numpy() | |
out.write(drawn_frame) | |
cap.release() | |
out.release() | |
return output_path | |
video_input = gr.Video(label="Input Video") | |
processed_video = gr.Image(label="Processed Video") # No 'outputs' submodule | |
interface = gr.Interface( | |
fn=process_video, | |
inputs=video_input, | |
outputs=processed_video, | |
title="Object Detection in Video", | |
description="Detect objects in a video using the trained model.", | |
) | |
interface.launch() | |