fish_count / app.py
srinuksv's picture
Update app.py
dde9364 verified
raw
history blame contribute delete
No virus
2.21 kB
import cv2
import torch
import gradio as gr
from torchvision.utils import draw_bounding_boxes
# Load the model
model_path = "R_CNN.pth"
model = torch.load(model_path, map_location=torch.device('cpu'))
model.eval()
# Define classes if not already defined
classes = ['creatures', 'fish', 'jellyfish', 'penguin', 'puffin', 'shark', 'starfish', 'stingray'] # List of class labels
# Define function for processing video
def process_video(input_video):
if isinstance(input_video, str):
# This is the case when the input is a filename
input_video_path = input_video
else:
# This is the case when the input is a file object
input_video_path = input_video.name
output_path = 'video_output.avi'
cap = cv2.VideoCapture(input_video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(output_path, fourcc, fps, (int(cap.get(3)), int(cap.get(4))))
threshold = 0.8 # Confidence threshold for bounding boxes
while True:
ret, frame = cap.read()
if not ret:
break
img = torch.tensor(frame.transpose(2, 0, 1) / 255.0, dtype=torch.float32)
img = img.unsqueeze(0)
with torch.no_grad():
prediction = model(img)
pred = prediction[0]
img_int = torch.tensor(frame, dtype=torch.uint8)
if img_int.shape[2] > 3:
img_int = img_int[:, :, :3]
drawn_frame = draw_bounding_boxes(
img_int.permute(2, 0, 1),
pred['boxes'][pred['scores'] > threshold],
[classes[i] for i in pred['labels'][pred['scores'] > threshold].tolist()],
width=4
).permute(1, 2, 0)
drawn_frame = drawn_frame.cpu().numpy()
out.write(drawn_frame)
cap.release()
out.release()
return output_path
video_input = gr.Video(label="Input Video")
processed_video = gr.Image(label="Processed Video") # No 'outputs' submodule
interface = gr.Interface(
fn=process_video,
inputs=video_input,
outputs=processed_video,
title="Object Detection in Video",
description="Detect objects in a video using the trained model.",
)
interface.launch()