athulnambiar's picture
Update app.py
ce8c119 verified
# app.py
import streamlit as st
import cv2
import numpy as np
import pandas as pd
from ultralytics import YOLO
import supervision as sv
from scipy.spatial import distance as dist
import tempfile
import matplotlib.pyplot as plt
# Sidebar with logo and team information
with st.sidebar:
st.image("https://cs.christuniversity.in/softex/resources/img/christ_university_Black.png", width=200)
st.markdown("<h4 style='text-align: center; margin-top: 0;'>Christ University Kengeri Campus</h4>", unsafe_allow_html=True)
st.markdown("<h5 style='text-align: center; margin-top: 0;'>Sports Department</h5>", unsafe_allow_html=True)
st.markdown("### Team Members")
st.markdown("""
- Harsh Vardhan Lal 2262069 6BTCSAIML B
- Harsheet Sandeep Thakur 2262070 6BTCSAIML B
- Nandalal C B 2262115 6BTCSAIML B
- Athul Nambiar 2262041 6BTCSAIML B
""")
# Initialize components
model = YOLO('yolov8s.pt')
tracker = sv.ByteTrack()
# Streamlit UI
st.title("⚽️ Player Tracking System")
uploaded_video = st.file_uploader("Upload match video", type=["mp4", "mov"])
calibration_dist = st.number_input("Field width in meters (for speed calibration):", value=68.0)
# Initialize session state
if 'player_data' not in st.session_state:
st.session_state.player_data = {}
# Create a placeholder for the live tracking table
live_table_container = st.container()
live_table = live_table_container.empty()
if uploaded_video:
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_video.read())
cap = cv2.VideoCapture(tfile.name)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
pixels_per_meter = frame_width / calibration_dist if calibration_dist > 0 else 1.0
st_frame = st.empty()
frame_count = 0
# Clear previous player data
st.session_state.player_data = {}
player_count = 0 # Track the number of distinct players
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Detection + tracking with person class filter (class 0)
results = model(frame)[0]
detections = sv.Detections.from_ultralytics(results)
# Filter for person class only (class 0 in COCO dataset)
person_mask = np.array([class_id == 0 for class_id in detections.class_id])
detections = detections[person_mask]
# Apply tracking - this will assign consistent IDs
detections = tracker.update_with_detections(detections)
# Extract bounding boxes, track IDs, etc.
boxes = detections.xyxy # Get bounding boxes [x1, y1, x2, y2]
track_ids = detections.tracker_id # Get track IDs
if boxes is not None and len(boxes) > 0:
for i, (box, track_id) in enumerate(zip(boxes, track_ids)):
# Skip detections without a track_id
if track_id is None:
continue
# Use the track_id directly from ByteTrack
player_id = int(track_id)
# Calculate center point of bounding box
x1, y1, x2, y2 = box
centroid = (int((x1 + x2) / 2), int((y1 + y2) / 2))
# Initialize speed to 0 for all cases
speed = 0.0
# Initialize new player if not seen before
if player_id not in st.session_state.player_data:
st.session_state.player_data[player_id] = {
'positions': [centroid],
'timestamps': [frame_count / fps],
'distance': 0.0,
'speeds': [0.0], # Initialize with zero speed
'last_seen': frame_count
}
player_count += 1 # Increment player counter
else:
# Calculate movement metrics
prev_pos = st.session_state.player_data[player_id]['positions'][-1]
time_diff = (frame_count / fps) - st.session_state.player_data[player_id]['timestamps'][-1]
# Calculate distance
pixel_dist = dist.euclidean(prev_pos, centroid)
# Apply motion smoothing to ignore unrealistic movements
if pixel_dist < frame_width * 0.1: # Max 10% of screen width per frame
speed_px = pixel_dist / time_diff if time_diff > 0 else 0.0
# Convert to meters per second
speed = speed_px / pixels_per_meter
meter_dist = pixel_dist / pixels_per_meter
# Update player data
st.session_state.player_data[player_id]['positions'].append(centroid)
st.session_state.player_data[player_id]['timestamps'].append(frame_count / fps)
st.session_state.player_data[player_id]['distance'] += meter_dist
st.session_state.player_data[player_id]['speeds'].append(speed)
st.session_state.player_data[player_id]['last_seen'] = frame_count
else:
# Just update position without adding to distance/speed
st.session_state.player_data[player_id]['positions'].append(centroid)
st.session_state.player_data[player_id]['timestamps'].append(frame_count / fps)
st.session_state.player_data[player_id]['last_seen'] = frame_count
# Maintain previous speed
speed = st.session_state.player_data[player_id]['speeds'][-1] if st.session_state.player_data[player_id]['speeds'] else 0
# Draw player info on frame
label = f"ID:{player_id} Speed:{speed:.1f}m/s"
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
cv2.putText(frame, label, (int(x1), int(y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
# Update live table
if st.session_state.player_data:
live_data = []
for player_id, data in st.session_state.player_data.items():
# Only show recently seen players
if frame_count - data['last_seen'] < 30:
current_speed = data['speeds'][-1] if data['speeds'] else 0
live_data.append({
"Player": f"Player {player_id}",
"Current Speed (m/s)": round(current_speed, 2),
"Distance (m)": round(data['distance'], 2),
"Time (s)": round(data['timestamps'][-1], 2) if data['timestamps'] else 0
})
live_df = pd.DataFrame(live_data)
if not live_df.empty:
live_table.dataframe(live_df, use_container_width=True, hide_index=True)
st_frame.image(frame, channels="BGR")
frame_count += 1
cap.release()
# Display final analytics
st.subheader(f"Player Statistics ({player_count} players detected)")
# Create player stats table
stats_data = []
for player_id, data in st.session_state.player_data.items():
if len(data['speeds']) > 0:
avg_speed = np.mean(data['speeds'])
max_speed = np.max(data['speeds'])
else:
avg_speed = 0
max_speed = 0
duration = data['timestamps'][-1] - data['timestamps'][0] if len(data['timestamps']) > 1 else 0
# Only include players with significant tracking data
if len(data['positions']) > 10:
stats_data.append({
"Player": f"Player {player_id}",
"Total Distance (m)": round(data['distance'], 2),
"Avg Speed (m/s)": round(avg_speed, 2),
"Max Speed (m/s)": round(max_speed, 2),
"Duration (s)": round(duration, 2)
})
stats_df = pd.DataFrame(stats_data)
st.dataframe(stats_df, use_container_width=True, hide_index=True)
# Create speed comparison chart
if st.session_state.player_data:
st.subheader("Player Speed Comparison")
fig, ax = plt.subplots(figsize=(10, 6))
for player_id, data in st.session_state.player_data.items():
# Only plot players with significant data
if len(data['speeds']) > 10:
ax.plot(data['timestamps'], data['speeds'], label=f"Player {player_id}")
ax.set_xlabel('Time (seconds)')
ax.set_ylabel('Speed (m/s)')
ax.set_title('Player Speed Over Time')
ax.legend()
ax.grid(True)
st.pyplot(fig)