File size: 9,011 Bytes
703780e
 
 
 
ce8c119
703780e
 
 
 
ce8c119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703780e
 
 
 
 
 
 
 
 
 
 
 
 
 
ce8c119
 
 
 
703780e
 
 
 
 
 
 
 
 
 
 
 
ce8c119
 
 
 
703780e
 
 
 
 
ce8c119
703780e
 
ce8c119
 
 
 
 
 
703780e
ce8c119
 
 
 
 
 
 
 
 
 
703780e
ce8c119
 
703780e
ce8c119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703780e
ce8c119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703780e
 
 
 
 
 
 
ce8c119
 
 
 
703780e
ce8c119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# app.py
import streamlit as st
import cv2
import numpy as np
import pandas as pd
from ultralytics import YOLO
import supervision as sv
from scipy.spatial import distance as dist
import tempfile
import matplotlib.pyplot as plt

# Sidebar with logo and team information
with st.sidebar:
    st.image("https://cs.christuniversity.in/softex/resources/img/christ_university_Black.png", width=200)
    st.markdown("<h4 style='text-align: center; margin-top: 0;'>Christ University Kengeri Campus</h4>", unsafe_allow_html=True)
    st.markdown("<h5 style='text-align: center; margin-top: 0;'>Sports Department</h5>", unsafe_allow_html=True)
    
    st.markdown("### Team Members")
    st.markdown("""
    - Harsh Vardhan Lal 2262069 6BTCSAIML B
    - Harsheet Sandeep Thakur 2262070 6BTCSAIML B
    - Nandalal C B 2262115 6BTCSAIML B
    - Athul Nambiar 2262041 6BTCSAIML B
    """)

# Initialize components
model = YOLO('yolov8s.pt')
tracker = sv.ByteTrack()

# Streamlit UI
st.title("⚽️ Player Tracking System")
uploaded_video = st.file_uploader("Upload match video", type=["mp4", "mov"])
calibration_dist = st.number_input("Field width in meters (for speed calibration):", value=68.0)

# Initialize session state
if 'player_data' not in st.session_state:
    st.session_state.player_data = {}

# Create a placeholder for the live tracking table
live_table_container = st.container()
live_table = live_table_container.empty()

if uploaded_video:
    tfile = tempfile.NamedTemporaryFile(delete=False)
    tfile.write(uploaded_video.read())
    
    cap = cv2.VideoCapture(tfile.name)
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    pixels_per_meter = frame_width / calibration_dist if calibration_dist > 0 else 1.0
    
    st_frame = st.empty()
    frame_count = 0

    # Clear previous player data
    st.session_state.player_data = {}
    player_count = 0  # Track the number of distinct players

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Detection + tracking with person class filter (class 0)
        results = model(frame)[0]
        detections = sv.Detections.from_ultralytics(results)
        
        # Filter for person class only (class 0 in COCO dataset)
        person_mask = np.array([class_id == 0 for class_id in detections.class_id])
        detections = detections[person_mask]
        
        # Apply tracking - this will assign consistent IDs
        detections = tracker.update_with_detections(detections)
        
        # Extract bounding boxes, track IDs, etc.
        boxes = detections.xyxy  # Get bounding boxes [x1, y1, x2, y2]
        track_ids = detections.tracker_id  # Get track IDs
        
        if boxes is not None and len(boxes) > 0:
            for i, (box, track_id) in enumerate(zip(boxes, track_ids)):
                # Skip detections without a track_id
                if track_id is None:
                    continue
                
                # Use the track_id directly from ByteTrack
                player_id = int(track_id)
                
                # Calculate center point of bounding box
                x1, y1, x2, y2 = box
                centroid = (int((x1 + x2) / 2), int((y1 + y2) / 2))
                
                # Initialize speed to 0 for all cases
                speed = 0.0
                
                # Initialize new player if not seen before
                if player_id not in st.session_state.player_data:
                    st.session_state.player_data[player_id] = {
                        'positions': [centroid],
                        'timestamps': [frame_count / fps],
                        'distance': 0.0,
                        'speeds': [0.0],  # Initialize with zero speed
                        'last_seen': frame_count
                    }
                    player_count += 1  # Increment player counter
                else:
                    # Calculate movement metrics
                    prev_pos = st.session_state.player_data[player_id]['positions'][-1]
                    time_diff = (frame_count / fps) - st.session_state.player_data[player_id]['timestamps'][-1]
                    
                    # Calculate distance
                    pixel_dist = dist.euclidean(prev_pos, centroid)
                    
                    # Apply motion smoothing to ignore unrealistic movements
                    if pixel_dist < frame_width * 0.1:  # Max 10% of screen width per frame
                        speed_px = pixel_dist / time_diff if time_diff > 0 else 0.0
                        
                        # Convert to meters per second
                        speed = speed_px / pixels_per_meter
                        meter_dist = pixel_dist / pixels_per_meter
                        
                        # Update player data
                        st.session_state.player_data[player_id]['positions'].append(centroid)
                        st.session_state.player_data[player_id]['timestamps'].append(frame_count / fps)
                        st.session_state.player_data[player_id]['distance'] += meter_dist
                        st.session_state.player_data[player_id]['speeds'].append(speed)
                        st.session_state.player_data[player_id]['last_seen'] = frame_count
                    else:
                        # Just update position without adding to distance/speed
                        st.session_state.player_data[player_id]['positions'].append(centroid)
                        st.session_state.player_data[player_id]['timestamps'].append(frame_count / fps)
                        st.session_state.player_data[player_id]['last_seen'] = frame_count
                        # Maintain previous speed
                        speed = st.session_state.player_data[player_id]['speeds'][-1] if st.session_state.player_data[player_id]['speeds'] else 0
                
                # Draw player info on frame
                label = f"ID:{player_id} Speed:{speed:.1f}m/s"
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
                cv2.putText(frame, label, (int(x1), int(y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)

        # Update live table
        if st.session_state.player_data:
            live_data = []
            for player_id, data in st.session_state.player_data.items():
                # Only show recently seen players
                if frame_count - data['last_seen'] < 30:
                    current_speed = data['speeds'][-1] if data['speeds'] else 0
                    live_data.append({
                        "Player": f"Player {player_id}",
                        "Current Speed (m/s)": round(current_speed, 2),
                        "Distance (m)": round(data['distance'], 2),
                        "Time (s)": round(data['timestamps'][-1], 2) if data['timestamps'] else 0
                    })
            
            live_df = pd.DataFrame(live_data)
            if not live_df.empty:
                live_table.dataframe(live_df, use_container_width=True, hide_index=True)

        st_frame.image(frame, channels="BGR")
        frame_count += 1

    cap.release()
    
    # Display final analytics
    st.subheader(f"Player Statistics ({player_count} players detected)")
    
    # Create player stats table
    stats_data = []
    for player_id, data in st.session_state.player_data.items():
        if len(data['speeds']) > 0:
            avg_speed = np.mean(data['speeds'])
            max_speed = np.max(data['speeds'])
        else:
            avg_speed = 0
            max_speed = 0
            
        duration = data['timestamps'][-1] - data['timestamps'][0] if len(data['timestamps']) > 1 else 0
        
        # Only include players with significant tracking data
        if len(data['positions']) > 10:
            stats_data.append({
                "Player": f"Player {player_id}",
                "Total Distance (m)": round(data['distance'], 2),
                "Avg Speed (m/s)": round(avg_speed, 2),
                "Max Speed (m/s)": round(max_speed, 2),
                "Duration (s)": round(duration, 2)
            })
    
    stats_df = pd.DataFrame(stats_data)
    st.dataframe(stats_df, use_container_width=True, hide_index=True)
    
    # Create speed comparison chart
    if st.session_state.player_data:
        st.subheader("Player Speed Comparison")
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        for player_id, data in st.session_state.player_data.items():
            # Only plot players with significant data
            if len(data['speeds']) > 10:
                ax.plot(data['timestamps'], data['speeds'], label=f"Player {player_id}")
        
        ax.set_xlabel('Time (seconds)')
        ax.set_ylabel('Speed (m/s)')
        ax.set_title('Player Speed Over Time')
        ax.legend()
        ax.grid(True)
        
        st.pyplot(fig)