Spaces:
Sleeping
Sleeping
import tqdm | |
import cv2 | |
import numpy as np | |
import re | |
import os | |
from mediapipe.python.solutions import drawing_utils as mp_drawing | |
import mediapipe as mp | |
from PoseClassification.pose_embedding import FullBodyPoseEmbedding | |
from PoseClassification.pose_classifier import PoseClassifier | |
from PoseClassification.utils import EMADictSmoothing | |
# from PoseClassification.utils import RepetitionCounter | |
from PoseClassification.visualize import PoseClassificationVisualizer | |
import argparse | |
from PoseClassification.utils import show_image | |
def check_major_current_position(positions_detected:dict, threshold_position) -> str: | |
''' | |
return the major position between those detected in frame, or return none | |
INPUTS | |
positions_detected : | |
dict of positions given by position classifier and pose_classification_filtered | |
{'pose1':8.0, 'pose2':2.0} | |
threshold_position : | |
values strictly below are considered "none" position | |
OUTPUT | |
major_position : | |
string with position (classes from classifier and "none") | |
''' | |
if max(positions_detected.values())<float(threshold_position): | |
major_position='none' | |
else: | |
major_position=max(positions_detected, key=positions_detected.get) | |
return major_position | |
def yoga_position_classifier(): | |
#Load arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("video_path", help="string video path in") | |
args = parser.parse_args() | |
video_path_in = args.video_path | |
direct_video=False | |
if video_path_in=="live": | |
video_path_in='data/live.mp4' | |
direct_video=True | |
video_path_out = re.sub(r'.mp4', r'_classified_video.mp4', video_path_in) | |
results_classification_path_out = re.sub(r'.mp4', r'_classified_results.csv', video_path_in) | |
# Initialize tracker, classifier and current position. | |
# Initialize tracker. | |
mp_pose = mp.solutions.pose | |
pose_tracker = mp_pose.Pose() | |
# Folder with pose class CSVs. That should be the same folder you used while | |
# building classifier to output CSVs. | |
pose_samples_folder = 'data/yoga_poses_csvs_out' | |
# Initialize embedder. | |
pose_embedder = FullBodyPoseEmbedding() | |
# Initialize classifier. | |
# Check that you are using the same parameters as during bootstrapping. | |
pose_classifier = PoseClassifier( | |
pose_samples_folder=pose_samples_folder, | |
pose_embedder=pose_embedder, | |
top_n_by_max_distance=30, | |
top_n_by_mean_distance=10) | |
# Initialize EMA smoothing. | |
pose_classification_filter = EMADictSmoothing( | |
window_size=10, | |
alpha=0.2) | |
# Initialize list of results | |
position_list=[] | |
frame_list=[] | |
# Instruction if direct flux video | |
if direct_video : | |
video_cap = cv2.VideoCapture(0) | |
# Instruction if path video | |
else : | |
assert type(video_path_in)==str, "Error in video path format, not a string. Abort." | |
# Open video and get video parameters and check if video is OK | |
video_cap = cv2.VideoCapture(video_path_in) | |
video_n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
assert type(video_n_frames)==float, 'Error in input video frames type. Abort.' | |
assert video_n_frames>0.0, 'Error in input video frames number : no frame. Abort.' | |
video_fps = video_cap.get(cv2.CAP_PROP_FPS) | |
video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
class_names=['chair', 'cobra', 'dog', 'goddess', 'plank', 'tree', 'warrior', 'none'] | |
position_threshold = 8.0 | |
# Open output video. | |
out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height)) | |
# Initialize results | |
frame_idx = 0 | |
current_position = {"none":10.0} | |
output_frame = None | |
position_timer = 0 | |
previous_position_major = 'none' | |
try: | |
with tqdm.tqdm(position=0, leave=True) as pbar: | |
while True: | |
# Get current time from beggining of video | |
time_sec = float(frame_idx*(1/video_fps)) | |
# Get current major position (str) | |
current_position_major = check_major_current_position(current_position, position_threshold) | |
success, input_frame = video_cap.read() | |
if not success: | |
print("Unable to read input video frame, breaking!") | |
break | |
# Run pose tracker | |
input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
result = pose_tracker.process(image=input_frame_rgb) | |
pose_landmarks = result.pose_landmarks | |
# Prepare the output frame | |
output_frame = input_frame.copy() | |
# Add a white banner on top | |
banner_height = int(video_height//10) | |
output_frame[0:banner_height, :] = (255, 255, 255) # White color | |
# Load the logo image | |
logo = cv2.imread("src/logo_impredalam.jpg") | |
logo_height, logo_width = logo.shape[:2] | |
logo_height_rescaled = banner_height | |
logo_width_rescaled = int((logo_width*logo_height_rescaled)// logo_height ) | |
logo = cv2.resize(logo, (logo_width_rescaled, logo_height_rescaled)) # Resize to banner scale | |
# Overlay the logo on the upper right corner | |
output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (logo) | |
# If landmarks are detected | |
if pose_landmarks is not None: | |
mp_drawing.draw_landmarks( | |
image=output_frame, | |
landmark_list=pose_landmarks, | |
connections=mp_pose.POSE_CONNECTIONS,) | |
# Get landmarks | |
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
pose_landmarks = np.array( | |
[ | |
[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
for lmk in pose_landmarks.landmark | |
], | |
dtype=np.float32,) | |
assert pose_landmarks.shape == (33,3,), "Unexpected landmarks shape: {}".format(pose_landmarks.shape) | |
# Classify the pose on the current frame | |
pose_classification = pose_classifier(pose_landmarks) | |
# Smooth classification using EMA | |
pose_classification_filtered = pose_classification_filter(pose_classification) | |
current_position=pose_classification_filtered | |
current_position_major=check_major_current_position(current_position, position_threshold) | |
# If no landmarks are detected | |
else: | |
current_position={'none':10.0} | |
current_position_major=check_major_current_position(current_position, position_threshold) | |
# If landmarks or no landmarks detected : | |
# Compute position timer according to current and previous position | |
if current_position_major==previous_position_major: | |
#increase position_timer | |
position_timer+=(1/video_fps) | |
else: | |
previous_position_major=current_position_major | |
position_timer=0 | |
# Display current position on frame | |
cv2.putText( | |
output_frame, | |
f"Pose: {current_position_major}", | |
(int(0+(1//50*video_width)), int(0+banner_height//3)), #coord | |
cv2.FONT_HERSHEY_SIMPLEX, | |
float(0.9*(video_height/video_width)), # Font size | |
(0, 0, 0), #color | |
1, # Thinner line | |
cv2.LINE_AA,) | |
# Display current position timer on frame | |
cv2.putText( | |
output_frame, | |
f"Duration: {int(position_timer)} seconds", | |
(int(0+(1//50*video_width)), int(0+(2*banner_height)//3)), #coord | |
cv2.FONT_HERSHEY_SIMPLEX, | |
float(0.9*(video_height/video_width)), # Font size | |
(0, 0, 0), #color | |
1, # Thinner line | |
cv2.LINE_AA,) | |
# Show output frame | |
cv2.imshow("Yoga position", output_frame) | |
# Add current_position (dict) and frame index to list (output file for debug) | |
position_list.append(current_position) | |
frame_list.append(frame_idx) | |
# Output file for debug | |
with open(results_classification_path_out, 'a') as f: | |
f.write(f'{frame_idx},{current_position}\n') | |
key = cv2.waitKey(1) & 0xFF | |
if key == ord("q"): | |
break | |
elif key == ord("r"): | |
current_position = {'none':10.0} | |
print("Position reset !") | |
frame_idx += 1 | |
pbar.update() | |
finally: | |
pose_tracker.close() | |
video_cap.release() | |
cv2.destroyAllWindows() | |
# Close output video. | |
out_video.release() | |
return frame_list, position_list | |
if __name__ == "__main__": | |
yoga_position_classifier() |