YOGAI / yoga_position_gradio.py
1mpreccable's picture
Upload 35 files
0ccc9b6 verified
import tqdm
import cv2
import numpy as np
import re
import os
from mediapipe.python.solutions import drawing_utils as mp_drawing
import mediapipe as mp
from PoseClassification.pose_embedding import FullBodyPoseEmbedding
from PoseClassification.pose_classifier import PoseClassifier
from PoseClassification.utils import EMADictSmoothing
# from PoseClassification.utils import RepetitionCounter
from PoseClassification.visualize import PoseClassificationVisualizer
import argparse
from PoseClassification.utils import show_image
def check_major_current_position(positions_detected:dict, threshold_position) -> str:
'''
return the major position between those detected in frame, or return none
INPUTS
positions_detected :
dict of positions given by position classifier and pose_classification_filtered
{'pose1':8.0, 'pose2':2.0}
threshold_position :
values strictly below are considered "none" position
OUTPUT
major_position :
string with position (classes from classifier and "none")
'''
if max(positions_detected.values())<float(threshold_position):
major_position='none'
else:
major_position=max(positions_detected, key=positions_detected.get)
return major_position
def yoga_position_classifier():
#Load arguments
parser = argparse.ArgumentParser()
parser.add_argument("video_path", help="string video path in")
args = parser.parse_args()
video_path_in = args.video_path
direct_video=False
if video_path_in=="live":
video_path_in='data/live.mp4'
direct_video=True
video_path_out = re.sub(r'.mp4', r'_classified_video.mp4', video_path_in)
results_classification_path_out = re.sub(r'.mp4', r'_classified_results.csv', video_path_in)
# Initialize tracker, classifier and current position.
# Initialize tracker.
mp_pose = mp.solutions.pose
pose_tracker = mp_pose.Pose()
# Folder with pose class CSVs. That should be the same folder you used while
# building classifier to output CSVs.
pose_samples_folder = 'data/yoga_poses_csvs_out'
# Initialize embedder.
pose_embedder = FullBodyPoseEmbedding()
# Initialize classifier.
# Check that you are using the same parameters as during bootstrapping.
pose_classifier = PoseClassifier(
pose_samples_folder=pose_samples_folder,
pose_embedder=pose_embedder,
top_n_by_max_distance=30,
top_n_by_mean_distance=10)
# Initialize EMA smoothing.
pose_classification_filter = EMADictSmoothing(
window_size=10,
alpha=0.2)
# Initialize list of results
position_list=[]
frame_list=[]
# Instruction if direct flux video
if direct_video :
video_cap = cv2.VideoCapture(0)
# Instruction if path video
else :
assert type(video_path_in)==str, "Error in video path format, not a string. Abort."
# Open video and get video parameters and check if video is OK
video_cap = cv2.VideoCapture(video_path_in)
video_n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
assert type(video_n_frames)==float, 'Error in input video frames type. Abort.'
assert video_n_frames>0.0, 'Error in input video frames number : no frame. Abort.'
video_fps = video_cap.get(cv2.CAP_PROP_FPS)
video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
class_names=['chair', 'cobra', 'dog', 'goddess', 'plank', 'tree', 'warrior', 'none']
position_threshold = 8.0
# Open output video.
out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height))
# Initialize results
frame_idx = 0
current_position = {"none":10.0}
output_frame = None
position_timer = 0
previous_position_major = 'none'
try:
with tqdm.tqdm(position=0, leave=True) as pbar:
while True:
# Get current time from beggining of video
time_sec = float(frame_idx*(1/video_fps))
# Get current major position (str)
current_position_major = check_major_current_position(current_position, position_threshold)
success, input_frame = video_cap.read()
if not success:
print("Unable to read input video frame, breaking!")
break
# Run pose tracker
input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
result = pose_tracker.process(image=input_frame_rgb)
pose_landmarks = result.pose_landmarks
# Prepare the output frame
output_frame = input_frame.copy()
# Add a white banner on top
banner_height = int(video_height//10)
output_frame[0:banner_height, :] = (255, 255, 255) # White color
# Load the logo image
logo = cv2.imread("src/logo_impredalam.jpg")
logo_height, logo_width = logo.shape[:2]
logo_height_rescaled = banner_height
logo_width_rescaled = int((logo_width*logo_height_rescaled)// logo_height )
logo = cv2.resize(logo, (logo_width_rescaled, logo_height_rescaled)) # Resize to banner scale
# Overlay the logo on the upper right corner
output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (logo)
# If landmarks are detected
if pose_landmarks is not None:
mp_drawing.draw_landmarks(
image=output_frame,
landmark_list=pose_landmarks,
connections=mp_pose.POSE_CONNECTIONS,)
# Get landmarks
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
pose_landmarks = np.array(
[
[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
for lmk in pose_landmarks.landmark
],
dtype=np.float32,)
assert pose_landmarks.shape == (33,3,), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
# Classify the pose on the current frame
pose_classification = pose_classifier(pose_landmarks)
# Smooth classification using EMA
pose_classification_filtered = pose_classification_filter(pose_classification)
current_position=pose_classification_filtered
current_position_major=check_major_current_position(current_position, position_threshold)
# If no landmarks are detected
else:
current_position={'none':10.0}
current_position_major=check_major_current_position(current_position, position_threshold)
# If landmarks or no landmarks detected :
# Compute position timer according to current and previous position
if current_position_major==previous_position_major:
#increase position_timer
position_timer+=(1/video_fps)
else:
previous_position_major=current_position_major
position_timer=0
# Display current position on frame
cv2.putText(
output_frame,
f"Pose: {current_position_major}",
(int(0+(1//50*video_width)), int(0+banner_height//3)), #coord
cv2.FONT_HERSHEY_SIMPLEX,
float(0.9*(video_height/video_width)), # Font size
(0, 0, 0), #color
1, # Thinner line
cv2.LINE_AA,)
# Display current position timer on frame
cv2.putText(
output_frame,
f"Duration: {int(position_timer)} seconds",
(int(0+(1//50*video_width)), int(0+(2*banner_height)//3)), #coord
cv2.FONT_HERSHEY_SIMPLEX,
float(0.9*(video_height/video_width)), # Font size
(0, 0, 0), #color
1, # Thinner line
cv2.LINE_AA,)
# Show output frame
cv2.imshow("Yoga position", output_frame)
# Add current_position (dict) and frame index to list (output file for debug)
position_list.append(current_position)
frame_list.append(frame_idx)
# Output file for debug
with open(results_classification_path_out, 'a') as f:
f.write(f'{frame_idx},{current_position}\n')
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
elif key == ord("r"):
current_position = {'none':10.0}
print("Position reset !")
frame_idx += 1
pbar.update()
finally:
pose_tracker.close()
video_cap.release()
cv2.destroyAllWindows()
# Close output video.
out_video.release()
return frame_list, position_list
if __name__ == "__main__":
yoga_position_classifier()