import cv2 import torch import gradio as gr from torchvision.utils import draw_bounding_boxes # Load the model model_path = "R_CNN.pth" model = torch.load(model_path, map_location=torch.device('cpu')) model.eval() # Define classes if not already defined classes = ['creatures', 'fish', 'jellyfish', 'penguin', 'puffin', 'shark', 'starfish', 'stingray'] # List of class labels # Define function for processing video def process_video(input_video): if isinstance(input_video, str): # This is the case when the input is a filename input_video_path = input_video else: # This is the case when the input is a file object input_video_path = input_video.name output_path = 'video_output.avi' cap = cv2.VideoCapture(input_video_path) fps = cap.get(cv2.CAP_PROP_FPS) fourcc = cv2.VideoWriter_fourcc(*'XVID') out = cv2.VideoWriter(output_path, fourcc, fps, (int(cap.get(3)), int(cap.get(4)))) threshold = 0.8 # Confidence threshold for bounding boxes while True: ret, frame = cap.read() if not ret: break img = torch.tensor(frame.transpose(2, 0, 1) / 255.0, dtype=torch.float32) img = img.unsqueeze(0) with torch.no_grad(): prediction = model(img) pred = prediction[0] img_int = torch.tensor(frame, dtype=torch.uint8) if img_int.shape[2] > 3: img_int = img_int[:, :, :3] drawn_frame = draw_bounding_boxes( img_int.permute(2, 0, 1), pred['boxes'][pred['scores'] > threshold], [classes[i] for i in pred['labels'][pred['scores'] > threshold].tolist()], width=4 ).permute(1, 2, 0) drawn_frame = drawn_frame.cpu().numpy() out.write(drawn_frame) cap.release() out.release() return output_path video_input = gr.Video(label="Input Video") processed_video = gr.Image(label="Processed Video") # No 'outputs' submodule interface = gr.Interface( fn=process_video, inputs=video_input, outputs=processed_video, title="Object Detection in Video", description="Detect objects in a video using the trained model.", ) interface.launch()