YOGAI / yoga_position.py
1mpreccable's picture
Upload 35 files
0ccc9b6 verified
import tqdm
import cv2
import numpy as np
import re
import os
from mediapipe.python.solutions import drawing_utils as mp_drawing
import mediapipe as mp
from PoseClassification.pose_embedding import FullBodyPoseEmbedding
from PoseClassification.pose_classifier import PoseClassifier
from PoseClassification.utils import EMADictSmoothing
# from PoseClassification.utils import RepetitionCounter
from PoseClassification.visualize import PoseClassificationVisualizer
import argparse
from PoseClassification.utils import show_image
def main():
#Load arguments
parser = argparse.ArgumentParser()
parser.add_argument("video_path", help="string video path in")
args = parser.parse_args()
video_path_in = args.video_path
direct_video=False
if video_path_in=="live":
video_path_in='data/live.mp4'
direct_video=True
video_path_out = re.sub(r'.mp4', r'_classified_video.mp4', video_path_in)
results_classification_path_out = re.sub(r'.mp4', r'_classified_results.csv', video_path_in)
# Instruction if direct flux video : not for now
if direct_video :
video_cap = cv2.VideoCapture(0)
video_fps = 30
video_width = 1280
video_height = 720
class_name='tree'
# Initialize tracker, classifier and current position.
# Initialize tracker.
mp_pose = mp.solutions.pose
pose_tracker = mp_pose.Pose()
# Folder with pose class CSVs. That should be the same folder you used while
# building classifier to output CSVs.
pose_samples_folder = 'data/yoga_poses_csvs_out'
# Initialize embedder.
pose_embedder = FullBodyPoseEmbedding()
# Initialize classifier.
# Check that you are using the same parameters as during bootstrapping.
pose_classifier = PoseClassifier(
pose_samples_folder=pose_samples_folder,
pose_embedder=pose_embedder,
top_n_by_max_distance=30,
top_n_by_mean_distance=10)
# Initialize list of results
position_list=[]
frame_list=[]
# Initialize EMA smoothing.
pose_classification_filter = EMADictSmoothing(
window_size=10,
alpha=0.2)
# Initialize renderer.
pose_classification_visualizer = PoseClassificationVisualizer(
class_name=class_name,
plot_x_max=1000,
# Graphic looks nicer if it's the same as `top_n_by_mean_distance`.
plot_y_max=10)
# Open output video.
out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height))
# Initialize list of results
frame_idx = 0
current_position = {"none":10.0}
output_frame = None
try:
with tqdm.tqdm(position=0, leave=True) as pbar:
while True:
#on rajoute à chaque itération la valeur de current_position et de frame_idx
position_list.append(current_position)
frame_list.append(frame_idx)
#on renvoie les deux valeurs au fur et à mesure
with open(results_classification_path_out, 'a') as f:
f.write(f'{frame_idx};{current_position}\n')
success, input_frame = video_cap.read()
if not success:
print("Unable to read input video frame, breaking!")
break
# Run pose tracker
input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
result = pose_tracker.process(image=input_frame_rgb)
pose_landmarks = result.pose_landmarks
# Prepare the output frame
output_frame = input_frame.copy()
# Add a white banner on top
banner_height = 180
output_frame[0:banner_height, :] = (255, 255, 255) # White color
# Load the logo image
logo = cv2.imread("src/logo_impredalam.jpg")
logo_height, logo_width = logo.shape[:2]
logo = cv2.resize(
logo, (logo_width // 3, logo_height // 3)
) # Resize to 1/3 scale
# Overlay the logo on the upper right corner
output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (
logo
)
if pose_landmarks is not None:
mp_drawing.draw_landmarks(
image=output_frame,
landmark_list=pose_landmarks,
connections=mp_pose.POSE_CONNECTIONS,
)
# Get landmarks
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
pose_landmarks = np.array(
[
[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
for lmk in pose_landmarks.landmark
],
dtype=np.float32,
)
assert pose_landmarks.shape == (
33,
3,
), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
# Classify the pose on the current frame
pose_classification = pose_classifier(pose_landmarks)
# Smooth classification using EMA
pose_classification_filtered = pose_classification_filter(pose_classification)
current_position=pose_classification_filtered
# Count repetitions
# repetitions_count = repetition_counter(pose_classification_filtered)
# Display repetitions count on the frame
# cv2.putText(
# output_frame,
# f"Push-Ups: {repetitions_count}",
# (10, 30),
# cv2.FONT_HERSHEY_SIMPLEX,
# 1,
# (0, 0, 0),
# 2,
# cv2.LINE_AA,
# )
# Display classified pose on the frame
cv2.putText(
output_frame,
f"Pose: {current_position}",
(10, 70),
cv2.FONT_HERSHEY_SIMPLEX,
1.2, # Smaller font size
(0, 0, 0),
1, # Thinner line
cv2.LINE_AA,
)
else:
# If no landmarks are detected, still display the last count
# repetitions_count = repetition_counter.n_repeats
# cv2.putText(
# output_frame,
# f"Push-Ups: {repetitions_count}",
# (10, 30),
# cv2.FONT_HERSHEY_SIMPLEX,
# 1,
# (0, 255, 0),
# 2,
# cv2.LINE_AA,
# )
current_position={'None':10.0}
cv2.putText(
output_frame,
f"Pose: {current_position}",
(10, 70),
cv2.FONT_HERSHEY_SIMPLEX,
1.2, # Smaller font size
(0, 0, 0),
1, # Thinner line
cv2.LINE_AA,
)
cv2.imshow("Yoga position classification", output_frame)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
elif key == ord("r"):
# repetition_counter.reset()
print("Counter reset!")
frame_idx += 1
pbar.update()
finally:
pose_tracker.close()
video_cap.release()
cv2.destroyAllWindows()
# Instruction if recorded video with video_path_in
else:
assert type(video_path_in)==str, "Error in video path format, not a string. Abort."
# Open video and get video parameters and check if video is OK
video_cap = cv2.VideoCapture(video_path_in)
video_n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
video_fps = video_cap.get(cv2.CAP_PROP_FPS)
video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
assert type(video_n_frames)==float, 'Error in input video frames type. Abort.'
assert video_n_frames>0.0, 'Error in input video frames number : no frame. Abort.'
class_name='tree'
# Initialize tracker, classifier and current position.
# Initialize tracker.
mp_pose = mp.solutions.pose
pose_tracker = mp_pose.Pose()
# Folder with pose class CSVs. That should be the same folder you used while
# building classifier to output CSVs.
pose_samples_folder = 'data/yoga_poses_csvs_out'
# Initialize embedder.
pose_embedder = FullBodyPoseEmbedding()
# Initialize classifier.
# Check that you are using the same parameters as during bootstrapping.
pose_classifier = PoseClassifier(
pose_samples_folder=pose_samples_folder,
pose_embedder=pose_embedder,
top_n_by_max_distance=30,
top_n_by_mean_distance=10)
# Initialize list of results
position_list=[]
frame_list=[]
# Initialize EMA smoothing.
pose_classification_filter = EMADictSmoothing(
window_size=10,
alpha=0.2)
# Initialize renderer.
pose_classification_visualizer = PoseClassificationVisualizer(
class_name=class_name,
plot_x_max=video_n_frames,
# Graphic looks nicer if it's the same as `top_n_by_mean_distance`.
plot_y_max=10)
# Open output video.
out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height))
# Initialize list of results
frame_idx = 0
current_position = {"none":10.0}
output_frame = None
with tqdm.tqdm(total=video_n_frames, position=0, leave=True) as pbar:
while True:
#on rajoute à chaque itération la valeur de current_position et de frame_idx
position_list.append(current_position)
frame_list.append(frame_idx)
#on renvoie les deux valeurs au fur et à mesure
with open(results_classification_path_out, 'a') as f:
f.write(f'{frame_idx};{current_position}\n')
# Get next frame of the video.
success, input_frame = video_cap.read()
if not success:
print("unable to read input video frame, breaking!")
break
# Run pose tracker.
input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
result = pose_tracker.process(image=input_frame)
pose_landmarks = result.pose_landmarks
# Draw pose prediction.
output_frame = input_frame.copy()
if pose_landmarks is not None:
mp_drawing.draw_landmarks(
image=output_frame,
landmark_list=pose_landmarks,
connections=mp_pose.POSE_CONNECTIONS)
if pose_landmarks is not None:
# Get landmarks.
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
pose_landmarks = np.array([[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
for lmk in pose_landmarks.landmark], dtype=np.float32)
assert pose_landmarks.shape == (33, 3), 'Unexpected landmarks shape: {}'.format(pose_landmarks.shape)
# Classify the pose on the current frame.
pose_classification = pose_classifier(pose_landmarks)
# Smooth classification using EMA.
pose_classification_filtered = pose_classification_filter(pose_classification)
current_position=pose_classification_filtered
# Count repetitions.
# repetitions_count = repetition_counter(pose_classification_filtered)
else:
# No pose => no classification on current frame.
pose_classification = None
# Still add empty classification to the filter to maintaing correct
# smoothing for future frames.
pose_classification_filtered = pose_classification_filter(dict())
pose_classification_filtered = None
current_position='None'
# Don't update the counter presuming that person is 'frozen'. Just
# take the latest repetitions count.
# repetitions_count = repetition_counter.n_repeats
# Draw classification plot and repetition counter.
output_frame = pose_classification_visualizer(
frame=output_frame,
pose_classification=pose_classification,
pose_classification_filtered=pose_classification_filtered,
repetitions_count='0'
)
# Save the output frame.
out_video.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR))
# Show intermediate frames of the video to track progress.
if frame_idx % 50 == 0:
show_image(output_frame)
frame_idx += 1
pbar.update()
# Close output video.
out_video.release()
# Release MediaPipe resources.
pose_tracker.close()
# Show the last frame of the video.
if output_frame is not None:
show_image(output_frame)
video_cap.release()
return current_position #string between ['Chair', 'Cobra', 'Dog', 'Goddess', 'Plank', 'Tree', 'Warrior', 'None' = nonfallen, 'Fall']
# mp_pose = mp.solutions.pose
# pose_tracker = mp_pose.Pose()
# pose_samples_folder = "data/yoga_poses_csvs_out"
# class_name = "tree"
# pose_embedder = FullBodyPoseEmbedding()
# pose_classifier = PoseClassifier(
# pose_samples_folder=pose_samples_folder,
# pose_embedder=pose_embedder,
# top_n_by_max_distance=30,
# top_n_by_mean_distance=10,
# )
# pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2)
# repetition_counter = RepetitionCounter(
# class_name=class_name, enter_threshold=6, exit_threshold=4
# )
# pose_classification_visualizer = PoseClassificationVisualizer(
# class_name=class_name, plot_x_max=1000, plot_y_max=10
# )
# video_cap = cv2.VideoCapture(0)
# video_fps = 30
# video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# frame_idx = 0
# output_frame = None
# try:
# with tqdm.tqdm(position=0, leave=True) as pbar:
# while True:
# success, input_frame = video_cap.read()
# if not success:
# print("Unable to read input video frame, breaking!")
# break
# # Run pose tracker
# input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
# result = pose_tracker.process(image=input_frame_rgb)
# pose_landmarks = result.pose_landmarks
# # Prepare the output frame
# output_frame = input_frame.copy()
# if pose_landmarks is not None:
# mp_drawing.draw_landmarks(
# image=output_frame,
# landmark_list=pose_landmarks,
# connections=mp_pose.POSE_CONNECTIONS,
# )
# # Get landmarks
# frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
# pose_landmarks = np.array(
# [
# [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
# for lmk in pose_landmarks.landmark
# ],
# dtype=np.float32,
# )
# assert pose_landmarks.shape == (
# 33,
# 3,
# ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
# # Classify the pose on the current frame
# pose_classification = pose_classifier(pose_landmarks)
# # Smooth classification using EMA
# pose_classification_filtered = pose_classification_filter(
# pose_classification
# )
# # Count repetitions
# # repetitions_count = repetition_counter(pose_classification_filtered)
# # Display repetitions count on the frame
# # cv2.putText(
# # output_frame,
# # f"Push-Ups: {repetitions_count}",
# # (10, 30),
# # cv2.FONT_HERSHEY_SIMPLEX,
# # 1,
# # (0, 255, 0),
# # 2,
# # cv2.LINE_AA,
# # )
# # Display classified pose on the frame
# cv2.putText(
# output_frame,
# f"Pose: {pose_classification}",
# (10, 70),
# cv2.FONT_HERSHEY_SIMPLEX,
# 1,
# (255, 0, 0),
# 2,
# cv2.LINE_AA,
# )
# else:
# # If no landmarks are detected, still display the last count
# # repetitions_count = repetition_counter.n_repeats
# # cv2.putText(
# # output_frame,
# # f"Push-Ups: {repetitions_count}",
# # (10, 30),
# # cv2.FONT_HERSHEY_SIMPLEX,
# # 1,
# # (0, 255, 0),
# # 2,
# # cv2.LINE_AA,
# # )
# # If no landmarks are detected, still display the last classified pose
# # Display classified pose on the frame
# cv2.putText(
# output_frame,
# f"Pose: {pose_classification}",
# (10, 70),
# cv2.FONT_HERSHEY_SIMPLEX,
# 1,
# (255, 0, 0),
# 2,
# cv2.LINE_AA,
# )
# cv2.imshow("Yoga pose classification", output_frame)
# key = cv2.waitKey(1) & 0xFF
# if key == ord("q"):
# break
# elif key == ord("r"):
# # repetition_counter.reset()
# print("Counter reset!")
# frame_idx += 1
# pbar.update()
# finally:
# pose_tracker.close()
# video_cap.release()
# cv2.destroyAllWindows()
if __name__ == "__main__":
main()