Age_prediction / app.py
yunusajib's picture
update app.py
1e4a4a4 verified
import torch
import cv2
import tempfile
import gradio as gr
from ultralytics import YOLO
from ultralytics.nn.tasks import DetectionModel
from ultralytics.nn.modules.conv import Conv
# Add all the classes we've seen so far to the safe globals list
torch.serialization.add_safe_globals([
DetectionModel,
torch.nn.modules.container.Sequential,
Conv
])
# Load the YOLO model
model = YOLO("yolov8n.pt")
# Object tracking function
def track_objects(video_input):
# Read uploaded video
cap = cv2.VideoCapture(video_input)
# Create a temporary output video file
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter(tmp_file.name, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Run YOLOv8 tracking
results = model.track(frame, persist=True, tracker="bytetrack.yaml")[0]
# Get annotated frame
annotated_frame = results.plot()
out.write(annotated_frame)
cap.release()
out.release()
return tmp_file.name
# Gradio interface
demo = gr.Interface(
fn=track_objects,
inputs=gr.Video(label="Upload a video to track people"),
outputs=gr.Video(label="Tracked Output"),
title="People Tracking with YOLOv8",
description="Upload a video and track people with YOLOv8 and ByteTrack"
)
demo.launch()