DSatishchandra's picture
Update app.py
af135f6 verified
import gradio as gr
import torch
from model import CNNLSTMClassifier
from utils import extract_frames
import shutil
import os
import cv2
model = CNNLSTMClassifier()
model.load_state_dict(torch.load("lbw_classifier.pt", map_location='cpu'))
model.eval()
classes = ["Not LBW", "LBW"]
def predict(video_file):
if isinstance(video_file, dict) and "name" in video_file:
video_path = video_file["name"]
else:
video_path = video_file
# Predict
frames = extract_frames(video_path)
with torch.no_grad():
output = model(frames)
pred = torch.argmax(output, dim=1).item()
prob = torch.softmax(output, dim=1)[0][pred].item()
label = f"{classes[pred]} ({prob:.2%})"
# Create annotated video
cap = cv2.VideoCapture(video_path)
out_path = "/tmp/annotated_video.mp4"
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
font = cv2.FONT_HERSHEY_SIMPLEX
color = (0, 255, 0) if pred == 1 else (0, 0, 255)
while True:
ret, frame = cap.read()
if not ret:
break
cv2.putText(frame, label, (30, 60), font, 2, color, 4, cv2.LINE_AA)
out.write(frame)
cap.release()
out.release()
return out_path
iface = gr.Interface(
fn=predict,
inputs=gr.Video(),
outputs=gr.Video(), # ← return annotated video
title="Smart LBW Classifier",
description="Upload a cricket video. The model will analyze the frames and overlay the LBW prediction."
)
iface.launch()