saravanamax's picture
Create app.py
b354bdf verified
import tempfile
import numpy as np
import streamlit as st
from streamlit_image_coordinates import streamlit_image_coordinates
import cv2
from ultralytics import YOLO
from detection import create_colors_info, detect
def main():
st.set_page_config(page_title="AI Powered Web Application for Football Tactical Analysis", layout="wide", initial_sidebar_state="expanded")
st.title("Football Players Detection With Team Prediction & Tactical Map")
st.subheader(":red[Works only with Tactical Camera footage]")
st.sidebar.title("Main Settings")
demo_selected = st.sidebar.radio(label="Select Demo Video", options=["Demo 1", "Demo 2"], horizontal=True)
## Sidebar Setup
st.sidebar.markdown('---')
st.sidebar.subheader("Video Upload")
input_vide_file = st.sidebar.file_uploader('Upload a video file', type=['mp4','mov', 'avi', 'm4v', 'asf'])
demo_vid_paths={
"Demo 1":'demo_vid_1.mp4',
"Demo 2":'demo_vid_2.mp4'
}
demo_vid_path = demo_vid_paths[demo_selected]
demo_team_info = {
"Demo 1":{"team1_name":"France",
"team2_name":"Switzerland",
"team1_p_color":'#1E2530',
"team1_gk_color":'#F5FD15',
"team2_p_color":'#FBFCFA',
"team2_gk_color":'#B1FCC4',
},
"Demo 2":{"team1_name":"Chelsea",
"team2_name":"Manchester City",
"team1_p_color":'#29478A',
"team1_gk_color":'#DC6258',
"team2_p_color":'#90C8FF',
"team2_gk_color":'#BCC703',
}
}
selected_team_info = demo_team_info[demo_selected]
tempf = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
if not input_vide_file:
tempf.name = demo_vid_path
demo_vid = open(tempf.name, 'rb')
demo_bytes = demo_vid.read()
st.sidebar.text('Demo video')
st.sidebar.video(demo_bytes)
else:
tempf.write(input_vide_file.read())
demo_vid = open(tempf.name, 'rb')
demo_bytes = demo_vid.read()
st.sidebar.text('Input video')
st.sidebar.video(demo_bytes)
# Load the YOLOv8 players detection model
model_players = YOLO("models/Yolo8L Players/weights/best.pt")
# Load the YOLOv8 field keypoints detection model
model_keypoints = YOLO("models/Yolo8M Field Keypoints/weights/best.pt")
st.sidebar.markdown('---')
st.sidebar.subheader("Team Names")
team1_name = st.sidebar.text_input(label='First Team Name', value=selected_team_info["team1_name"])
team2_name = st.sidebar.text_input(label='Second Team Name', value=selected_team_info["team2_name"])
st.sidebar.markdown('---')
## Page Setup
tab1, tab2, tab3 = st.tabs(["How to use?", "Team Colors", "Model Hyperparameters & Detection"])
with tab1:
st.header(':blue[Welcome!]')
st.subheader('Main Application Functionalities:', divider='blue')
st.markdown("""
1. Football players, referee, and ball detection.
2. Players team prediction.
3. Estimation of players and ball positions on a tactical map.
4. Ball Tracking.
""")
st.subheader('How to use?', divider='blue')
st.markdown("""
**There are two demo videos that are automaticaly loaded when you start the app, alongside the recommended settings and hyperparameters**
1. Upload a video to analyse, using the sidebar menu "Browse files" button.
2. Enter the team names that corresponds to the uploaded video in the text fields in the sidebar menu.
3. Access the "Team colors" tab in the main page.
4. Select a frame where players and goal keepers from both teams can be detected.
5. Follow the instruction on the page to pick each team colors.
6. Go to the "Model Hyperpramerters & Detection" tab, adjust hyperparameters and select the annotation options. (Default hyperparameters are recommended)
7. Run Detection!
8. If "save outputs" option was selected the saved video can be found in the "outputs" directory
""")
st.write("Version 0.0.1")
with tab2:
t1col1, t1col2 = st.columns([1,1])
with t1col1:
cap_temp = cv2.VideoCapture(tempf.name)
frame_count = int(cap_temp.get(cv2.CAP_PROP_FRAME_COUNT))
frame_nbr = st.slider(label="Select frame", min_value=1, max_value=frame_count, step=1, help="Select frame to pick team colors from")
cap_temp.set(cv2.CAP_PROP_POS_FRAMES, frame_nbr)
success, frame = cap_temp.read()
with st.spinner('Detecting players in selected frame..'):
results = model_players(frame, conf=0.7)
bboxes = results[0].boxes.xyxy.cpu().numpy()
labels = results[0].boxes.cls.cpu().numpy()
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
detections_imgs_list = []
detections_imgs_grid = []
padding_img = np.ones((80,60,3),dtype=np.uint8)*255
for i, j in enumerate(list(labels)):
if int(j) == 0:
bbox = bboxes[i,:]
obj_img = frame_rgb[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]
obj_img = cv2.resize(obj_img, (60,80))
detections_imgs_list.append(obj_img)
detections_imgs_grid.append([detections_imgs_list[i] for i in range(len(detections_imgs_list)//2)])
detections_imgs_grid.append([detections_imgs_list[i] for i in range(len(detections_imgs_list)//2, len(detections_imgs_list))])
if len(detections_imgs_list)%2 != 0:
detections_imgs_grid[0].append(padding_img)
concat_det_imgs_row1 = cv2.hconcat(detections_imgs_grid[0])
concat_det_imgs_row2 = cv2.hconcat(detections_imgs_grid[1])
concat_det_imgs = cv2.vconcat([concat_det_imgs_row1,concat_det_imgs_row2])
st.write("Detected players")
value = streamlit_image_coordinates(concat_det_imgs, key="numpy")
#value_radio_dic = defaultdict(lambda: None)
st.markdown('---')
radio_options =[f"{team1_name} P color", f"{team1_name} GK color",f"{team2_name} P color", f"{team2_name} GK color"]
active_color = st.radio(label="Select which team color to pick from the image above", options=radio_options, horizontal=True,
help="Chose team color you want to pick and click on the image above to pick the color. Colors will be displayed in boxes below.")
if value is not None:
picked_color = concat_det_imgs[value['y'], value['x'], :]
st.session_state[f"{active_color}"] = '#%02x%02x%02x' % tuple(picked_color)
st.write("Boxes below can be used to manually adjust selected colors.")
cp1, cp2, cp3, cp4 = st.columns([1,1,1,1])
with cp1:
hex_color_1 = st.session_state[f"{team1_name} P color"] if f"{team1_name} P color" in st.session_state else selected_team_info["team1_p_color"]
team1_p_color = st.color_picker(label=' ', value=hex_color_1, key='t1p')
st.session_state[f"{team1_name} P color"] = team1_p_color
with cp2:
hex_color_2 = st.session_state[f"{team1_name} GK color"] if f"{team1_name} GK color" in st.session_state else selected_team_info["team1_gk_color"]
team1_gk_color = st.color_picker(label=' ', value=hex_color_2, key='t1gk')
st.session_state[f"{team1_name} GK color"] = team1_gk_color
with cp3:
hex_color_3 = st.session_state[f"{team2_name} P color"] if f"{team2_name} P color" in st.session_state else selected_team_info["team2_p_color"]
team2_p_color = st.color_picker(label=' ', value=hex_color_3, key='t2p')
st.session_state[f"{team2_name} P color"] = team2_p_color
with cp4:
hex_color_4 = st.session_state[f"{team2_name} GK color"] if f"{team2_name} GK color" in st.session_state else selected_team_info["team2_gk_color"]
team2_gk_color = st.color_picker(label=' ', value=hex_color_4, key='t2gk')
st.session_state[f"{team2_name} GK color"] = team2_gk_color
st.markdown('---')
with t1col2:
extracted_frame = st.empty()
extracted_frame.image(frame, use_column_width=True, channels="BGR")
colors_dic, color_list_lab = create_colors_info(team1_name, st.session_state[f"{team1_name} P color"], st.session_state[f"{team1_name} GK color"],
team2_name, st.session_state[f"{team2_name} P color"], st.session_state[f"{team2_name} GK color"])
with tab3:
t2col1, t2col2 = st.columns([1,1])
with t2col1:
player_model_conf_thresh = st.slider('PLayers Detection Confidence Threshold', min_value=0.0, max_value=1.0, value=0.6)
keypoints_model_conf_thresh = st.slider('Field Keypoints PLayers Detection Confidence Threshold', min_value=0.0, max_value=1.0, value=0.7)
keypoints_displacement_mean_tol = st.slider('Keypoints Displacement RMSE Tolerance (pixels)', min_value=-1, max_value=100, value=7,
help="Indicates the maximum allowed average distance between the position of the field keypoints\
in current and previous detections. It is used to determine wether to update homography matrix or not. ")
detection_hyper_params = {
0: player_model_conf_thresh,
1: keypoints_model_conf_thresh,
2: keypoints_displacement_mean_tol
}
with t2col2:
num_pal_colors = st.slider(label="Number of palette colors", min_value=1, max_value=5, step=1, value=3,
help="How many colors to extract form detected players bounding-boxes? It is used for team prediction.")
st.markdown("---")
save_output = st.checkbox(label='Save output', value=False)
if save_output:
output_file_name = st.text_input(label='File Name (Optional)', placeholder='Enter output video file name.')
else:
output_file_name = None
st.markdown("---")
bcol1, bcol2 = st.columns([1,1])
with bcol1:
nbr_frames_no_ball_thresh = st.number_input("Ball track reset threshold (frames)", min_value=1, max_value=10000,
value=30, help="After how many frames with no ball detection, should the track be reset?")
ball_track_dist_thresh = st.number_input("Ball track distance threshold (pixels)", min_value=1, max_value=1280,
value=100, help="Maximum allowed distance between two consecutive balls detection to keep the current track.")
max_track_length = st.number_input("Maximum ball track length (Nbr. detections)", min_value=1, max_value=1000,
value=35, help="Maximum total number of ball detections to keep in tracking history")
ball_track_hyperparams = {
0: nbr_frames_no_ball_thresh,
1: ball_track_dist_thresh,
2: max_track_length
}
with bcol2:
st.write("Annotation options:")
bcol21t, bcol22t = st.columns([1,1])
with bcol21t:
show_k = st.toggle(label="Show Keypoints Detections", value=False)
show_p = st.toggle(label="Show Players Detections", value=True)
with bcol22t:
show_pal = st.toggle(label="Show Color Palettes", value=True)
show_b = st.toggle(label="Show Ball Tracks", value=True)
plot_hyperparams = {
0: show_k,
1: show_pal,
2: show_b,
3: show_p
}
st.markdown('---')
bcol21, bcol22, bcol23, bcol24 = st.columns([1.5,1,1,1])
with bcol21:
st.write('')
with bcol22:
ready = True if (team1_name == '') or (team2_name == '') else False
start_detection = st.button(label='Start Detection', disabled=ready)
with bcol23:
stop_btn_state = True if not start_detection else False
stop_detection = st.button(label='Stop Detection', disabled=stop_btn_state)
with bcol24:
st.write('')
stframe = st.empty()
cap = cv2.VideoCapture(tempf.name)
status = False
if start_detection and not stop_detection:
st.toast(f'Detection Started!')
status = detect(cap, stframe, output_file_name, save_output, model_players, model_keypoints,
detection_hyper_params, ball_track_hyperparams, plot_hyperparams,
num_pal_colors, colors_dic, color_list_lab)
else:
try:
# Release the video capture object and close the display window
cap.release()
except:
pass
if status:
st.toast(f'Detection Completed!')
cap.release()
if __name__=='__main__':
try:
main()
except SystemExit:
pass