Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from model import CNNLSTMClassifier | |
from utils import extract_frames | |
import shutil | |
import os | |
import cv2 | |
model = CNNLSTMClassifier() | |
model.load_state_dict(torch.load("lbw_classifier.pt", map_location='cpu')) | |
model.eval() | |
classes = ["Not LBW", "LBW"] | |
def predict(video_file): | |
if isinstance(video_file, dict) and "name" in video_file: | |
video_path = video_file["name"] | |
else: | |
video_path = video_file | |
# Predict | |
frames = extract_frames(video_path) | |
with torch.no_grad(): | |
output = model(frames) | |
pred = torch.argmax(output, dim=1).item() | |
prob = torch.softmax(output, dim=1)[0][pred].item() | |
label = f"{classes[pred]} ({prob:.2%})" | |
# Create annotated video | |
cap = cv2.VideoCapture(video_path) | |
out_path = "/tmp/annotated_video.mp4" | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
out = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
color = (0, 255, 0) if pred == 1 else (0, 0, 255) | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
cv2.putText(frame, label, (30, 60), font, 2, color, 4, cv2.LINE_AA) | |
out.write(frame) | |
cap.release() | |
out.release() | |
return out_path | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Video(), | |
outputs=gr.Video(), # β return annotated video | |
title="Smart LBW Classifier", | |
description="Upload a cricket video. The model will analyze the frames and overlay the LBW prediction." | |
) | |
iface.launch() | |