fish_count / app.py
srinuksv's picture
Update app.py
3f60342 verified
raw
history blame
2.07 kB
import cv2
import torch
import gradio as gr
from torchvision.utils import draw_bounding_boxes
# Load the model
model_path = "/content/R_CNN.pth"
model = torch.load(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Define classes if not already defined
classes = ['creatures',
'fish',
'jellyfish',
'penguin',
'puffin',
'shark',
'starfish',
'stingray'] # List of class labels
# Define function for processing video
def process_video(input_video):
output_path = 'video_output.avi'
cap = cv2.VideoCapture(input_video.name)
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(output_path, fourcc, fps, (int(cap.get(3)), int(cap.get(4))))
threshold = 0.8 # Confidence threshold for bounding boxes
while True:
ret, frame = cap.read()
if not ret:
break
img = torch.tensor(frame.transpose(2, 0, 1) / 255.0, dtype=torch.float32).to(device)
img = img.unsqueeze(0)
with torch.no_grad():
prediction = model(img)
pred = prediction[0]
img_int = torch.tensor(frame, dtype=torch.uint8)
if img_int.shape[2] > 3:
img_int = img_int[:, :, :3]
drawn_frame = draw_bounding_boxes(
img_int.permute(2, 0, 1),
pred['boxes'][pred['scores'] > threshold],
[classes[i] for i in pred['labels'][pred['scores'] > threshold].tolist()],
width=4
).permute(1, 2, 0)
drawn_frame = drawn_frame.cpu().numpy()
out.write(drawn_frame)
cap.release()
out.release()
return output_path
# Create Gradio interface
video_input = gr.inputs.Video(label="Input Video")
processed_video = gr.outputs.Video(label="Processed Video")
gr.Interface(
fn=process_video,
inputs=video_input,
outputs=processed_video,
title="Object Detection in Video",
description="Detect objects in a video using the trained model.",
server_name="0.0.0.0"
).launch()