File size: 1,697 Bytes
de8a79f
 
 
 
b3f84ea
 
 
de8a79f
 
 
 
 
 
 
e21f2de
 
 
 
b3f84ea
e21f2de
b3f84ea
e21f2de
de8a79f
 
 
 
b3f84ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de8a79f
 
 
af135f6
 
de8a79f
af135f6
de8a79f
 
af135f6
de8a79f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import torch
from model import CNNLSTMClassifier
from utils import extract_frames
import shutil
import os
import cv2

model = CNNLSTMClassifier()
model.load_state_dict(torch.load("lbw_classifier.pt", map_location='cpu'))
model.eval()

classes = ["Not LBW", "LBW"]

def predict(video_file):
    if isinstance(video_file, dict) and "name" in video_file:
        video_path = video_file["name"]
    else:
        video_path = video_file

    # Predict
    frames = extract_frames(video_path)
    with torch.no_grad():
        output = model(frames)
        pred = torch.argmax(output, dim=1).item()
        prob = torch.softmax(output, dim=1)[0][pred].item()

    label = f"{classes[pred]} ({prob:.2%})"

    # Create annotated video
    cap = cv2.VideoCapture(video_path)
    out_path = "/tmp/annotated_video.mp4"
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    fps = cap.get(cv2.CAP_PROP_FPS)
    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))

    font = cv2.FONT_HERSHEY_SIMPLEX
    color = (0, 255, 0) if pred == 1 else (0, 0, 255)

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        cv2.putText(frame, label, (30, 60), font, 2, color, 4, cv2.LINE_AA)
        out.write(frame)
    cap.release()
    out.release()

    return out_path

iface = gr.Interface(
    fn=predict,
    inputs=gr.Video(),
    outputs=gr.Video(),  # ← return annotated video
    title="Smart LBW Classifier",
    description="Upload a cricket video. The model will analyze the frames and overlay the LBW prediction."
)


iface.launch()