Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
import gradio as gr | |
from torchvision.utils import draw_bounding_boxes | |
# Load the model | |
model_path = "/content/R_CNN.pth" | |
model = torch.load(model_path) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
# Define classes if not already defined | |
classes = ['creatures', | |
'fish', | |
'jellyfish', | |
'penguin', | |
'puffin', | |
'shark', | |
'starfish', | |
'stingray'] # List of class labels | |
# Define function for processing video | |
def process_video(input_video): | |
output_path = 'video_output.avi' | |
cap = cv2.VideoCapture(input_video.name) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (int(cap.get(3)), int(cap.get(4)))) | |
threshold = 0.8 # Confidence threshold for bounding boxes | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
img = torch.tensor(frame.transpose(2, 0, 1) / 255.0, dtype=torch.float32).to(device) | |
img = img.unsqueeze(0) | |
with torch.no_grad(): | |
prediction = model(img) | |
pred = prediction[0] | |
img_int = torch.tensor(frame, dtype=torch.uint8) | |
if img_int.shape[2] > 3: | |
img_int = img_int[:, :, :3] | |
drawn_frame = draw_bounding_boxes( | |
img_int.permute(2, 0, 1), | |
pred['boxes'][pred['scores'] > threshold], | |
[classes[i] for i in pred['labels'][pred['scores'] > threshold].tolist()], | |
width=4 | |
).permute(1, 2, 0) | |
drawn_frame = drawn_frame.cpu().numpy() | |
out.write(drawn_frame) | |
cap.release() | |
out.release() | |
return output_path | |
# Create Gradio interface | |
video_input = gr.inputs.Video(label="Input Video") | |
processed_video = gr.outputs.Video(label="Processed Video") | |
gr.Interface( | |
fn=process_video, | |
inputs=video_input, | |
outputs=processed_video, | |
title="Object Detection in Video", | |
description="Detect objects in a video using the trained model.", | |
server_name="0.0.0.0" | |
).launch() | |