Spaces:
Runtime error
Runtime error
import tempfile | |
import numpy as np | |
import streamlit as st | |
from streamlit_image_coordinates import streamlit_image_coordinates | |
import cv2 | |
from ultralytics import YOLO | |
from detection import create_colors_info, detect | |
def main(): | |
st.set_page_config(page_title="AI Powered Web Application for Football Tactical Analysis", layout="wide", initial_sidebar_state="expanded") | |
st.title("Football Players Detection With Team Prediction & Tactical Map") | |
st.subheader(":red[Works only with Tactical Camera footage]") | |
st.sidebar.title("Main Settings") | |
demo_selected = st.sidebar.radio(label="Select Demo Video", options=["Demo 1", "Demo 2"], horizontal=True) | |
## Sidebar Setup | |
st.sidebar.markdown('---') | |
st.sidebar.subheader("Video Upload") | |
input_vide_file = st.sidebar.file_uploader('Upload a video file', type=['mp4','mov', 'avi', 'm4v', 'asf']) | |
demo_vid_paths={ | |
"Demo 1":'demo_vid_1.mp4', | |
"Demo 2":'demo_vid_2.mp4' | |
} | |
demo_vid_path = demo_vid_paths[demo_selected] | |
demo_team_info = { | |
"Demo 1":{"team1_name":"France", | |
"team2_name":"Switzerland", | |
"team1_p_color":'#1E2530', | |
"team1_gk_color":'#F5FD15', | |
"team2_p_color":'#FBFCFA', | |
"team2_gk_color":'#B1FCC4', | |
}, | |
"Demo 2":{"team1_name":"Chelsea", | |
"team2_name":"Manchester City", | |
"team1_p_color":'#29478A', | |
"team1_gk_color":'#DC6258', | |
"team2_p_color":'#90C8FF', | |
"team2_gk_color":'#BCC703', | |
} | |
} | |
selected_team_info = demo_team_info[demo_selected] | |
tempf = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
if not input_vide_file: | |
tempf.name = demo_vid_path | |
demo_vid = open(tempf.name, 'rb') | |
demo_bytes = demo_vid.read() | |
st.sidebar.text('Demo video') | |
st.sidebar.video(demo_bytes) | |
else: | |
tempf.write(input_vide_file.read()) | |
demo_vid = open(tempf.name, 'rb') | |
demo_bytes = demo_vid.read() | |
st.sidebar.text('Input video') | |
st.sidebar.video(demo_bytes) | |
# Load the YOLOv8 players detection model | |
model_players = YOLO("models/Yolo8L Players/weights/best.pt") | |
# Load the YOLOv8 field keypoints detection model | |
model_keypoints = YOLO("models/Yolo8M Field Keypoints/weights/best.pt") | |
st.sidebar.markdown('---') | |
st.sidebar.subheader("Team Names") | |
team1_name = st.sidebar.text_input(label='First Team Name', value=selected_team_info["team1_name"]) | |
team2_name = st.sidebar.text_input(label='Second Team Name', value=selected_team_info["team2_name"]) | |
st.sidebar.markdown('---') | |
## Page Setup | |
tab1, tab2, tab3 = st.tabs(["How to use?", "Team Colors", "Model Hyperparameters & Detection"]) | |
with tab1: | |
st.header(':blue[Welcome!]') | |
st.subheader('Main Application Functionalities:', divider='blue') | |
st.markdown(""" | |
1. Football players, referee, and ball detection. | |
2. Players team prediction. | |
3. Estimation of players and ball positions on a tactical map. | |
4. Ball Tracking. | |
""") | |
st.subheader('How to use?', divider='blue') | |
st.markdown(""" | |
**There are two demo videos that are automaticaly loaded when you start the app, alongside the recommended settings and hyperparameters** | |
1. Upload a video to analyse, using the sidebar menu "Browse files" button. | |
2. Enter the team names that corresponds to the uploaded video in the text fields in the sidebar menu. | |
3. Access the "Team colors" tab in the main page. | |
4. Select a frame where players and goal keepers from both teams can be detected. | |
5. Follow the instruction on the page to pick each team colors. | |
6. Go to the "Model Hyperpramerters & Detection" tab, adjust hyperparameters and select the annotation options. (Default hyperparameters are recommended) | |
7. Run Detection! | |
8. If "save outputs" option was selected the saved video can be found in the "outputs" directory | |
""") | |
st.write("Version 0.0.1") | |
with tab2: | |
t1col1, t1col2 = st.columns([1,1]) | |
with t1col1: | |
cap_temp = cv2.VideoCapture(tempf.name) | |
frame_count = int(cap_temp.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_nbr = st.slider(label="Select frame", min_value=1, max_value=frame_count, step=1, help="Select frame to pick team colors from") | |
cap_temp.set(cv2.CAP_PROP_POS_FRAMES, frame_nbr) | |
success, frame = cap_temp.read() | |
with st.spinner('Detecting players in selected frame..'): | |
results = model_players(frame, conf=0.7) | |
bboxes = results[0].boxes.xyxy.cpu().numpy() | |
labels = results[0].boxes.cls.cpu().numpy() | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
detections_imgs_list = [] | |
detections_imgs_grid = [] | |
padding_img = np.ones((80,60,3),dtype=np.uint8)*255 | |
for i, j in enumerate(list(labels)): | |
if int(j) == 0: | |
bbox = bboxes[i,:] | |
obj_img = frame_rgb[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] | |
obj_img = cv2.resize(obj_img, (60,80)) | |
detections_imgs_list.append(obj_img) | |
detections_imgs_grid.append([detections_imgs_list[i] for i in range(len(detections_imgs_list)//2)]) | |
detections_imgs_grid.append([detections_imgs_list[i] for i in range(len(detections_imgs_list)//2, len(detections_imgs_list))]) | |
if len(detections_imgs_list)%2 != 0: | |
detections_imgs_grid[0].append(padding_img) | |
concat_det_imgs_row1 = cv2.hconcat(detections_imgs_grid[0]) | |
concat_det_imgs_row2 = cv2.hconcat(detections_imgs_grid[1]) | |
concat_det_imgs = cv2.vconcat([concat_det_imgs_row1,concat_det_imgs_row2]) | |
st.write("Detected players") | |
value = streamlit_image_coordinates(concat_det_imgs, key="numpy") | |
#value_radio_dic = defaultdict(lambda: None) | |
st.markdown('---') | |
radio_options =[f"{team1_name} P color", f"{team1_name} GK color",f"{team2_name} P color", f"{team2_name} GK color"] | |
active_color = st.radio(label="Select which team color to pick from the image above", options=radio_options, horizontal=True, | |
help="Chose team color you want to pick and click on the image above to pick the color. Colors will be displayed in boxes below.") | |
if value is not None: | |
picked_color = concat_det_imgs[value['y'], value['x'], :] | |
st.session_state[f"{active_color}"] = '#%02x%02x%02x' % tuple(picked_color) | |
st.write("Boxes below can be used to manually adjust selected colors.") | |
cp1, cp2, cp3, cp4 = st.columns([1,1,1,1]) | |
with cp1: | |
hex_color_1 = st.session_state[f"{team1_name} P color"] if f"{team1_name} P color" in st.session_state else selected_team_info["team1_p_color"] | |
team1_p_color = st.color_picker(label=' ', value=hex_color_1, key='t1p') | |
st.session_state[f"{team1_name} P color"] = team1_p_color | |
with cp2: | |
hex_color_2 = st.session_state[f"{team1_name} GK color"] if f"{team1_name} GK color" in st.session_state else selected_team_info["team1_gk_color"] | |
team1_gk_color = st.color_picker(label=' ', value=hex_color_2, key='t1gk') | |
st.session_state[f"{team1_name} GK color"] = team1_gk_color | |
with cp3: | |
hex_color_3 = st.session_state[f"{team2_name} P color"] if f"{team2_name} P color" in st.session_state else selected_team_info["team2_p_color"] | |
team2_p_color = st.color_picker(label=' ', value=hex_color_3, key='t2p') | |
st.session_state[f"{team2_name} P color"] = team2_p_color | |
with cp4: | |
hex_color_4 = st.session_state[f"{team2_name} GK color"] if f"{team2_name} GK color" in st.session_state else selected_team_info["team2_gk_color"] | |
team2_gk_color = st.color_picker(label=' ', value=hex_color_4, key='t2gk') | |
st.session_state[f"{team2_name} GK color"] = team2_gk_color | |
st.markdown('---') | |
with t1col2: | |
extracted_frame = st.empty() | |
extracted_frame.image(frame, use_column_width=True, channels="BGR") | |
colors_dic, color_list_lab = create_colors_info(team1_name, st.session_state[f"{team1_name} P color"], st.session_state[f"{team1_name} GK color"], | |
team2_name, st.session_state[f"{team2_name} P color"], st.session_state[f"{team2_name} GK color"]) | |
with tab3: | |
t2col1, t2col2 = st.columns([1,1]) | |
with t2col1: | |
player_model_conf_thresh = st.slider('PLayers Detection Confidence Threshold', min_value=0.0, max_value=1.0, value=0.6) | |
keypoints_model_conf_thresh = st.slider('Field Keypoints PLayers Detection Confidence Threshold', min_value=0.0, max_value=1.0, value=0.7) | |
keypoints_displacement_mean_tol = st.slider('Keypoints Displacement RMSE Tolerance (pixels)', min_value=-1, max_value=100, value=7, | |
help="Indicates the maximum allowed average distance between the position of the field keypoints\ | |
in current and previous detections. It is used to determine wether to update homography matrix or not. ") | |
detection_hyper_params = { | |
0: player_model_conf_thresh, | |
1: keypoints_model_conf_thresh, | |
2: keypoints_displacement_mean_tol | |
} | |
with t2col2: | |
num_pal_colors = st.slider(label="Number of palette colors", min_value=1, max_value=5, step=1, value=3, | |
help="How many colors to extract form detected players bounding-boxes? It is used for team prediction.") | |
st.markdown("---") | |
save_output = st.checkbox(label='Save output', value=False) | |
if save_output: | |
output_file_name = st.text_input(label='File Name (Optional)', placeholder='Enter output video file name.') | |
else: | |
output_file_name = None | |
st.markdown("---") | |
bcol1, bcol2 = st.columns([1,1]) | |
with bcol1: | |
nbr_frames_no_ball_thresh = st.number_input("Ball track reset threshold (frames)", min_value=1, max_value=10000, | |
value=30, help="After how many frames with no ball detection, should the track be reset?") | |
ball_track_dist_thresh = st.number_input("Ball track distance threshold (pixels)", min_value=1, max_value=1280, | |
value=100, help="Maximum allowed distance between two consecutive balls detection to keep the current track.") | |
max_track_length = st.number_input("Maximum ball track length (Nbr. detections)", min_value=1, max_value=1000, | |
value=35, help="Maximum total number of ball detections to keep in tracking history") | |
ball_track_hyperparams = { | |
0: nbr_frames_no_ball_thresh, | |
1: ball_track_dist_thresh, | |
2: max_track_length | |
} | |
with bcol2: | |
st.write("Annotation options:") | |
bcol21t, bcol22t = st.columns([1,1]) | |
with bcol21t: | |
show_k = st.toggle(label="Show Keypoints Detections", value=False) | |
show_p = st.toggle(label="Show Players Detections", value=True) | |
with bcol22t: | |
show_pal = st.toggle(label="Show Color Palettes", value=True) | |
show_b = st.toggle(label="Show Ball Tracks", value=True) | |
plot_hyperparams = { | |
0: show_k, | |
1: show_pal, | |
2: show_b, | |
3: show_p | |
} | |
st.markdown('---') | |
bcol21, bcol22, bcol23, bcol24 = st.columns([1.5,1,1,1]) | |
with bcol21: | |
st.write('') | |
with bcol22: | |
ready = True if (team1_name == '') or (team2_name == '') else False | |
start_detection = st.button(label='Start Detection', disabled=ready) | |
with bcol23: | |
stop_btn_state = True if not start_detection else False | |
stop_detection = st.button(label='Stop Detection', disabled=stop_btn_state) | |
with bcol24: | |
st.write('') | |
stframe = st.empty() | |
cap = cv2.VideoCapture(tempf.name) | |
status = False | |
if start_detection and not stop_detection: | |
st.toast(f'Detection Started!') | |
status = detect(cap, stframe, output_file_name, save_output, model_players, model_keypoints, | |
detection_hyper_params, ball_track_hyperparams, plot_hyperparams, | |
num_pal_colors, colors_dic, color_list_lab) | |
else: | |
try: | |
# Release the video capture object and close the display window | |
cap.release() | |
except: | |
pass | |
if status: | |
st.toast(f'Detection Completed!') | |
cap.release() | |
if __name__=='__main__': | |
try: | |
main() | |
except SystemExit: | |
pass |