Spaces:
Sleeping
Sleeping
import tqdm | |
import cv2 | |
import numpy as np | |
import re | |
import os | |
from mediapipe.python.solutions import drawing_utils as mp_drawing | |
import mediapipe as mp | |
from PoseClassification.pose_embedding import FullBodyPoseEmbedding | |
from PoseClassification.pose_classifier import PoseClassifier | |
from PoseClassification.utils import EMADictSmoothing | |
# from PoseClassification.utils import RepetitionCounter | |
from PoseClassification.visualize import PoseClassificationVisualizer | |
import argparse | |
from PoseClassification.utils import show_image | |
def main(): | |
#Load arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("video_path", help="string video path in") | |
args = parser.parse_args() | |
video_path_in = args.video_path | |
direct_video=False | |
if video_path_in=="live": | |
video_path_in='data/live.mp4' | |
direct_video=True | |
video_path_out = re.sub(r'.mp4', r'_classified_video.mp4', video_path_in) | |
results_classification_path_out = re.sub(r'.mp4', r'_classified_results.csv', video_path_in) | |
# Instruction if direct flux video : not for now | |
if direct_video : | |
video_cap = cv2.VideoCapture(0) | |
video_fps = 30 | |
video_width = 1280 | |
video_height = 720 | |
class_name='tree' | |
# Initialize tracker, classifier and current position. | |
# Initialize tracker. | |
mp_pose = mp.solutions.pose | |
pose_tracker = mp_pose.Pose() | |
# Folder with pose class CSVs. That should be the same folder you used while | |
# building classifier to output CSVs. | |
pose_samples_folder = 'data/yoga_poses_csvs_out' | |
# Initialize embedder. | |
pose_embedder = FullBodyPoseEmbedding() | |
# Initialize classifier. | |
# Check that you are using the same parameters as during bootstrapping. | |
pose_classifier = PoseClassifier( | |
pose_samples_folder=pose_samples_folder, | |
pose_embedder=pose_embedder, | |
top_n_by_max_distance=30, | |
top_n_by_mean_distance=10) | |
# Initialize list of results | |
position_list=[] | |
frame_list=[] | |
# Initialize EMA smoothing. | |
pose_classification_filter = EMADictSmoothing( | |
window_size=10, | |
alpha=0.2) | |
# Initialize renderer. | |
pose_classification_visualizer = PoseClassificationVisualizer( | |
class_name=class_name, | |
plot_x_max=1000, | |
# Graphic looks nicer if it's the same as `top_n_by_mean_distance`. | |
plot_y_max=10) | |
# Open output video. | |
out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height)) | |
# Initialize list of results | |
frame_idx = 0 | |
current_position = {"none":10.0} | |
output_frame = None | |
try: | |
with tqdm.tqdm(position=0, leave=True) as pbar: | |
while True: | |
#on rajoute à chaque itération la valeur de current_position et de frame_idx | |
position_list.append(current_position) | |
frame_list.append(frame_idx) | |
#on renvoie les deux valeurs au fur et à mesure | |
with open(results_classification_path_out, 'a') as f: | |
f.write(f'{frame_idx};{current_position}\n') | |
success, input_frame = video_cap.read() | |
if not success: | |
print("Unable to read input video frame, breaking!") | |
break | |
# Run pose tracker | |
input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
result = pose_tracker.process(image=input_frame_rgb) | |
pose_landmarks = result.pose_landmarks | |
# Prepare the output frame | |
output_frame = input_frame.copy() | |
# Add a white banner on top | |
banner_height = 180 | |
output_frame[0:banner_height, :] = (255, 255, 255) # White color | |
# Load the logo image | |
logo = cv2.imread("src/logo_impredalam.jpg") | |
logo_height, logo_width = logo.shape[:2] | |
logo = cv2.resize( | |
logo, (logo_width // 3, logo_height // 3) | |
) # Resize to 1/3 scale | |
# Overlay the logo on the upper right corner | |
output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = ( | |
logo | |
) | |
if pose_landmarks is not None: | |
mp_drawing.draw_landmarks( | |
image=output_frame, | |
landmark_list=pose_landmarks, | |
connections=mp_pose.POSE_CONNECTIONS, | |
) | |
# Get landmarks | |
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
pose_landmarks = np.array( | |
[ | |
[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
for lmk in pose_landmarks.landmark | |
], | |
dtype=np.float32, | |
) | |
assert pose_landmarks.shape == ( | |
33, | |
3, | |
), "Unexpected landmarks shape: {}".format(pose_landmarks.shape) | |
# Classify the pose on the current frame | |
pose_classification = pose_classifier(pose_landmarks) | |
# Smooth classification using EMA | |
pose_classification_filtered = pose_classification_filter(pose_classification) | |
current_position=pose_classification_filtered | |
# Count repetitions | |
# repetitions_count = repetition_counter(pose_classification_filtered) | |
# Display repetitions count on the frame | |
# cv2.putText( | |
# output_frame, | |
# f"Push-Ups: {repetitions_count}", | |
# (10, 30), | |
# cv2.FONT_HERSHEY_SIMPLEX, | |
# 1, | |
# (0, 0, 0), | |
# 2, | |
# cv2.LINE_AA, | |
# ) | |
# Display classified pose on the frame | |
cv2.putText( | |
output_frame, | |
f"Pose: {current_position}", | |
(10, 70), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1.2, # Smaller font size | |
(0, 0, 0), | |
1, # Thinner line | |
cv2.LINE_AA, | |
) | |
else: | |
# If no landmarks are detected, still display the last count | |
# repetitions_count = repetition_counter.n_repeats | |
# cv2.putText( | |
# output_frame, | |
# f"Push-Ups: {repetitions_count}", | |
# (10, 30), | |
# cv2.FONT_HERSHEY_SIMPLEX, | |
# 1, | |
# (0, 255, 0), | |
# 2, | |
# cv2.LINE_AA, | |
# ) | |
current_position={'None':10.0} | |
cv2.putText( | |
output_frame, | |
f"Pose: {current_position}", | |
(10, 70), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1.2, # Smaller font size | |
(0, 0, 0), | |
1, # Thinner line | |
cv2.LINE_AA, | |
) | |
cv2.imshow("Yoga position classification", output_frame) | |
key = cv2.waitKey(1) & 0xFF | |
if key == ord("q"): | |
break | |
elif key == ord("r"): | |
# repetition_counter.reset() | |
print("Counter reset!") | |
frame_idx += 1 | |
pbar.update() | |
finally: | |
pose_tracker.close() | |
video_cap.release() | |
cv2.destroyAllWindows() | |
# Instruction if recorded video with video_path_in | |
else: | |
assert type(video_path_in)==str, "Error in video path format, not a string. Abort." | |
# Open video and get video parameters and check if video is OK | |
video_cap = cv2.VideoCapture(video_path_in) | |
video_n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
video_fps = video_cap.get(cv2.CAP_PROP_FPS) | |
video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
assert type(video_n_frames)==float, 'Error in input video frames type. Abort.' | |
assert video_n_frames>0.0, 'Error in input video frames number : no frame. Abort.' | |
class_name='tree' | |
# Initialize tracker, classifier and current position. | |
# Initialize tracker. | |
mp_pose = mp.solutions.pose | |
pose_tracker = mp_pose.Pose() | |
# Folder with pose class CSVs. That should be the same folder you used while | |
# building classifier to output CSVs. | |
pose_samples_folder = 'data/yoga_poses_csvs_out' | |
# Initialize embedder. | |
pose_embedder = FullBodyPoseEmbedding() | |
# Initialize classifier. | |
# Check that you are using the same parameters as during bootstrapping. | |
pose_classifier = PoseClassifier( | |
pose_samples_folder=pose_samples_folder, | |
pose_embedder=pose_embedder, | |
top_n_by_max_distance=30, | |
top_n_by_mean_distance=10) | |
# Initialize list of results | |
position_list=[] | |
frame_list=[] | |
# Initialize EMA smoothing. | |
pose_classification_filter = EMADictSmoothing( | |
window_size=10, | |
alpha=0.2) | |
# Initialize renderer. | |
pose_classification_visualizer = PoseClassificationVisualizer( | |
class_name=class_name, | |
plot_x_max=video_n_frames, | |
# Graphic looks nicer if it's the same as `top_n_by_mean_distance`. | |
plot_y_max=10) | |
# Open output video. | |
out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height)) | |
# Initialize list of results | |
frame_idx = 0 | |
current_position = {"none":10.0} | |
output_frame = None | |
with tqdm.tqdm(total=video_n_frames, position=0, leave=True) as pbar: | |
while True: | |
#on rajoute à chaque itération la valeur de current_position et de frame_idx | |
position_list.append(current_position) | |
frame_list.append(frame_idx) | |
#on renvoie les deux valeurs au fur et à mesure | |
with open(results_classification_path_out, 'a') as f: | |
f.write(f'{frame_idx};{current_position}\n') | |
# Get next frame of the video. | |
success, input_frame = video_cap.read() | |
if not success: | |
print("unable to read input video frame, breaking!") | |
break | |
# Run pose tracker. | |
input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
result = pose_tracker.process(image=input_frame) | |
pose_landmarks = result.pose_landmarks | |
# Draw pose prediction. | |
output_frame = input_frame.copy() | |
if pose_landmarks is not None: | |
mp_drawing.draw_landmarks( | |
image=output_frame, | |
landmark_list=pose_landmarks, | |
connections=mp_pose.POSE_CONNECTIONS) | |
if pose_landmarks is not None: | |
# Get landmarks. | |
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
pose_landmarks = np.array([[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
for lmk in pose_landmarks.landmark], dtype=np.float32) | |
assert pose_landmarks.shape == (33, 3), 'Unexpected landmarks shape: {}'.format(pose_landmarks.shape) | |
# Classify the pose on the current frame. | |
pose_classification = pose_classifier(pose_landmarks) | |
# Smooth classification using EMA. | |
pose_classification_filtered = pose_classification_filter(pose_classification) | |
current_position=pose_classification_filtered | |
# Count repetitions. | |
# repetitions_count = repetition_counter(pose_classification_filtered) | |
else: | |
# No pose => no classification on current frame. | |
pose_classification = None | |
# Still add empty classification to the filter to maintaing correct | |
# smoothing for future frames. | |
pose_classification_filtered = pose_classification_filter(dict()) | |
pose_classification_filtered = None | |
current_position='None' | |
# Don't update the counter presuming that person is 'frozen'. Just | |
# take the latest repetitions count. | |
# repetitions_count = repetition_counter.n_repeats | |
# Draw classification plot and repetition counter. | |
output_frame = pose_classification_visualizer( | |
frame=output_frame, | |
pose_classification=pose_classification, | |
pose_classification_filtered=pose_classification_filtered, | |
repetitions_count='0' | |
) | |
# Save the output frame. | |
out_video.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR)) | |
# Show intermediate frames of the video to track progress. | |
if frame_idx % 50 == 0: | |
show_image(output_frame) | |
frame_idx += 1 | |
pbar.update() | |
# Close output video. | |
out_video.release() | |
# Release MediaPipe resources. | |
pose_tracker.close() | |
# Show the last frame of the video. | |
if output_frame is not None: | |
show_image(output_frame) | |
video_cap.release() | |
return current_position #string between ['Chair', 'Cobra', 'Dog', 'Goddess', 'Plank', 'Tree', 'Warrior', 'None' = nonfallen, 'Fall'] | |
# mp_pose = mp.solutions.pose | |
# pose_tracker = mp_pose.Pose() | |
# pose_samples_folder = "data/yoga_poses_csvs_out" | |
# class_name = "tree" | |
# pose_embedder = FullBodyPoseEmbedding() | |
# pose_classifier = PoseClassifier( | |
# pose_samples_folder=pose_samples_folder, | |
# pose_embedder=pose_embedder, | |
# top_n_by_max_distance=30, | |
# top_n_by_mean_distance=10, | |
# ) | |
# pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2) | |
# repetition_counter = RepetitionCounter( | |
# class_name=class_name, enter_threshold=6, exit_threshold=4 | |
# ) | |
# pose_classification_visualizer = PoseClassificationVisualizer( | |
# class_name=class_name, plot_x_max=1000, plot_y_max=10 | |
# ) | |
# video_cap = cv2.VideoCapture(0) | |
# video_fps = 30 | |
# video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
# video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# frame_idx = 0 | |
# output_frame = None | |
# try: | |
# with tqdm.tqdm(position=0, leave=True) as pbar: | |
# while True: | |
# success, input_frame = video_cap.read() | |
# if not success: | |
# print("Unable to read input video frame, breaking!") | |
# break | |
# # Run pose tracker | |
# input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
# result = pose_tracker.process(image=input_frame_rgb) | |
# pose_landmarks = result.pose_landmarks | |
# # Prepare the output frame | |
# output_frame = input_frame.copy() | |
# if pose_landmarks is not None: | |
# mp_drawing.draw_landmarks( | |
# image=output_frame, | |
# landmark_list=pose_landmarks, | |
# connections=mp_pose.POSE_CONNECTIONS, | |
# ) | |
# # Get landmarks | |
# frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
# pose_landmarks = np.array( | |
# [ | |
# [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
# for lmk in pose_landmarks.landmark | |
# ], | |
# dtype=np.float32, | |
# ) | |
# assert pose_landmarks.shape == ( | |
# 33, | |
# 3, | |
# ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape) | |
# # Classify the pose on the current frame | |
# pose_classification = pose_classifier(pose_landmarks) | |
# # Smooth classification using EMA | |
# pose_classification_filtered = pose_classification_filter( | |
# pose_classification | |
# ) | |
# # Count repetitions | |
# # repetitions_count = repetition_counter(pose_classification_filtered) | |
# # Display repetitions count on the frame | |
# # cv2.putText( | |
# # output_frame, | |
# # f"Push-Ups: {repetitions_count}", | |
# # (10, 30), | |
# # cv2.FONT_HERSHEY_SIMPLEX, | |
# # 1, | |
# # (0, 255, 0), | |
# # 2, | |
# # cv2.LINE_AA, | |
# # ) | |
# # Display classified pose on the frame | |
# cv2.putText( | |
# output_frame, | |
# f"Pose: {pose_classification}", | |
# (10, 70), | |
# cv2.FONT_HERSHEY_SIMPLEX, | |
# 1, | |
# (255, 0, 0), | |
# 2, | |
# cv2.LINE_AA, | |
# ) | |
# else: | |
# # If no landmarks are detected, still display the last count | |
# # repetitions_count = repetition_counter.n_repeats | |
# # cv2.putText( | |
# # output_frame, | |
# # f"Push-Ups: {repetitions_count}", | |
# # (10, 30), | |
# # cv2.FONT_HERSHEY_SIMPLEX, | |
# # 1, | |
# # (0, 255, 0), | |
# # 2, | |
# # cv2.LINE_AA, | |
# # ) | |
# # If no landmarks are detected, still display the last classified pose | |
# # Display classified pose on the frame | |
# cv2.putText( | |
# output_frame, | |
# f"Pose: {pose_classification}", | |
# (10, 70), | |
# cv2.FONT_HERSHEY_SIMPLEX, | |
# 1, | |
# (255, 0, 0), | |
# 2, | |
# cv2.LINE_AA, | |
# ) | |
# cv2.imshow("Yoga pose classification", output_frame) | |
# key = cv2.waitKey(1) & 0xFF | |
# if key == ord("q"): | |
# break | |
# elif key == ord("r"): | |
# # repetition_counter.reset() | |
# print("Counter reset!") | |
# frame_idx += 1 | |
# pbar.update() | |
# finally: | |
# pose_tracker.close() | |
# video_cap.release() | |
# cv2.destroyAllWindows() | |
if __name__ == "__main__": | |
main() |