1mpreccable commited on
Commit
0ccc9b6
·
verified ·
1 Parent(s): 3c0113d

Upload 35 files

Browse files
PoseClassification/__pycache__/bootstrap.cpython-312.pyc ADDED
Binary file (12.7 kB). View file
 
PoseClassification/__pycache__/pose_classifier.cpython-312.pyc ADDED
Binary file (8.15 kB). View file
 
PoseClassification/__pycache__/pose_embedding.cpython-312.pyc ADDED
Binary file (9.31 kB). View file
 
PoseClassification/__pycache__/utils.cpython-312.pyc ADDED
Binary file (4.93 kB). View file
 
PoseClassification/__pycache__/visualize.cpython-312.pyc ADDED
Binary file (6.14 kB). View file
 
PoseClassification/bootstrap.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from matplotlib import pyplot as plt
3
+ import numpy as np
4
+ import os, csv
5
+ from PIL import Image, ImageDraw
6
+ import sys
7
+ import tqdm
8
+
9
+ from mediapipe.python.solutions import drawing_utils as mp_drawing
10
+ from mediapipe.python.solutions import pose as mp_pose
11
+
12
+ from PoseClassification.utils import show_image
13
+
14
+ class BootstrapHelper(object):
15
+ """Helps to bootstrap images and filter pose samples for classification."""
16
+
17
+ def __init__(self, images_in_folder, images_out_folder, csvs_out_folder):
18
+ self._images_in_folder = images_in_folder
19
+ self._images_out_folder = images_out_folder
20
+ self._csvs_out_folder = csvs_out_folder
21
+
22
+ # Get list of pose classes and print image statistics.
23
+ self._pose_class_names = sorted(
24
+ [n for n in os.listdir(self._images_in_folder) if not n.startswith(".")]
25
+ )
26
+
27
+ def bootstrap(self, per_pose_class_limit=None):
28
+ """Bootstraps images in a given folder.
29
+
30
+ Required image in folder (same use for image out folder):
31
+ pushups_up/
32
+ image_001.jpg
33
+ image_002.jpg
34
+ ...
35
+ pushups_down/
36
+ image_001.jpg
37
+ image_002.jpg
38
+ ...
39
+ ...
40
+
41
+ Produced CSVs out folder:
42
+ pushups_up.csv
43
+ pushups_down.csv
44
+
45
+ Produced CSV structure with pose 3D landmarks:
46
+ sample_00001,x1,y1,z1,x2,y2,z2,....
47
+ sample_00002,x1,y1,z1,x2,y2,z2,....
48
+ """
49
+ # Create output folder for CVSs.
50
+ if not os.path.exists(self._csvs_out_folder):
51
+ os.makedirs(self._csvs_out_folder)
52
+
53
+ for pose_class_name in self._pose_class_names:
54
+ print("Bootstrapping ", pose_class_name, file=sys.stderr)
55
+
56
+ # Paths for the pose class.
57
+ images_in_folder = os.path.join(self._images_in_folder, pose_class_name)
58
+ images_out_folder = os.path.join(self._images_out_folder, pose_class_name)
59
+ csv_out_path = os.path.join(self._csvs_out_folder, pose_class_name + ".csv")
60
+ if not os.path.exists(images_out_folder):
61
+ os.makedirs(images_out_folder)
62
+
63
+ with open(csv_out_path, "w") as csv_out_file:
64
+ csv_out_writer = csv.writer(
65
+ csv_out_file, delimiter=",", quoting=csv.QUOTE_MINIMAL
66
+ )
67
+ # Get list of images.
68
+ image_names = sorted(
69
+ [n for n in os.listdir(images_in_folder) if not n.startswith(".")]
70
+ )
71
+ if per_pose_class_limit is not None:
72
+ image_names = image_names[:per_pose_class_limit]
73
+
74
+ # Bootstrap every image.
75
+ for image_name in tqdm.tqdm(image_names):
76
+ # Load image.
77
+ input_frame = cv2.imread(os.path.join(images_in_folder, image_name))
78
+ input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
79
+
80
+ # Initialize fresh pose tracker and run it.
81
+ # with mp_pose.Pose(upper_body_only=False) as pose_tracker:
82
+ with mp_pose.Pose() as pose_tracker:
83
+ result = pose_tracker.process(image=input_frame)
84
+ pose_landmarks = result.pose_landmarks
85
+
86
+ # Save image with pose prediction (if pose was detected).
87
+ output_frame = input_frame.copy()
88
+ if pose_landmarks is not None:
89
+ mp_drawing.draw_landmarks(
90
+ image=output_frame,
91
+ landmark_list=pose_landmarks,
92
+ connections=mp_pose.POSE_CONNECTIONS,
93
+ )
94
+ output_frame = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR)
95
+ cv2.imwrite(
96
+ os.path.join(images_out_folder, image_name), output_frame
97
+ )
98
+
99
+ # Save landmarks if pose was detected.
100
+ if pose_landmarks is not None:
101
+ # Get landmarks.
102
+ frame_height, frame_width = (
103
+ output_frame.shape[0],
104
+ output_frame.shape[1],
105
+ )
106
+ pose_landmarks = np.array(
107
+ [
108
+ [
109
+ lmk.x * frame_width,
110
+ lmk.y * frame_height,
111
+ lmk.z * frame_width,
112
+ ]
113
+ for lmk in pose_landmarks.landmark
114
+ ],
115
+ dtype=np.float32,
116
+ )
117
+ assert pose_landmarks.shape == (
118
+ 33,
119
+ 3,
120
+ ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
121
+ csv_out_writer.writerow(
122
+ [image_name] + pose_landmarks.flatten().astype(str).tolist()
123
+ )
124
+
125
+ # Draw XZ projection and concatenate with the image.
126
+ projection_xz = self._draw_xz_projection(
127
+ output_frame=output_frame, pose_landmarks=pose_landmarks
128
+ )
129
+ output_frame = np.concatenate((output_frame, projection_xz), axis=1)
130
+
131
+ def _draw_xz_projection(self, output_frame, pose_landmarks, r=0.5, color="red"):
132
+ frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
133
+ img = Image.new("RGB", (frame_width, frame_height), color="white")
134
+
135
+ if pose_landmarks is None:
136
+ return np.asarray(img)
137
+
138
+ # Scale radius according to the image width.
139
+ r *= frame_width * 0.01
140
+
141
+ draw = ImageDraw.Draw(img)
142
+ for idx_1, idx_2 in mp_pose.POSE_CONNECTIONS:
143
+ # Flip Z and move hips center to the center of the image.
144
+ x1, y1, z1 = pose_landmarks[idx_1] * [1, 1, -1] + [0, 0, frame_height * 0.5]
145
+ x2, y2, z2 = pose_landmarks[idx_2] * [1, 1, -1] + [0, 0, frame_height * 0.5]
146
+
147
+ draw.ellipse([x1 - r, z1 - r, x1 + r, z1 + r], fill=color)
148
+ draw.ellipse([x2 - r, z2 - r, x2 + r, z2 + r], fill=color)
149
+ draw.line([x1, z1, x2, z2], width=int(r), fill=color)
150
+
151
+ return np.asarray(img)
152
+
153
+ def align_images_and_csvs(self, print_removed_items=False):
154
+ """Makes sure that image folders and CSVs have the same sample.
155
+
156
+ Leaves only intersetion of samples in both image folders and CSVs.
157
+ """
158
+ for pose_class_name in self._pose_class_names:
159
+ # Paths for the pose class.
160
+ images_out_folder = os.path.join(self._images_out_folder, pose_class_name)
161
+ csv_out_path = os.path.join(self._csvs_out_folder, pose_class_name + ".csv")
162
+
163
+ # Read CSV into memory.
164
+ rows = []
165
+ with open(csv_out_path) as csv_out_file:
166
+ csv_out_reader = csv.reader(csv_out_file, delimiter=",")
167
+ for row in csv_out_reader:
168
+ rows.append(row)
169
+
170
+ # Image names left in CSV.
171
+ image_names_in_csv = []
172
+
173
+ # Re-write the CSV removing lines without corresponding images.
174
+ with open(csv_out_path, "w") as csv_out_file:
175
+ csv_out_writer = csv.writer(
176
+ csv_out_file, delimiter=",", quoting=csv.QUOTE_MINIMAL
177
+ )
178
+ for row in rows:
179
+ image_name = row[0]
180
+ image_path = os.path.join(images_out_folder, image_name)
181
+ if os.path.exists(image_path):
182
+ image_names_in_csv.append(image_name)
183
+ csv_out_writer.writerow(row)
184
+ elif print_removed_items:
185
+ print("Removed image from CSV: ", image_path)
186
+
187
+ # Remove images without corresponding line in CSV.
188
+ for image_name in os.listdir(images_out_folder):
189
+ if image_name not in image_names_in_csv:
190
+ image_path = os.path.join(images_out_folder, image_name)
191
+ os.remove(image_path)
192
+ if print_removed_items:
193
+ print("Removed image from folder: ", image_path)
194
+
195
+ def analyze_outliers(self, outliers):
196
+ """Classifies each sample against all other to find outliers.
197
+
198
+ If sample is classified differrently than the original class - it should
199
+ either be deleted or more similar samples should be added.
200
+ """
201
+ for outlier in outliers:
202
+ image_path = os.path.join(
203
+ self._images_out_folder, outlier.sample.class_name, outlier.sample.name
204
+ )
205
+
206
+ print("Outlier")
207
+ print(" sample path = ", image_path)
208
+ print(" sample class = ", outlier.sample.class_name)
209
+ print(" detected class = ", outlier.detected_class)
210
+ print(" all classes = ", outlier.all_classes)
211
+
212
+ img = cv2.imread(image_path)
213
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
214
+ show_image(img, figsize=(20, 20))
215
+
216
+ def remove_outliers(self, outliers):
217
+ """Removes outliers from the image folders."""
218
+ for outlier in outliers:
219
+ image_path = os.path.join(
220
+ self._images_out_folder, outlier.sample.class_name, outlier.sample.name
221
+ )
222
+ os.remove(image_path)
223
+
224
+ def print_images_in_statistics(self):
225
+ """Prints statistics from the input image folder."""
226
+ self._print_images_statistics(self._images_in_folder, self._pose_class_names)
227
+
228
+ def print_images_out_statistics(self):
229
+ """Prints statistics from the output image folder."""
230
+ self._print_images_statistics(self._images_out_folder, self._pose_class_names)
231
+
232
+ def _print_images_statistics(self, images_folder, pose_class_names):
233
+ print("Number of images per pose class:")
234
+ for pose_class_name in pose_class_names:
235
+ n_images = len(
236
+ [
237
+ n
238
+ for n in os.listdir(os.path.join(images_folder, pose_class_name))
239
+ if not n.startswith(".")
240
+ ]
241
+ )
242
+ print(" {}: {}".format(pose_class_name, n_images))
PoseClassification/pose_classifier.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os, csv
3
+
4
+
5
+ class PoseSample(object):
6
+ def __init__(self, name, landmarks, class_name, embedding):
7
+ self.name = name
8
+ self.landmarks = landmarks
9
+ self.class_name = class_name
10
+ self.embedding = embedding
11
+
12
+
13
+ class PoseSampleOutlier(object):
14
+ def __init__(self, sample, detected_class, all_classes):
15
+ self.sample = sample
16
+ self.detected_class = detected_class
17
+ self.all_classes = all_classes
18
+
19
+
20
+ class PoseClassifier(object):
21
+ """Classifies pose landmarks."""
22
+
23
+ def __init__(
24
+ self,
25
+ pose_samples_folder,
26
+ pose_embedder,
27
+ file_extension="csv",
28
+ file_separator=",",
29
+ n_landmarks=33,
30
+ n_dimensions=3,
31
+ top_n_by_max_distance=30,
32
+ top_n_by_mean_distance=10,
33
+ axes_weights=(1.0, 1.0, 0.2),
34
+ ):
35
+ self._pose_embedder = pose_embedder
36
+ self._n_landmarks = n_landmarks
37
+ self._n_dimensions = n_dimensions
38
+ self._top_n_by_max_distance = top_n_by_max_distance
39
+ self._top_n_by_mean_distance = top_n_by_mean_distance
40
+ self._axes_weights = axes_weights
41
+
42
+ self._pose_samples = self._load_pose_samples(
43
+ pose_samples_folder,
44
+ file_extension,
45
+ file_separator,
46
+ n_landmarks,
47
+ n_dimensions,
48
+ pose_embedder,
49
+ )
50
+
51
+ def _load_pose_samples(
52
+ self,
53
+ pose_samples_folder,
54
+ file_extension,
55
+ file_separator,
56
+ n_landmarks,
57
+ n_dimensions,
58
+ pose_embedder,
59
+ ):
60
+ """Loads pose samples from a given folder.
61
+
62
+ Required folder structure:
63
+ neutral_standing.csv
64
+ pushups_down.csv
65
+ pushups_up.csv
66
+ squats_down.csv
67
+ ...
68
+
69
+ Required CSV structure:
70
+ sample_00001,x1,y1,z1,x2,y2,z2,....
71
+ sample_00002,x1,y1,z1,x2,y2,z2,....
72
+ ...
73
+ """
74
+ # Each file in the folder represents one pose class.
75
+ file_names = [
76
+ name
77
+ for name in os.listdir(pose_samples_folder)
78
+ if name.endswith(file_extension)
79
+ ]
80
+
81
+ pose_samples = []
82
+ for file_name in file_names:
83
+ # Use file name as pose class name.
84
+ class_name = file_name[: -(len(file_extension) + 1)]
85
+
86
+ # Parse CSV.
87
+ with open(os.path.join(pose_samples_folder, file_name)) as csv_file:
88
+ csv_reader = csv.reader(csv_file, delimiter=file_separator)
89
+ for row in csv_reader:
90
+ assert (
91
+ len(row) == n_landmarks * n_dimensions + 1
92
+ ), "Wrong number of values: {}".format(len(row))
93
+ landmarks = np.array(row[1:], np.float32).reshape(
94
+ [n_landmarks, n_dimensions]
95
+ )
96
+ pose_samples.append(
97
+ PoseSample(
98
+ name=row[0],
99
+ landmarks=landmarks,
100
+ class_name=class_name,
101
+ embedding=pose_embedder(landmarks),
102
+ )
103
+ )
104
+
105
+ return pose_samples
106
+
107
+ def find_pose_sample_outliers(self):
108
+ """Classifies each sample against the entire database."""
109
+ # Find outliers in target poses
110
+ outliers = []
111
+ for sample in self._pose_samples:
112
+ # Find nearest poses for the target one.
113
+ pose_landmarks = sample.landmarks.copy()
114
+ pose_classification = self.__call__(pose_landmarks)
115
+ class_names = [
116
+ class_name
117
+ for class_name, count in pose_classification.items()
118
+ if count == max(pose_classification.values())
119
+ ]
120
+
121
+ # Sample is an outlier if nearest poses have different class or more than
122
+ # one pose class is detected as nearest.
123
+ if sample.class_name not in class_names or len(class_names) != 1:
124
+ outliers.append(
125
+ PoseSampleOutlier(sample, class_names, pose_classification)
126
+ )
127
+
128
+ return outliers
129
+
130
+ def __call__(self, pose_landmarks):
131
+ """Classifies given pose.
132
+
133
+ Classification is done in two stages:
134
+ * First we pick top-N samples by MAX distance. It allows to remove samples
135
+ that are almost the same as given pose, but has few joints bent in the
136
+ other direction.
137
+ * Then we pick top-N samples by MEAN distance. After outliers are removed
138
+ on a previous step, we can pick samples that are closes on average.
139
+
140
+ Args:
141
+ pose_landmarks: NumPy array with 3D landmarks of shape (N, 3).
142
+
143
+ Returns:
144
+ Dictionary with count of nearest pose samples from the database. Sample:
145
+ {
146
+ 'pushups_down': 8,
147
+ 'pushups_up': 2,
148
+ }
149
+ """
150
+ # Check that provided and target poses have the same shape.
151
+ assert pose_landmarks.shape == (
152
+ self._n_landmarks,
153
+ self._n_dimensions,
154
+ ), "Unexpected shape: {}".format(pose_landmarks.shape)
155
+
156
+ # Get given pose embedding.
157
+ pose_embedding = self._pose_embedder(pose_landmarks)
158
+ flipped_pose_embedding = self._pose_embedder(
159
+ pose_landmarks * np.array([-1, 1, 1])
160
+ )
161
+
162
+ # Filter by max distance.
163
+ #
164
+ # That helps to remove outliers - poses that are almost the same as the
165
+ # given one, but has one joint bent into another direction and actually
166
+ # represnt a different pose class.
167
+ max_dist_heap = []
168
+ for sample_idx, sample in enumerate(self._pose_samples):
169
+ max_dist = min(
170
+ np.max(np.abs(sample.embedding - pose_embedding) * self._axes_weights),
171
+ np.max(
172
+ np.abs(sample.embedding - flipped_pose_embedding)
173
+ * self._axes_weights
174
+ ),
175
+ )
176
+ max_dist_heap.append([max_dist, sample_idx])
177
+
178
+ max_dist_heap = sorted(max_dist_heap, key=lambda x: x[0])
179
+ max_dist_heap = max_dist_heap[: self._top_n_by_max_distance]
180
+
181
+ # Filter by mean distance.
182
+ #
183
+ # After removing outliers we can find the nearest pose by mean distance.
184
+ mean_dist_heap = []
185
+ for _, sample_idx in max_dist_heap:
186
+ sample = self._pose_samples[sample_idx]
187
+ mean_dist = min(
188
+ np.mean(np.abs(sample.embedding - pose_embedding) * self._axes_weights),
189
+ np.mean(
190
+ np.abs(sample.embedding - flipped_pose_embedding)
191
+ * self._axes_weights
192
+ ),
193
+ )
194
+ mean_dist_heap.append([mean_dist, sample_idx])
195
+
196
+ mean_dist_heap = sorted(mean_dist_heap, key=lambda x: x[0])
197
+ mean_dist_heap = mean_dist_heap[: self._top_n_by_mean_distance]
198
+
199
+ # Collect results into map: (class_name -> n_samples)
200
+ class_names = [
201
+ self._pose_samples[sample_idx].class_name
202
+ for _, sample_idx in mean_dist_heap
203
+ ]
204
+ result = {
205
+ class_name: class_names.count(class_name) for class_name in set(class_names)
206
+ }
207
+
208
+ return result
PoseClassification/pose_embedding.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ class FullBodyPoseEmbedding(object):
5
+ """Converts 3D pose landmarks into 3D embedding."""
6
+
7
+ def __init__(self, torso_size_multiplier=2.5):
8
+ # Multiplier to apply to the torso to get minimal body size.
9
+ self._torso_size_multiplier = torso_size_multiplier
10
+
11
+ # Names of the landmarks as they appear in the prediction.
12
+ self._landmark_names = [
13
+ "nose",
14
+ "left_eye_inner",
15
+ "left_eye",
16
+ "left_eye_outer",
17
+ "right_eye_inner",
18
+ "right_eye",
19
+ "right_eye_outer",
20
+ "left_ear",
21
+ "right_ear",
22
+ "mouth_left",
23
+ "mouth_right",
24
+ "left_shoulder",
25
+ "right_shoulder",
26
+ "left_elbow",
27
+ "right_elbow",
28
+ "left_wrist",
29
+ "right_wrist",
30
+ "left_pinky_1",
31
+ "right_pinky_1",
32
+ "left_index_1",
33
+ "right_index_1",
34
+ "left_thumb_2",
35
+ "right_thumb_2",
36
+ "left_hip",
37
+ "right_hip",
38
+ "left_knee",
39
+ "right_knee",
40
+ "left_ankle",
41
+ "right_ankle",
42
+ "left_heel",
43
+ "right_heel",
44
+ "left_foot_index",
45
+ "right_foot_index",
46
+ ]
47
+
48
+ def __call__(self, landmarks):
49
+ """Normalizes pose landmarks and converts to embedding
50
+
51
+ Args:
52
+ landmarks - NumPy array with 3D landmarks of shape (N, 3).
53
+
54
+ Result:
55
+ Numpy array with pose embedding of shape (M, 3) where `M` is the number of
56
+ pairwise distances defined in `_get_pose_distance_embedding`.
57
+ """
58
+ assert landmarks.shape[0] == len(
59
+ self._landmark_names
60
+ ), "Unexpected number of landmarks: {}".format(landmarks.shape[0])
61
+
62
+ # Get pose landmarks.
63
+ landmarks = np.copy(landmarks)
64
+
65
+ # Normalize landmarks.
66
+ landmarks = self._normalize_pose_landmarks(landmarks)
67
+
68
+ # Get embedding.
69
+ embedding = self._get_pose_distance_embedding(landmarks)
70
+
71
+ return embedding
72
+
73
+ def _normalize_pose_landmarks(self, landmarks):
74
+ """Normalizes landmarks translation and scale."""
75
+ landmarks = np.copy(landmarks)
76
+
77
+ # Normalize translation.
78
+ pose_center = self._get_pose_center(landmarks)
79
+ landmarks -= pose_center
80
+
81
+ # Normalize scale.
82
+ pose_size = self._get_pose_size(landmarks, self._torso_size_multiplier)
83
+ landmarks /= pose_size
84
+ # Multiplication by 100 is not required, but makes it eaasier to debug.
85
+ landmarks *= 100
86
+
87
+ return landmarks
88
+
89
+ def _get_pose_center(self, landmarks):
90
+ """Calculates pose center as point between hips."""
91
+ left_hip = landmarks[self._landmark_names.index("left_hip")]
92
+ right_hip = landmarks[self._landmark_names.index("right_hip")]
93
+ center = (left_hip + right_hip) * 0.5
94
+ return center
95
+
96
+ def _get_pose_size(self, landmarks, torso_size_multiplier):
97
+ """Calculates pose size.
98
+
99
+ It is the maximum of two values:
100
+ * Torso size multiplied by `torso_size_multiplier`
101
+ * Maximum distance from pose center to any pose landmark
102
+ """
103
+ # This approach uses only 2D landmarks to compute pose size.
104
+ landmarks = landmarks[:, :2]
105
+
106
+ # Hips center.
107
+ left_hip = landmarks[self._landmark_names.index("left_hip")]
108
+ right_hip = landmarks[self._landmark_names.index("right_hip")]
109
+ hips = (left_hip + right_hip) * 0.5
110
+
111
+ # Shoulders center.
112
+ left_shoulder = landmarks[self._landmark_names.index("left_shoulder")]
113
+ right_shoulder = landmarks[self._landmark_names.index("right_shoulder")]
114
+ shoulders = (left_shoulder + right_shoulder) * 0.5
115
+
116
+ # Torso size as the minimum body size.
117
+ torso_size = np.linalg.norm(shoulders - hips)
118
+
119
+ # Max dist to pose center.
120
+ pose_center = self._get_pose_center(landmarks)
121
+ max_dist = np.max(np.linalg.norm(landmarks - pose_center, axis=1))
122
+
123
+ return max(torso_size * torso_size_multiplier, max_dist)
124
+
125
+ def _get_pose_distance_embedding(self, landmarks):
126
+ """Converts pose landmarks into 3D embedding.
127
+
128
+ We use several pairwise 3D distances to form pose embedding. All distances
129
+ include X and Y components with sign. We differnt types of pairs to cover
130
+ different pose classes. Feel free to remove some or add new.
131
+
132
+ Args:
133
+ landmarks - NumPy array with 3D landmarks of shape (N, 3).
134
+
135
+ Result:
136
+ Numpy array with pose embedding of shape (M, 3) where `M` is the number of
137
+ pairwise distances.
138
+ """
139
+ embedding = np.array(
140
+ [
141
+ # One joint.
142
+ self._get_distance(
143
+ self._get_average_by_names(landmarks, "left_hip", "right_hip"),
144
+ self._get_average_by_names(
145
+ landmarks, "left_shoulder", "right_shoulder"
146
+ ),
147
+ ),
148
+ self._get_distance_by_names(landmarks, "left_shoulder", "left_elbow"),
149
+ self._get_distance_by_names(landmarks, "right_shoulder", "right_elbow"),
150
+ self._get_distance_by_names(landmarks, "left_elbow", "left_wrist"),
151
+ self._get_distance_by_names(landmarks, "right_elbow", "right_wrist"),
152
+ self._get_distance_by_names(landmarks, "left_hip", "left_knee"),
153
+ self._get_distance_by_names(landmarks, "right_hip", "right_knee"),
154
+ self._get_distance_by_names(landmarks, "left_knee", "left_ankle"),
155
+ self._get_distance_by_names(landmarks, "right_knee", "right_ankle"),
156
+ # Two joints.
157
+ self._get_distance_by_names(landmarks, "left_shoulder", "left_wrist"),
158
+ self._get_distance_by_names(landmarks, "right_shoulder", "right_wrist"),
159
+ self._get_distance_by_names(landmarks, "left_hip", "left_ankle"),
160
+ self._get_distance_by_names(landmarks, "right_hip", "right_ankle"),
161
+ # Four joints.
162
+ self._get_distance_by_names(landmarks, "left_hip", "left_wrist"),
163
+ self._get_distance_by_names(landmarks, "right_hip", "right_wrist"),
164
+ # Five joints.
165
+ self._get_distance_by_names(landmarks, "left_shoulder", "left_ankle"),
166
+ self._get_distance_by_names(landmarks, "right_shoulder", "right_ankle"),
167
+ self._get_distance_by_names(landmarks, "left_hip", "left_wrist"),
168
+ self._get_distance_by_names(landmarks, "right_hip", "right_wrist"),
169
+ # Cross body.
170
+ self._get_distance_by_names(landmarks, "left_elbow", "right_elbow"),
171
+ self._get_distance_by_names(landmarks, "left_knee", "right_knee"),
172
+ self._get_distance_by_names(landmarks, "left_wrist", "right_wrist"),
173
+ self._get_distance_by_names(landmarks, "left_ankle", "right_ankle"),
174
+ # Body bent direction.
175
+ self._get_distance(
176
+ self._get_average_by_names(landmarks, 'left_wrist', 'left_ankle'),
177
+ landmarks[self._landmark_names.index('left_hip')]),
178
+ self._get_distance(
179
+ self._get_average_by_names(landmarks, 'right_wrist', 'right_ankle'),
180
+ landmarks[self._landmark_names.index('right_hip')]),
181
+ # Angle between landmarks - cf https://www.kaggle.com/code/venkatkumar001/yoga-pose-recognition-mediapipe
182
+ # self._calculateAngle(landmarks, "left_hip", "left_knee", "left_ankle"),
183
+ # self._calculateAngle(landmarks, "right_hip", "right_knee", "right_ankle"),
184
+ # self._calculateAngle(landmarks, "left_shoulder", "left_elbow", "left_wrist"),
185
+ # self._calculateAngle(landmarks, "right_shoulder", "right_elbow", "right_wrist")
186
+
187
+ ]
188
+ )
189
+ # print(embedding)
190
+ # print(embbeding.shape)
191
+ # print(type(embedding))
192
+ # print(type(landmarks[self._landmark_names.index('right_hip')]))
193
+ # print(landmarks[self._landmark_names.index('right_hip')])
194
+ return embedding
195
+
196
+ def _get_average_by_names(self, landmarks, name_from, name_to):
197
+ lmk_from = landmarks[self._landmark_names.index(name_from)]
198
+ lmk_to = landmarks[self._landmark_names.index(name_to)]
199
+ return (lmk_from + lmk_to) * 0.5
200
+
201
+ def _get_distance_by_names(self, landmarks, name_from, name_to):
202
+ lmk_from = landmarks[self._landmark_names.index(name_from)]
203
+ lmk_to = landmarks[self._landmark_names.index(name_to)]
204
+ return self._get_distance(lmk_from, lmk_to)
205
+
206
+ def _get_distance(self, lmk_from, lmk_to):
207
+ return lmk_to - lmk_from
208
+
209
+ def _calculateAngle(self, landmarks, name1, name2, name3):
210
+ '''
211
+ This function calculates angle between three different landmarks.
212
+ Args:
213
+ landmark1: The first landmark containing the x,y and z coordinates.
214
+ landmark2: The second landmark containing the x,y and z coordinates.
215
+ landmark3: The third landmark containing the x,y and z coordinates.
216
+ Returns:
217
+ angle: The calculated angle between the three landmarks.
218
+
219
+ cf https://www.kaggle.com/code/venkatkumar001/yoga-pose-recognition-mediapipe
220
+ '''
221
+
222
+ # Get the required landmarks coordinates.
223
+ x1, y1, _ = landmarks[self._landmark_names.index(name1)]
224
+ x2, y2, _ = landmarks[self._landmark_names.index(name2)]
225
+ x3, y3, _ = landmarks[self._landmark_names.index(name3)]
226
+
227
+ # Calculate the angle between the three points
228
+ angle = math.degrees(math.atan2(y3 - y2, x3 - x2) - math.atan2(y1 - y2, x1 - x2))
229
+
230
+ # Check if the angle is less than zero.
231
+ if angle < 0:
232
+
233
+ # Add 360 to the found angle.
234
+ angle += 360
235
+
236
+ # Return the calculated angle.
237
+ return angle
PoseClassification/pose_embedding_2.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ class FullBodyPoseEmbedding(object):
5
+ """Converts 3D pose landmarks into 3D embedding."""
6
+
7
+ def __init__(self, torso_size_multiplier=2.5):
8
+ # Multiplier to apply to the torso to get minimal body size.
9
+ self._torso_size_multiplier = torso_size_multiplier
10
+
11
+ # Names of the landmarks as they appear in the prediction.
12
+ self._landmark_names = [
13
+ "nose",
14
+ "left_eye_inner",
15
+ "left_eye",
16
+ "left_eye_outer",
17
+ "right_eye_inner",
18
+ "right_eye",
19
+ "right_eye_outer",
20
+ "left_ear",
21
+ "right_ear",
22
+ "mouth_left",
23
+ "mouth_right",
24
+ "left_shoulder",
25
+ "right_shoulder",
26
+ "left_elbow",
27
+ "right_elbow",
28
+ "left_wrist",
29
+ "right_wrist",
30
+ "left_pinky_1",
31
+ "right_pinky_1",
32
+ "left_index_1",
33
+ "right_index_1",
34
+ "left_thumb_2",
35
+ "right_thumb_2",
36
+ "left_hip",
37
+ "right_hip",
38
+ "left_knee",
39
+ "right_knee",
40
+ "left_ankle",
41
+ "right_ankle",
42
+ "left_heel",
43
+ "right_heel",
44
+ "left_foot_index",
45
+ "right_foot_index",
46
+ ]
47
+
48
+ def __call__(self, landmarks):
49
+ """Normalizes pose landmarks and converts to embedding
50
+
51
+ Args:
52
+ landmarks - NumPy array with 3D landmarks of shape (N, 3).
53
+
54
+ Result:
55
+ Numpy array with pose embedding of shape (M, 3) where `M` is the number of
56
+ pairwise distances defined in `_get_pose_distance_embedding`.
57
+ """
58
+ assert landmarks.shape[0] == len(
59
+ self._landmark_names
60
+ ), "Unexpected number of landmarks: {}".format(landmarks.shape[0])
61
+
62
+ # Get pose landmarks.
63
+ landmarks = np.copy(landmarks)
64
+
65
+ # Normalize landmarks.
66
+ landmarks = self._normalize_pose_landmarks(landmarks)
67
+
68
+ # Get embedding.
69
+ embedding = self._get_pose_distance_embedding(landmarks)
70
+
71
+ # Add angle embedding
72
+ embedding_angle = self._get_pose_angle_embedding(landmarks)
73
+
74
+ assert embedding.shape == embedding_angle.shape, f"Error in embeddings shape : distance embed {embedding.shape} and angle {embedding_angle.shape}"
75
+
76
+ return embedding
77
+
78
+ def _normalize_pose_landmarks(self, landmarks):
79
+ """Normalizes landmarks translation and scale."""
80
+ landmarks = np.copy(landmarks)
81
+
82
+ # Normalize translation.
83
+ pose_center = self._get_pose_center(landmarks)
84
+ landmarks -= pose_center
85
+
86
+ # Normalize scale.
87
+ pose_size = self._get_pose_size(landmarks, self._torso_size_multiplier)
88
+ landmarks /= pose_size
89
+ # Multiplication by 100 is not required, but makes it eaasier to debug.
90
+ landmarks *= 100
91
+
92
+ return landmarks
93
+
94
+ def _get_pose_center(self, landmarks):
95
+ """Calculates pose center as point between hips."""
96
+ left_hip = landmarks[self._landmark_names.index("left_hip")]
97
+ right_hip = landmarks[self._landmark_names.index("right_hip")]
98
+ center = (left_hip + right_hip) * 0.5
99
+ return center
100
+
101
+ def _get_pose_size(self, landmarks, torso_size_multiplier):
102
+ """Calculates pose size.
103
+
104
+ It is the maximum of two values:
105
+ * Torso size multiplied by `torso_size_multiplier`
106
+ * Maximum distance from pose center to any pose landmark
107
+ """
108
+ # This approach uses only 2D landmarks to compute pose size.
109
+ landmarks = landmarks[:, :2]
110
+
111
+ # Hips center.
112
+ left_hip = landmarks[self._landmark_names.index("left_hip")]
113
+ right_hip = landmarks[self._landmark_names.index("right_hip")]
114
+ hips = (left_hip + right_hip) * 0.5
115
+
116
+ # Shoulders center.
117
+ left_shoulder = landmarks[self._landmark_names.index("left_shoulder")]
118
+ right_shoulder = landmarks[self._landmark_names.index("right_shoulder")]
119
+ shoulders = (left_shoulder + right_shoulder) * 0.5
120
+
121
+ # Torso size as the minimum body size.
122
+ torso_size = np.linalg.norm(shoulders - hips)
123
+
124
+ # Max dist to pose center.
125
+ pose_center = self._get_pose_center(landmarks)
126
+ max_dist = np.max(np.linalg.norm(landmarks - pose_center, axis=1))
127
+
128
+ return max(torso_size * torso_size_multiplier, max_dist)
129
+
130
+ def _get_pose_distance_embedding(self, landmarks):
131
+ """Converts pose landmarks into 3D embedding.
132
+
133
+ We use several pairwise 3D distances to form pose embedding. All distances
134
+ include X and Y components with sign. We differnt types of pairs to cover
135
+ different pose classes. Feel free to remove some or add new.
136
+
137
+ Args:
138
+ landmarks - NumPy array with 3D landmarks of shape (N, 3).
139
+
140
+ Result:
141
+ Numpy array with pose embedding of shape (M, 3) where `M` is the number of
142
+ pairwise distances.
143
+ """
144
+ embedding = np.array(
145
+ [
146
+ # One joint.
147
+ self._get_distance(
148
+ self._get_average_by_names(landmarks, "left_hip", "right_hip"),
149
+ self._get_average_by_names(
150
+ landmarks, "left_shoulder", "right_shoulder"
151
+ ),
152
+ ),
153
+ self._get_distance_by_names(landmarks, "left_shoulder", "left_elbow"),
154
+ self._get_distance_by_names(landmarks, "right_shoulder", "right_elbow"),
155
+ self._get_distance_by_names(landmarks, "left_elbow", "left_wrist"),
156
+ self._get_distance_by_names(landmarks, "right_elbow", "right_wrist"),
157
+ self._get_distance_by_names(landmarks, "left_hip", "left_knee"),
158
+ self._get_distance_by_names(landmarks, "right_hip", "right_knee"),
159
+ self._get_distance_by_names(landmarks, "left_knee", "left_ankle"),
160
+ self._get_distance_by_names(landmarks, "right_knee", "right_ankle"),
161
+ # Two joints.
162
+ self._get_distance_by_names(landmarks, "left_shoulder", "left_wrist"),
163
+ self._get_distance_by_names(landmarks, "right_shoulder", "right_wrist"),
164
+ self._get_distance_by_names(landmarks, "left_hip", "left_ankle"),
165
+ self._get_distance_by_names(landmarks, "right_hip", "right_ankle"),
166
+ # Four joints.
167
+ self._get_distance_by_names(landmarks, "left_hip", "left_wrist"),
168
+ self._get_distance_by_names(landmarks, "right_hip", "right_wrist"),
169
+ # Five joints.
170
+ self._get_distance_by_names(landmarks, "left_shoulder", "left_ankle"),
171
+ self._get_distance_by_names(landmarks, "right_shoulder", "right_ankle"),
172
+ self._get_distance_by_names(landmarks, "left_hip", "left_wrist"),
173
+ self._get_distance_by_names(landmarks, "right_hip", "right_wrist"),
174
+ # Cross body.
175
+ self._get_distance_by_names(landmarks, "left_elbow", "right_elbow"),
176
+ self._get_distance_by_names(landmarks, "left_knee", "right_knee"),
177
+ self._get_distance_by_names(landmarks, "left_wrist", "right_wrist"),
178
+ self._get_distance_by_names(landmarks, "left_ankle", "right_ankle"),
179
+ # Body bent direction.
180
+ self._get_distance(
181
+ self._get_average_by_names(landmarks, 'left_wrist', 'left_ankle'),
182
+ landmarks[self._landmark_names.index('left_hip')]),
183
+ self._get_distance(
184
+ self._get_average_by_names(landmarks, 'right_wrist', 'right_ankle'),
185
+ landmarks[self._landmark_names.index('right_hip')])
186
+
187
+ ]
188
+ )
189
+ # print(embedding)
190
+ # print(embbeding.shape)
191
+ # print(type(embedding))
192
+ # print(type(landmarks[self._landmark_names.index('right_hip')]))
193
+ # print(landmarks[self._landmark_names.index('right_hip')])
194
+ return embedding
195
+
196
+ def _get_average_by_names(self, landmarks, name_from, name_to):
197
+ lmk_from = landmarks[self._landmark_names.index(name_from)]
198
+ lmk_to = landmarks[self._landmark_names.index(name_to)]
199
+ return (lmk_from + lmk_to) * 0.5
200
+
201
+ def _get_distance_by_names(self, landmarks, name_from, name_to):
202
+ lmk_from = landmarks[self._landmark_names.index(name_from)]
203
+ lmk_to = landmarks[self._landmark_names.index(name_to)]
204
+ return self._get_distance(lmk_from, lmk_to)
205
+
206
+ def _get_distance(self, lmk_from, lmk_to):
207
+ return lmk_to - lmk_from
208
+
209
+ def _get_pose_angle_embedding(self, landmarks):
210
+ embedding = [
211
+ # Angle between landmarks - cf https://www.kaggle.com/code/venkatkumar001/yoga-pose-recognition-mediapipe
212
+ self._calculateAngle(landmarks, "left_hip", "left_knee", "left_ankle"),
213
+ self._calculateAngle(landmarks, "right_hip", "right_knee", "right_ankle"),
214
+ self._calculateAngle(landmarks, "left_shoulder", "left_elbow", "left_wrist"),
215
+ self._calculateAngle(landmarks, "right_shoulder", "right_elbow", "right_wrist")
216
+ ]
217
+ return embedding
218
+
219
+ def _calculateAngle(self, landmarks, name1, name2, name3):
220
+ '''
221
+ This function calculates angle between three different landmarks.
222
+ Args:
223
+ landmark1: The first landmark containing the x,y and z coordinates.
224
+ landmark2: The second landmark containing the x,y and z coordinates.
225
+ landmark3: The third landmark containing the x,y and z coordinates.
226
+ Returns:
227
+ angle: The calculated angle between the three landmarks.
228
+
229
+ cf https://www.kaggle.com/code/venkatkumar001/yoga-pose-recognition-mediapipe
230
+ '''
231
+ # Get the required landmarks coordinates.
232
+ x1, y1, _ = landmarks[self._landmark_names.index(name1)]
233
+ x2, y2, _ = landmarks[self._landmark_names.index(name2)]
234
+ x3, y3, _ = landmarks[self._landmark_names.index(name3)]
235
+
236
+ # Calculate the angle between the three points
237
+ angle = math.degrees(math.atan2(y3 - y2, x3 - x2) - math.atan2(y1 - y2, x1 - x2))
238
+
239
+ # Check if the angle is less than zero.
240
+ if angle < 0:
241
+
242
+ # Add 360 to the found angle.
243
+ angle += 360
244
+
245
+ # Return the calculated angle.
246
+ return angle
PoseClassification/utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+
4
+
5
+ def show_image(img, figsize=(10, 10)):
6
+ """Shows output PIL image."""
7
+ plt.figure(figsize=figsize)
8
+ plt.imshow(img)
9
+ plt.show()
10
+
11
+
12
+ class EMADictSmoothing(object):
13
+ """Smoothes pose classification. Exponential moving average (EMA)."""
14
+
15
+ def __init__(self, window_size=10, alpha=0.2):
16
+ self._window_size = window_size
17
+ self._alpha = alpha
18
+
19
+ self._data_in_window = []
20
+
21
+ def __call__(self, data):
22
+ """Smoothes given pose classification.
23
+
24
+ Smoothing is done by computing Exponential Moving Average for every pose
25
+ class observed in the given time window. Missed pose classes arre replaced
26
+ with 0.
27
+
28
+ Args:
29
+ data: Dictionary with pose classification. Sample:
30
+ {
31
+ 'pushups_down': 8,
32
+ 'pushups_up': 2,
33
+ }
34
+
35
+ Result:
36
+ Dictionary in the same format but with smoothed and float instead of
37
+ integer values. Sample:
38
+ {
39
+ 'pushups_down': 8.3,
40
+ 'pushups_up': 1.7,
41
+ }
42
+ """
43
+ # Add new data to the beginning of the window for simpler code.
44
+ self._data_in_window.insert(0, data)
45
+ self._data_in_window = self._data_in_window[: self._window_size]
46
+
47
+ # Get all keys.
48
+ keys = set([key for data in self._data_in_window for key, _ in data.items()])
49
+
50
+ # Get smoothed values.
51
+ smoothed_data = dict()
52
+ for key in keys:
53
+ factor = 1.0
54
+ top_sum = 0.0
55
+ bottom_sum = 0.0
56
+ for data in self._data_in_window:
57
+ value = data[key] if key in data else 0.0
58
+
59
+ top_sum += factor * value
60
+ bottom_sum += factor
61
+
62
+ # Update factor.
63
+ factor *= 1.0 - self._alpha
64
+
65
+ smoothed_data[key] = top_sum / bottom_sum
66
+
67
+ return smoothed_data
68
+
69
+
70
+ class RepetitionCounter(object):
71
+ """Counts number of repetitions of given target pose class."""
72
+
73
+ def __init__(self, class_name, enter_threshold=6, exit_threshold=4):
74
+ self._class_name = class_name
75
+
76
+ # If pose counter passes given threshold, then we enter the pose.
77
+ self._enter_threshold = enter_threshold
78
+ self._exit_threshold = exit_threshold
79
+
80
+ # Either we are in given pose or not.
81
+ self._pose_entered = False
82
+
83
+ # Number of times we exited the pose.
84
+ self._n_repeats = 0
85
+
86
+ @property
87
+ def n_repeats(self):
88
+ return self._n_repeats
89
+
90
+ def reset(self):
91
+ self._n_repeats = 0
92
+
93
+ def __call__(self, pose_classification):
94
+ """Counts number of repetitions happend until given frame.
95
+
96
+ We use two thresholds. First you need to go above the higher one to enter
97
+ the pose, and then you need to go below the lower one to exit it. Difference
98
+ between the thresholds makes it stable to prediction jittering (which will
99
+ cause wrong counts in case of having only one threshold).
100
+
101
+ Args:
102
+ pose_classification: Pose classification dictionary on current frame.
103
+ Sample:
104
+ {
105
+ 'pushups_down': 8.3,
106
+ 'pushups_up': 1.7,
107
+ }
108
+
109
+ Returns:
110
+ Integer counter of repetitions.
111
+ """
112
+ # Get pose confidence.
113
+ pose_confidence = 0.0
114
+ if self._class_name in pose_classification:
115
+ pose_confidence = pose_classification[self._class_name]
116
+
117
+ # On the very first frame or if we were out of the pose, just check if we
118
+ # entered it on this frame and update the state.
119
+ if not self._pose_entered:
120
+ self._pose_entered = pose_confidence > self._enter_threshold
121
+ return self._n_repeats
122
+
123
+ # If we were in the pose and are exiting it, then increase the counter and
124
+ # update the state.
125
+ if pose_confidence < self._exit_threshold:
126
+ self._n_repeats += 1
127
+ self._pose_entered = False
128
+
129
+ return self._n_repeats
PoseClassification/visualize.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from PIL import Image, ImageFont, ImageDraw
3
+ import requests
4
+ import matplotlib.pyplot as plt
5
+
6
+
7
+ class PoseClassificationVisualizer(object):
8
+ """Keeps track of classifcations for every frame and renders them."""
9
+
10
+ def __init__(
11
+ self,
12
+ class_name,
13
+ plot_location_x=0.05,
14
+ plot_location_y=0.05,
15
+ plot_max_width=0.4,
16
+ plot_max_height=0.4,
17
+ plot_figsize=(9, 4),
18
+ plot_x_max=None,
19
+ plot_y_max=None,
20
+ counter_location_x=0.85,
21
+ counter_location_y=0.05,
22
+ counter_font_path="https://github.com/googlefonts/roboto/blob/main/src/hinted/Roboto-Regular.ttf?raw=true",
23
+ counter_font_color="red",
24
+ counter_font_size=0.15,
25
+ ):
26
+ self._class_name = class_name
27
+ self._plot_location_x = plot_location_x
28
+ self._plot_location_y = plot_location_y
29
+ self._plot_max_width = plot_max_width
30
+ self._plot_max_height = plot_max_height
31
+ self._plot_figsize = plot_figsize
32
+ self._plot_x_max = plot_x_max
33
+ self._plot_y_max = plot_y_max
34
+ self._counter_location_x = counter_location_x
35
+ self._counter_location_y = counter_location_y
36
+ self._counter_font_path = counter_font_path
37
+ self._counter_font_color = counter_font_color
38
+ self._counter_font_size = counter_font_size
39
+
40
+ self._counter_font = None
41
+
42
+ self._pose_classification_history = []
43
+ self._pose_classification_filtered_history = []
44
+
45
+ def __call__(
46
+ self,
47
+ frame,
48
+ pose_classification,
49
+ pose_classification_filtered,
50
+ repetitions_count,
51
+ ):
52
+ """Renders pose classifcation and counter until given frame."""
53
+ # Extend classification history.
54
+ self._pose_classification_history.append(pose_classification)
55
+ self._pose_classification_filtered_history.append(pose_classification_filtered)
56
+
57
+ # Output frame with classification plot and counter.
58
+ output_img = Image.fromarray(frame)
59
+
60
+ output_width = output_img.size[0]
61
+ output_height = output_img.size[1]
62
+
63
+ # Draw the plot.
64
+ img = self._plot_classification_history(output_width, output_height)
65
+ img.thumbnail(
66
+ (
67
+ int(output_width * self._plot_max_width),
68
+ int(output_height * self._plot_max_height),
69
+ ),
70
+ Image.LANCZOS,
71
+ )
72
+ output_img.paste(
73
+ img,
74
+ (
75
+ int(output_width * self._plot_location_x),
76
+ int(output_height * self._plot_location_y),
77
+ ),
78
+ )
79
+
80
+ # Draw the count.
81
+ output_img_draw = ImageDraw.Draw(output_img)
82
+ if self._counter_font is None:
83
+ font_size = int(output_height * self._counter_font_size)
84
+ font_request = requests.get(self._counter_font_path, allow_redirects=True)
85
+ self._counter_font = ImageFont.truetype(
86
+ io.BytesIO(font_request.content), size=font_size
87
+ )
88
+ output_img_draw.text(
89
+ (
90
+ output_width * self._counter_location_x,
91
+ output_height * self._counter_location_y,
92
+ ),
93
+ str(repetitions_count),
94
+ font=self._counter_font,
95
+ fill=self._counter_font_color,
96
+ )
97
+
98
+ return output_img
99
+
100
+ def _plot_classification_history(self, output_width, output_height):
101
+ fig = plt.figure(figsize=self._plot_figsize)
102
+
103
+ for classification_history in [
104
+ self._pose_classification_history,
105
+ self._pose_classification_filtered_history,
106
+ ]:
107
+ y = []
108
+ for classification in classification_history:
109
+ if classification is None:
110
+ y.append(None)
111
+ elif self._class_name in classification:
112
+ y.append(classification[self._class_name])
113
+ else:
114
+ y.append(0)
115
+ plt.plot(y, linewidth=7)
116
+
117
+ plt.grid(axis="y", alpha=0.75)
118
+ plt.xlabel("Frame")
119
+ plt.ylabel("Confidence")
120
+ plt.title("Classification history for `{}`".format(self._class_name))
121
+ plt.legend(loc="upper right")
122
+
123
+ if self._plot_y_max is not None:
124
+ plt.ylim(top=self._plot_y_max)
125
+ if self._plot_x_max is not None:
126
+ plt.xlim(right=self._plot_x_max)
127
+
128
+ # Convert plot to image.
129
+ buf = io.BytesIO()
130
+ dpi = min(
131
+ output_width * self._plot_max_width / float(self._plot_figsize[0]),
132
+ output_height * self._plot_max_height / float(self._plot_figsize[1]),
133
+ )
134
+ fig.savefig(buf, dpi=dpi)
135
+ buf.seek(0)
136
+ img = Image.open(buf)
137
+ plt.close()
138
+
139
+ return img
README.md CHANGED
@@ -1,14 +1,191 @@
1
- ---
2
- title: YOGAI
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.3.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: yoga app
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Projet ACV-2
2
+
3
+ ## Exécution rapide
4
+
5
+ **Installation uv**
6
+
7
+ > curl -LsSf https://astral.sh/uv/install.sh | sh
8
+
9
+ > uv self update
10
+
11
+ **Installation et execution du projet**
12
+
13
+ > git clone git@github.com:LexouLam/projet-acv-2.git
14
+
15
+ > cd projet-acv-2
16
+
17
+ > uv venv --python 3.12
18
+
19
+ > uv sync
20
+
21
+ > source .venv/bin/activate
22
+
23
+ > uv run classify\_video.py live --display
24
+
25
+ OU
26
+
27
+ > uv run classify\_video.py data/videos/tree_vid_1.mp4 --display
28
+
29
+ ## Documentation
30
+
31
+ Team : Impredalam
32
+
33
+ INFO GLOBALE:
34
+
35
+ Gardez en tête que la branche “main” du projet ne doit jamais être bugée, le code qu’elle contient doit
36
+ toujours pouvoir s’exécuter (sauf bug non anticipé qui nécessitera un “hot-fix”).
37
+
38
+ ======
39
+ JOUR 1
40
+ ======
41
+
42
+ Le thème abordé est la détection et la classification de pose humaine dans le cadre d’une application de
43
+ sport à domicile.
44
+
45
+ ------------------------------------------------------------------------------
46
+ TODO JOUR 1:
47
+
48
+ ---DONE--- 1. Trouver un nom pour votre groupe et un nom pour le projet.
49
+ ---DONE--- 2. Mettre en place un dépôt git
50
+ ---DONE--- 3. Explorer la base de code déjà existante
51
+ - algorithme de détection/classification de poses,
52
+ - algorithme de comptage sur un flux vidéo.
53
+ ---DONE--- 4. Constituer une base de données annotées pour « entraîner » l’algorithme avec quelques images de vous faisant des pompes.
54
+ ---DONE--- 5. Préparer une vidéo démontrant la faisabilité d’un tel projet.
55
+ ---DONE--- 6. Optionnel J1 : une démo live + un repo git structuré sans notebook.
56
+
57
+ ------------------------------------------------------------------------------
58
+ 1. stand-up ---DONE---
59
+
60
+ 2. Prise en compte des exigences client suite à
61
+ la dernière livraison, ---DONE---
62
+
63
+ 3. tération de code ---DONE---
64
+
65
+ 4. 16h : livraison au client ---DONE---
66
+
67
+ 5. 17h : concours de pompe ---POSTPONED---
68
+
69
+ ------------------------------------------------------------------------------
70
+ RESULTAT:
71
+
72
+ Programme founctionnelle, qui détecte les pompes et les compte, formé avec les images d'Internet, nos propres photos, inversées horizontalement pour rendre l'ensemble de données plus grand et plus riche.
73
+
74
+ Acev un logo de notre équipe
75
+ ------------------------------------------------------------------------------
76
+
77
+ ======
78
+ JOUR 2
79
+ ======
80
+
81
+ Développer le vrai projet qui pourra être utilisé par la société.
82
+ Sujet: cours de yoga : classification des positions classiques
83
+
84
+ Le client veut un programme python exécutable en ligne de commande avec une interface simple.
85
+
86
+ ------------------------------------------------------------------------------
87
+ TODO JOUR 2:
88
+
89
+ ---DONE--- 1. Choix d’un sujet parmi les quatres proposés.
90
+ 2. Planification et répartition des tâches, structuration du projet git.
91
+ ---DONE--- 3. Constitution d’une base de données adaptée au sujet choisi (réalisée vous-même, ou pas ?).
92
+ 4. Sortir du notebook, script avec arguments.
93
+  https://docs.python.org/3/library/argparse.html
94
+ 5. Implémentation des options (prioritairement, la possibilité d’afficher des informations pour débugger le programme facilement).
95
+ 6. Documentation minimale pour lancer le programme.
96
+ 7. Optionnel J2 : packagisation poetry ou équivalent + (très optionnel) tests fonctionnels/unitaires.
97
+  https://github.com/features/actions
98
+
99
+ ------------------------------------------------------------------------------
100
+ 1. 9h : prise en main du sujet et gestion de projet (création/répartition des tâches)
101
+
102
+ 2. 9h45 : début du sprint de la journée
103
+
104
+ 3. 16h : livraison au client
105
+
106
+ ------------------------------------------------------------------------------
107
+ RESULTAT:
108
+
109
+ test
110
+ ------------------------------------------------------------------------------
111
+
112
+ ======
113
+ JOUR 3
114
+ ======
115
+
116
+ 1. stand-up,
117
+
118
+ 2. prise en compte des exigences client suite à
119
+ la dernière livraison,
120
+
121
+ 3. tération de code,
122
+
123
+ 4. livraison au client
124
+
125
+
126
+
127
+ ## Set-up environnement
128
+
129
+ **installation uv**
130
+
131
+ > curl -LsSf https://astral.sh/uv/install.sh | sh
132
+ > uv self update
133
+
134
+ **création environnement**
135
+
136
+ > mkdir projet_acv_2
137
+ > cd projet_acv_2/
138
+ > uv init
139
+ > uv venv --python 3.12
140
+ > uv add numpy matplotlib plotly jupyter opencv-python mediapipe
141
+ > uv add tqdm requests pillow scikit-learn
142
+
143
+ **création repo git si non créé**
144
+
145
+ > touch .gitignore
146
+ > git init
147
+ > git add .
148
+ > git commit -m "start repo"
149
+ > git remote add origin git@github.com:LexouLam/projet-acv-2.git
150
+ > git push --set-upstream origin master
151
+ > git push
152
+
153
+ **clone repo git et initialisation environnement**
154
+
155
+ > git clone git@github.com:LexouLam/projet-acv-2.git
156
+ > cd projet-acv-2
157
+ > uv sync
158
+
159
+ ## Arguments script "classify_video.py"
160
+
161
+ > classify_video.py arg1
162
+
163
+ Inputs
164
+
165
+ **arg1** : "path/to/video.mp4" ou "live"
166
+
167
+ Outputs
168
+
169
+ Aucun pour l'instant...
170
+
171
+
172
+
173
+
174
+ ## Modèle informations
175
+
176
+ **Pose Landmark Model (BlazePose GHUM 3D)**
177
+ https://camo.githubusercontent.com/d3afebfc801ee1a094c28604c7a0eb25f8b9c9925f75b0fff4c8c8b4871c0d28/68747470733a2f2f6d65646961706970652e6465762f696d616765732f6d6f62696c652f706f73655f747261636b696e675f66756c6c5f626f64795f6c616e646d61726b732e706e67
178
+
179
+ GUIDE: https://github.com/google-ai-edge/mediapipe/blob/master/docs/solutions/pose.md
180
+
181
+ ![alt text](src/image.png)
182
+
183
+ Left shoulder (landmark 11)
184
+ Right shoulder (landmark 12)
185
+ Left elbow (landmark 13)
186
+ Right elbow (landmark 14)
187
+ Left wrist (landmark 15)
188
+ Right wrist (landmark 16)
189
+ Hips (landmarks 23 and 24)
190
+
191
+
README.md.old ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Projet ACV-2
2
+
3
+ Team : Impredalam
4
+
5
+ INFO GLOBALE:
6
+
7
+ Gardez en tête que la branche “main” du projet ne doit jamais être bugée, le code qu’elle contient doit
8
+ toujours pouvoir s’exécuter (sauf bug non anticipé qui nécessitera un “hot-fix”).
9
+
10
+ ======
11
+ JOUR 1
12
+ ======
13
+
14
+ Le thème abordé est la détection et la classification de pose humaine dans le cadre d’une application de
15
+ sport à domicile.
16
+
17
+ ------------------------------------------------------------------------------
18
+ TODO JOUR 1:
19
+
20
+ ---DONE--- 1. Trouver un nom pour votre groupe et un nom pour le projet.
21
+ ---DONE--- 2. Mettre en place un dépôt git
22
+ ---DONE--- 3. Explorer la base de code déjà existante
23
+ - algorithme de détection/classification de poses,
24
+ - algorithme de comptage sur un flux vidéo.
25
+ ---DONE--- 4. Constituer une base de données annotées pour « entraîner » l’algorithme avec quelques images de vous faisant des pompes.
26
+ ---DONE--- 5. Préparer une vidéo démontrant la faisabilité d’un tel projet.
27
+ ---DONE--- 6. Optionnel J1 : une démo live + un repo git structuré sans notebook.
28
+
29
+ ------------------------------------------------------------------------------
30
+ 1. stand-up ---DONE---
31
+
32
+ 2. Prise en compte des exigences client suite à
33
+ la dernière livraison, ---DONE---
34
+
35
+ 3. tération de code ---DONE---
36
+
37
+ 4. 16h : livraison au client ---DONE---
38
+
39
+ 5. 17h : concours de pompe ---POSTPONED---
40
+
41
+ ------------------------------------------------------------------------------
42
+ RESULTAT:
43
+
44
+ Programme founctionnelle, qui détecte les pompes et les compte, formé avec les images d'Internet, nos propres photos, inversées horizontalement pour rendre l'ensemble de données plus grand et plus riche.
45
+
46
+ Acev un logo de notre équipe
47
+ ------------------------------------------------------------------------------
48
+
49
+ ======
50
+ JOUR 2
51
+ ======
52
+
53
+ Développer le vrai projet qui pourra être utilisé par la société.
54
+ Sujet: cours de yoga : classification des positions classiques
55
+
56
+ Le client veut un programme python exécutable en ligne de commande avec une interface simple.
57
+
58
+ ------------------------------------------------------------------------------
59
+ TODO JOUR 2:
60
+
61
+ ---DONE--- 1. Choix d’un sujet parmi les quatres proposés.
62
+ 2. Planification et répartition des tâches, structuration du projet git.
63
+ ---DONE--- 3. Constitution d’une base de données adaptée au sujet choisi (réalisée vous-même, ou pas ?).
64
+ 4. Sortir du notebook, script avec arguments.
65
+  https://docs.python.org/3/library/argparse.html
66
+ 5. Implémentation des options (prioritairement, la possibilité d’afficher des informations pour débugger le programme facilement).
67
+ 6. Documentation minimale pour lancer le programme.
68
+ 7. Optionnel J2 : packagisation poetry ou équivalent + (très optionnel) tests fonctionnels/unitaires.
69
+  https://github.com/features/actions
70
+
71
+ ------------------------------------------------------------------------------
72
+ 1. 9h : prise en main du sujet et gestion de projet (création/répartition des tâches)
73
+
74
+ 2. 9h45 : début du sprint de la journée
75
+
76
+ 3. 16h : livraison au client
77
+
78
+ ------------------------------------------------------------------------------
79
+ RESULTAT:
80
+
81
+ test
82
+ ------------------------------------------------------------------------------
83
+
84
+ ======
85
+ JOUR 3
86
+ ======
87
+
88
+ 1. stand-up,
89
+
90
+ 2. prise en compte des exigences client suite à
91
+ la dernière livraison,
92
+
93
+ 3. tération de code,
94
+
95
+ 4. livraison au client
96
+
97
+
98
+
99
+ ## Set-up environnement
100
+
101
+ **installation uv**
102
+
103
+ > curl -LsSf https://astral.sh/uv/install.sh | sh
104
+ > uv self update
105
+
106
+ **création environnement**
107
+
108
+ > mkdir projet_acv_2
109
+ > cd projet_acv_2/
110
+ > uv init
111
+ > uv venv --python 3.12
112
+ > uv add numpy matplotlib plotly jupyter opencv-python mediapipe
113
+ > uv add tqdm requests pillow scikit-learn
114
+
115
+ **création repo git si non créé**
116
+
117
+ > touch .gitignore
118
+ > git init
119
+ > git add .
120
+ > git commit -m "start repo"
121
+ > git remote add origin git@github.com:LexouLam/projet-acv-2.git
122
+ > git push --set-upstream origin master
123
+ > git push
124
+
125
+ **clone repo git et initialisation environnement**
126
+
127
+ > git clone git@github.com:LexouLam/projet-acv-2.git
128
+ > cd projet-acv-2
129
+ > uv sync
130
+
131
+
132
+
133
+ **Pose Landmark Model (BlazePose GHUM 3D)**
134
+ https://camo.githubusercontent.com/d3afebfc801ee1a094c28604c7a0eb25f8b9c9925f75b0fff4c8c8b4871c0d28/68747470733a2f2f6d65646961706970652e6465762f696d616765732f6d6f62696c652f706f73655f747261636b696e675f66756c6c5f626f64795f6c616e646d61726b732e706e67
135
+
136
+ GUIDE: https://github.com/google-ai-edge/mediapipe/blob/master/docs/solutions/pose.md
137
+
138
+ ![alt text](src/image.png)
139
+
140
+ Left shoulder (landmark 11)
141
+ Right shoulder (landmark 12)
142
+ Left elbow (landmark 13)
143
+ Right elbow (landmark 14)
144
+ Left wrist (landmark 15)
145
+ Right wrist (landmark 16)
146
+ Hips (landmarks 23 and 24)
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from interface_pages.home_page import home_page
3
+ from interface_pages.about_page import about_page
4
+ from interface_pages.yoga_position_from_stream import yoga_position_from_stream
5
+ from interface_pages.yoga_position_from_video import yoga_position_from_video
6
+
7
+ def main(page):
8
+ if page == "Home":
9
+ return home_page()
10
+ elif page == "About us":
11
+ return about_page()
12
+ elif page == "Yoga from stream":
13
+ return yoga_position_from_stream()
14
+ elif page == "Yoga from video":
15
+ return yoga_position_from_video()
16
+
17
+ def interface():
18
+ with gr.Blocks(css="static/styles.css") as demo:
19
+
20
+ # Layout with a Row to hold buttons and content
21
+ with gr.Row():
22
+ with gr.Column(scale=1, elem_classes=["menu-column"]):
23
+ # Vertical Navigation Buttons
24
+ home_button = gr.Button("Home", elem_classes=["menu-button"])
25
+ about_button = gr.Button("About us", elem_classes=["menu-button"])
26
+ yoga_stream_button = gr.Button("Yoga from stream", elem_classes=["menu-button"])
27
+ yoga_video_button = gr.Button("Yoga from video", elem_classes=["menu-button"])
28
+
29
+ # Create page contents
30
+ with gr.Column(elem_id="page-content") as page_content:
31
+ home_page_content = home_page()
32
+ about_page_content = about_page()
33
+ yoga_stream_content = yoga_position_from_stream()
34
+ yoga_video_content = yoga_position_from_video()
35
+
36
+ # Set initial visibility
37
+ home_page_content.visible = True
38
+ about_page_content.visible = False
39
+ yoga_stream_content.visible = False
40
+ yoga_video_content.visible = False
41
+
42
+ # Button click handlers
43
+ def show_page(page):
44
+ return [
45
+ gr.update(visible=(content == page))
46
+ for content in [
47
+ home_page_content,
48
+ about_page_content,
49
+ yoga_stream_content,
50
+ yoga_video_content,
51
+ ]
52
+ ]
53
+
54
+ home_button.click(
55
+ lambda: show_page(home_page_content),
56
+ outputs=[
57
+ home_page_content,
58
+ about_page_content,
59
+ yoga_stream_content,
60
+ yoga_video_content,
61
+ ],
62
+ )
63
+ about_button.click(
64
+ lambda: show_page(about_page_content),
65
+ outputs=[
66
+ home_page_content,
67
+ about_page_content,
68
+ yoga_stream_content,
69
+ yoga_video_content,
70
+ ],
71
+ )
72
+ yoga_stream_button.click(
73
+ lambda: show_page(yoga_stream_content),
74
+ outputs=[
75
+ home_page_content,
76
+ about_page_content,
77
+ yoga_stream_content,
78
+ yoga_video_content,
79
+ ],
80
+ )
81
+ yoga_video_button.click(
82
+ lambda: show_page(yoga_video_content),
83
+ outputs=[
84
+ home_page_content,
85
+ about_page_content,
86
+ yoga_stream_content,
87
+ yoga_video_content,
88
+ ],
89
+ )
90
+
91
+ return demo
92
+
93
+
94
+ if __name__ == "__main__":
95
+ interface().launch(share=True)
classify_video.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import cv2
4
+ import numpy as np
5
+ from rich.console import Console
6
+ from rich.panel import Panel
7
+ from rich.align import Align
8
+ from rich.layout import Layout
9
+ from pyfiglet import Figlet
10
+ import mediapipe as mp
11
+ from PoseClassification.pose_embedding import FullBodyPoseEmbedding
12
+ from PoseClassification.pose_classifier import PoseClassifier
13
+ from PoseClassification.utils import EMADictSmoothing
14
+ from PoseClassification.visualize import PoseClassificationVisualizer
15
+
16
+ # For cross-platform compatibility
17
+ try:
18
+ import msvcrt # Windows
19
+ except ImportError:
20
+ import termios # Unix-like
21
+ import tty
22
+
23
+
24
+ def getch():
25
+ if sys.platform == "win32":
26
+ return msvcrt.getch().decode("utf-8")
27
+ else:
28
+ fd = sys.stdin.fileno()
29
+ old_settings = termios.tcgetattr(fd)
30
+ try:
31
+ tty.setraw(sys.stdin.fileno())
32
+ ch = sys.stdin.read(1)
33
+ finally:
34
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
35
+ return ch
36
+
37
+
38
+ def create_ascii_title(text):
39
+ f = Figlet(font="isometric2")
40
+ return f.renderText(text)
41
+
42
+
43
+ def main(input_source, display=False, output_file=None):
44
+ console = Console()
45
+ layout = Layout()
46
+
47
+ # Create ASCII title
48
+ ascii_title = create_ascii_title("YOGAI")
49
+
50
+ # Create the layout
51
+ layout.split(
52
+ Layout(Panel(Align.center(ascii_title), border_style="bold blue"), size=15),
53
+ Layout(name="main"),
54
+ )
55
+
56
+ is_live = input_source == "live"
57
+ if is_live:
58
+ layout["main"].update(
59
+ Panel(
60
+ "Processing live video from camera",
61
+ title="Video Classification",
62
+ border_style="bold blue",
63
+ )
64
+ )
65
+ else:
66
+ layout["main"].update(
67
+ Panel(
68
+ f"Processing video: {input_source}",
69
+ title="Video Classification",
70
+ border_style="bold blue",
71
+ )
72
+ )
73
+
74
+ console.print(layout)
75
+
76
+ # Initialize pose tracker, embedder, and classifier
77
+ mp_pose = mp.solutions.pose
78
+ pose_tracker = mp_pose.Pose()
79
+ pose_embedder = FullBodyPoseEmbedding()
80
+ pose_classifier = PoseClassifier(
81
+ pose_samples_folder="data/yoga_poses_csvs_out",
82
+ pose_embedder=pose_embedder,
83
+ top_n_by_max_distance=30,
84
+ top_n_by_mean_distance=10,
85
+ )
86
+ pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2)
87
+
88
+ # Open the video source
89
+ if is_live:
90
+ video = cv2.VideoCapture(0)
91
+ fps = 30 # Assume 30 fps for live video
92
+ total_frames = float("inf") # Infinite frames for live video
93
+ else:
94
+ video = cv2.VideoCapture(input_source)
95
+ fps = video.get(cv2.CAP_PROP_FPS)
96
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
97
+
98
+ # Initialize pose timings (use lowercase for keys)
99
+ pose_timings = {
100
+ "chair": 0,
101
+ "cobra": 0,
102
+ "dog": 0,
103
+ "plank": 0,
104
+ "goddess": 0,
105
+ "tree": 0,
106
+ "warrior": 0,
107
+ "no pose detected": 0,
108
+ "fallen": 0,
109
+ }
110
+
111
+ frame_count = 0
112
+ while True:
113
+ ret, frame = video.read()
114
+ if not ret:
115
+ if is_live:
116
+ console.print(
117
+ "[bold red]Error reading from camera. Exiting...[/bold red]"
118
+ )
119
+ break
120
+
121
+ # Process the frame
122
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
123
+ result = pose_tracker.process(image=frame_rgb)
124
+
125
+ if result.pose_landmarks is not None:
126
+ # Draw landmarks on the frame
127
+ mp.solutions.drawing_utils.draw_landmarks(
128
+ frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS
129
+ )
130
+
131
+ frame_height, frame_width = frame.shape[0], frame.shape[1]
132
+ pose_landmarks = np.array(
133
+ [
134
+ [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
135
+ for lmk in result.pose_landmarks.landmark
136
+ ],
137
+ dtype=np.float32,
138
+ )
139
+
140
+ # Classify the pose
141
+ pose_classification = pose_classifier(pose_landmarks)
142
+ pose_classification_filtered = pose_classification_filter(
143
+ pose_classification
144
+ )
145
+
146
+ # Update pose timings (only for the pose with highest confidence)
147
+ max_pose = max(
148
+ pose_classification_filtered, key=pose_classification_filtered.get
149
+ ).lower()
150
+ pose_timings[max_pose] += 1 / fps
151
+ else:
152
+ pose_timings["no pose detected"] += 1 / fps
153
+
154
+ frame_count += 1
155
+ if frame_count % 30 == 0: # Update every 30 frames
156
+ panel_content = (
157
+ f"[bold]Chair:[/bold] {pose_timings['chair']:.2f}s\n"
158
+ f"[bold]Cobra:[/bold] {pose_timings['cobra']:.2f}s\n"
159
+ f"[bold]Dog:[/bold] {pose_timings['dog']:.2f}s\n"
160
+ f"[bold]Plank:[/bold] {pose_timings['plank']:.2f}s\n"
161
+ f"[bold]Goddess:[/bold] {pose_timings['goddess']:.2f}s\n"
162
+ f"[bold]Tree:[/bold] {pose_timings['tree']:.2f}s\n"
163
+ f"[bold]Warrior:[/bold] {pose_timings['warrior']:.2f}s\n"
164
+ f"---\n"
165
+ f"[bold]No pose detected:[/bold] {pose_timings['no pose detected']:.2f}s\n"
166
+ f"[bold]Fallen:[/bold] {pose_timings['fallen']:.2f}s"
167
+ )
168
+ if not is_live:
169
+ panel_content += f"\n\nProcessed {frame_count}/{total_frames} frames"
170
+
171
+ layout["main"].update(
172
+ Panel(
173
+ panel_content,
174
+ title="Classification Results",
175
+ border_style="bold green",
176
+ )
177
+ )
178
+ console.print(layout)
179
+
180
+ if display:
181
+ cv2.imshow("Video", frame)
182
+ if cv2.waitKey(1) & 0xFF == ord("q"):
183
+ break
184
+
185
+ video.release()
186
+ if display:
187
+ cv2.destroyAllWindows()
188
+
189
+ # Final results
190
+ final_panel_content = (
191
+ f"[bold]Chair:[/bold] {pose_timings['chair']:.2f}s\n"
192
+ f"[bold]Cobra:[/bold] {pose_timings['cobra']:.2f}s\n"
193
+ f"[bold]Dog:[/bold] {pose_timings['dog']:.2f}s\n"
194
+ f"[bold]Plank:[/bold] {pose_timings['plank']:.2f}s\n"
195
+ f"[bold]Goddess:[/bold] {pose_timings['goddess']:.2f}s\n"
196
+ f"[bold]Tree:[/bold] {pose_timings['tree']:.2f}s\n"
197
+ f"[bold]Warrior:[/bold] {pose_timings['warrior']:.2f}s\n"
198
+ f"---\n"
199
+ f"[bold]No pose detected:[/bold] {pose_timings['no pose detected']:.2f}s\n"
200
+ f"[bold]Fallen:[/bold] {pose_timings['fallen']:.2f}s"
201
+ )
202
+ layout["main"].update(
203
+ Panel(
204
+ final_panel_content,
205
+ title="Final Classification Results",
206
+ border_style="bold green",
207
+ )
208
+ )
209
+ console.print(layout)
210
+
211
+ if output_file:
212
+ console.print(f"[green]Output saved to: {output_file}[/green]")
213
+
214
+
215
+ if __name__ == "__main__":
216
+ parser = argparse.ArgumentParser(
217
+ description="Classify poses in a video file or from live camera."
218
+ )
219
+ parser.add_argument("input", help="Input video file or 'live' for camera feed")
220
+ parser.add_argument(
221
+ "--display", action="store_true", help="Display the video with detected poses"
222
+ )
223
+ parser.add_argument("--output", help="Output video file")
224
+
225
+ if len(sys.argv) == 1:
226
+ parser.print_help(sys.stderr)
227
+ sys.exit(1)
228
+
229
+ args = parser.parse_args()
230
+
231
+ main(args.input, args.display, args.output)
hello.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from projet-acv-2!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
interface_pages/__init__.py ADDED
File without changes
interface_pages/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (184 Bytes). View file
 
interface_pages/__pycache__/about_page.cpython-312.pyc ADDED
Binary file (438 Bytes). View file
 
interface_pages/__pycache__/home_page.cpython-312.pyc ADDED
Binary file (500 Bytes). View file
 
interface_pages/__pycache__/yoga_position_from_stream.cpython-312.pyc ADDED
Binary file (1.61 kB). View file
 
interface_pages/__pycache__/yoga_position_from_video.cpython-312.pyc ADDED
Binary file (613 Bytes). View file
 
interface_pages/about_page.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def about_page():
5
+ return gr.Markdown(
6
+ """
7
+ # About Us
8
+
9
+ WYOGAI — the BEST.
10
+ """
11
+ )
interface_pages/home_page.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def home_page():
5
+ return gr.Markdown(
6
+ """
7
+ # Welcome to YOGAI App!
8
+
9
+ This is your home page where you can explore different yoga practices.
10
+ """
11
+ )
interface_pages/yoga_position_from_stream.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def yoga_position_from_stream():
5
+ def download_video(video_path):
6
+ if video_path:
7
+ return video_path
8
+ return None
9
+
10
+ with gr.Column() as yoga_stream:
11
+ gr.Markdown("# Yoga from Stream")
12
+ gr.Markdown(
13
+ "Stream live yoga sessions and practice along with our expert instructors."
14
+ )
15
+ video_feed = gr.Video(source="webcam", streaming=True, interactive=True)
16
+ download_button = gr.Button("Download Recorded Video")
17
+ video_output = gr.Video()
18
+
19
+ download_button.click(
20
+ download_video,
21
+ inputs=[video_feed], # Changed from video_output to video_feed
22
+ outputs=[gr.File()],
23
+ )
24
+
25
+ return yoga_stream
26
+
27
+
28
+ if __name__ == "__main__":
29
+ with gr.Blocks() as demo:
30
+ yoga_position_from_stream()
31
+ demo.launch()
interface_pages/yoga_position_from_video.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def yoga_position_from_video():
5
+ return gr.Markdown(
6
+ """
7
+ # Yoga from Video
8
+
9
+ Watch pre-recorded yoga sessions and practice at your convenience.
10
+
11
+ Select a video below:
12
+
13
+ - Beginner Yoga
14
+ - Advanced Techniques
15
+ - Restorative Yoga
16
+ """
17
+ )
pushups_counter.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import cv2
3
+ import numpy as np
4
+ from mediapipe.python.solutions import drawing_utils as mp_drawing
5
+ import mediapipe as mp
6
+ from PoseClassification.pose_embedding import FullBodyPoseEmbedding
7
+ from PoseClassification.pose_classifier import PoseClassifier
8
+ from PoseClassification.utils import EMADictSmoothing
9
+ from PoseClassification.utils import RepetitionCounter
10
+ from PoseClassification.visualize import PoseClassificationVisualizer
11
+
12
+ mp_pose = mp.solutions.pose
13
+ pose_tracker = mp_pose.Pose()
14
+
15
+ pose_samples_folder = "data/fitness_poses_csvs_out"
16
+ class_name = "pushups_down"
17
+
18
+ pose_embedder = FullBodyPoseEmbedding()
19
+
20
+ pose_classifier = PoseClassifier(
21
+ pose_samples_folder=pose_samples_folder,
22
+ pose_embedder=pose_embedder,
23
+ top_n_by_max_distance=30,
24
+ top_n_by_mean_distance=10,
25
+ )
26
+
27
+ pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2)
28
+
29
+ repetition_counter = RepetitionCounter(
30
+ class_name=class_name, enter_threshold=6, exit_threshold=4
31
+ )
32
+
33
+ pose_classification_visualizer = PoseClassificationVisualizer(
34
+ class_name=class_name, plot_x_max=1000, plot_y_max=10
35
+ )
36
+
37
+ video_cap = cv2.VideoCapture(0)
38
+ video_fps = 30
39
+ video_width = 1280
40
+ video_height = 720
41
+ video_cap.set(cv2.CAP_PROP_FRAME_WIDTH, video_width)
42
+ video_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height)
43
+
44
+ frame_idx = 0
45
+ output_frame = None
46
+
47
+ try:
48
+ with tqdm.tqdm(position=0, leave=True) as pbar:
49
+ while True:
50
+ success, input_frame = video_cap.read()
51
+ if not success:
52
+ print("Unable to read input video frame, breaking!")
53
+ break
54
+
55
+ # Run pose tracker
56
+ input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
57
+ result = pose_tracker.process(image=input_frame_rgb)
58
+ pose_landmarks = result.pose_landmarks
59
+
60
+ # Prepare the output frame
61
+ output_frame = input_frame.copy()
62
+
63
+ # Add a white banner on top
64
+ banner_height = 180
65
+ output_frame[0:banner_height, :] = (255, 255, 255) # White color
66
+
67
+ # Load the logo image
68
+ logo = cv2.imread("src/logo_impredalam.jpg")
69
+ logo_height, logo_width = logo.shape[:2]
70
+ logo = cv2.resize(
71
+ logo, (logo_width // 3, logo_height // 3)
72
+ ) # Resize to 1/3 scale
73
+
74
+ # Overlay the logo on the upper right corner
75
+ output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (
76
+ logo
77
+ )
78
+ if pose_landmarks is not None:
79
+ mp_drawing.draw_landmarks(
80
+ image=output_frame,
81
+ landmark_list=pose_landmarks,
82
+ connections=mp_pose.POSE_CONNECTIONS,
83
+ )
84
+
85
+ # Get landmarks
86
+ frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
87
+ pose_landmarks = np.array(
88
+ [
89
+ [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
90
+ for lmk in pose_landmarks.landmark
91
+ ],
92
+ dtype=np.float32,
93
+ )
94
+ assert pose_landmarks.shape == (
95
+ 33,
96
+ 3,
97
+ ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
98
+
99
+ # Classify the pose on the current frame
100
+ pose_classification = pose_classifier(pose_landmarks)
101
+
102
+ # Smooth classification using EMA
103
+ pose_classification_filtered = pose_classification_filter(
104
+ pose_classification
105
+ )
106
+
107
+ # Count repetitions
108
+ repetitions_count = repetition_counter(pose_classification_filtered)
109
+
110
+ # Display repetitions count on the frame
111
+ cv2.putText(
112
+ output_frame,
113
+ f"Push-Ups: {repetitions_count}",
114
+ (10, 30),
115
+ cv2.FONT_HERSHEY_SIMPLEX,
116
+ 1,
117
+ (0, 0, 0),
118
+ 2,
119
+ cv2.LINE_AA,
120
+ )
121
+ # Display classified pose on the frame
122
+ cv2.putText(
123
+ output_frame,
124
+ f"Pose: {pose_classification}",
125
+ (10, 70),
126
+ cv2.FONT_HERSHEY_SIMPLEX,
127
+ 1.2, # Smaller font size
128
+ (0, 0, 0),
129
+ 1, # Thinner line
130
+ cv2.LINE_AA,
131
+ )
132
+ else:
133
+ # If no landmarks are detected, still display the last count
134
+ repetitions_count = repetition_counter.n_repeats
135
+ cv2.putText(
136
+ output_frame,
137
+ f"Push-Ups: {repetitions_count}",
138
+ (10, 30),
139
+ cv2.FONT_HERSHEY_SIMPLEX,
140
+ 1,
141
+ (0, 255, 0),
142
+ 2,
143
+ cv2.LINE_AA,
144
+ )
145
+
146
+ cv2.imshow("Push-Up Counter", output_frame)
147
+
148
+ key = cv2.waitKey(1) & 0xFF
149
+ if key == ord("q"):
150
+ break
151
+ elif key == ord("r"):
152
+ repetition_counter.reset()
153
+ print("Counter reset!")
154
+
155
+ frame_idx += 1
156
+ pbar.update()
157
+
158
+ finally:
159
+
160
+ pose_tracker.close()
161
+ video_cap.release()
162
+ cv2.destroyAllWindows()
pyproject.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "projet-acv-2"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "argparse>=1.4.0",
9
+ "ffmpeg>=1.4",
10
+ "gradio>=3.36.1",
11
+ "jupyter>=1.1.1",
12
+ "matplotlib>=3.9.2",
13
+ "mediapipe>=0.10.15",
14
+ "numpy>=1.26.4",
15
+ "opencv-python>=4.10.0.84",
16
+ "pillow>=11.0.0",
17
+ "plotly>=5.24.1",
18
+ "pyfiglet>=1.0.2",
19
+ "requests>=2.32.3",
20
+ "rich>=13.9.2",
21
+ "scikit-learn>=1.5.2",
22
+ "streamlit>=1.9.0",
23
+ "tqdm>=4.66.5",
24
+ ]
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ argparse>=1.4.0
2
+ ffmpeg>=1.4
3
+ gradio>=3.36.1
4
+ jupyter>=1.1.1
5
+ matplotlib>=3.9.2
6
+ mediapipe>=0.10.15
7
+ numpy>=1.26.4
8
+ opencv-python>=4.10.0.84
9
+ pillow>=11.0.0
10
+ plotly>=5.24.1
11
+ pyfiglet>=1.0.2
12
+ requests>=2.32.3
13
+ rich>=13.9.2
14
+ scikit-learn>=1.5.2
15
+ streamlit>=1.9.0
16
+ tqdm>=4.66.5
src/image.png ADDED
src/logo_impredalam.jpg ADDED
static/styles.css ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .menu-column {
2
+ background-color: #4CAF50; /* Background color of the menu */
3
+ padding: 20px; /* Padding around the menu */
4
+ height: 100vh; /* Full height for the menu */
5
+ }
6
+
7
+ .menu-button {
8
+ color: white; /* Text color for the buttons */
9
+ background-color: transparent; /* Transparent background */
10
+ border: none; /* No border */
11
+ padding: 10px 15px; /* Padding for the buttons */
12
+ width: 100%; /* Full width for buttons */
13
+ text-align: left; /* Align text to the left */
14
+ cursor: pointer; /* Pointer cursor on hover */
15
+ transition: background-color 0.3s; /* Smooth transition */
16
+ }
17
+
18
+ .menu-button:hover {
19
+ background-color: rgba(255, 255, 255, 0.2); /* Light hover effect */
20
+ }
21
+
22
+ .gradio-container {
23
+ margin-top: 0; /* Remove top margin to allow for full height */
24
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
yoga_position.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import cv2
3
+ import numpy as np
4
+ import re
5
+ import os
6
+ from mediapipe.python.solutions import drawing_utils as mp_drawing
7
+ import mediapipe as mp
8
+ from PoseClassification.pose_embedding import FullBodyPoseEmbedding
9
+ from PoseClassification.pose_classifier import PoseClassifier
10
+ from PoseClassification.utils import EMADictSmoothing
11
+ # from PoseClassification.utils import RepetitionCounter
12
+ from PoseClassification.visualize import PoseClassificationVisualizer
13
+ import argparse
14
+ from PoseClassification.utils import show_image
15
+
16
+ def main():
17
+ #Load arguments
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("video_path", help="string video path in")
20
+ args = parser.parse_args()
21
+
22
+ video_path_in = args.video_path
23
+ direct_video=False
24
+ if video_path_in=="live":
25
+ video_path_in='data/live.mp4'
26
+ direct_video=True
27
+
28
+ video_path_out = re.sub(r'.mp4', r'_classified_video.mp4', video_path_in)
29
+ results_classification_path_out = re.sub(r'.mp4', r'_classified_results.csv', video_path_in)
30
+
31
+
32
+ # Instruction if direct flux video : not for now
33
+ if direct_video :
34
+ video_cap = cv2.VideoCapture(0)
35
+ video_fps = 30
36
+ video_width = 1280
37
+ video_height = 720
38
+
39
+ class_name='tree'
40
+
41
+ # Initialize tracker, classifier and current position.
42
+ # Initialize tracker.
43
+ mp_pose = mp.solutions.pose
44
+ pose_tracker = mp_pose.Pose()
45
+ # Folder with pose class CSVs. That should be the same folder you used while
46
+ # building classifier to output CSVs.
47
+ pose_samples_folder = 'data/yoga_poses_csvs_out'
48
+ # Initialize embedder.
49
+ pose_embedder = FullBodyPoseEmbedding()
50
+ # Initialize classifier.
51
+ # Check that you are using the same parameters as during bootstrapping.
52
+ pose_classifier = PoseClassifier(
53
+ pose_samples_folder=pose_samples_folder,
54
+ pose_embedder=pose_embedder,
55
+ top_n_by_max_distance=30,
56
+ top_n_by_mean_distance=10)
57
+
58
+ # Initialize list of results
59
+ position_list=[]
60
+ frame_list=[]
61
+
62
+ # Initialize EMA smoothing.
63
+ pose_classification_filter = EMADictSmoothing(
64
+ window_size=10,
65
+ alpha=0.2)
66
+
67
+ # Initialize renderer.
68
+ pose_classification_visualizer = PoseClassificationVisualizer(
69
+ class_name=class_name,
70
+ plot_x_max=1000,
71
+ # Graphic looks nicer if it's the same as `top_n_by_mean_distance`.
72
+ plot_y_max=10)
73
+
74
+ # Open output video.
75
+ out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height))
76
+
77
+ # Initialize list of results
78
+ frame_idx = 0
79
+ current_position = {"none":10.0}
80
+
81
+ output_frame = None
82
+ try:
83
+ with tqdm.tqdm(position=0, leave=True) as pbar:
84
+ while True:
85
+ #on rajoute à chaque itération la valeur de current_position et de frame_idx
86
+ position_list.append(current_position)
87
+ frame_list.append(frame_idx)
88
+
89
+ #on renvoie les deux valeurs au fur et à mesure
90
+ with open(results_classification_path_out, 'a') as f:
91
+ f.write(f'{frame_idx};{current_position}\n')
92
+
93
+ success, input_frame = video_cap.read()
94
+ if not success:
95
+ print("Unable to read input video frame, breaking!")
96
+ break
97
+
98
+ # Run pose tracker
99
+ input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
100
+ result = pose_tracker.process(image=input_frame_rgb)
101
+ pose_landmarks = result.pose_landmarks
102
+
103
+ # Prepare the output frame
104
+ output_frame = input_frame.copy()
105
+
106
+ # Add a white banner on top
107
+ banner_height = 180
108
+ output_frame[0:banner_height, :] = (255, 255, 255) # White color
109
+
110
+ # Load the logo image
111
+ logo = cv2.imread("src/logo_impredalam.jpg")
112
+ logo_height, logo_width = logo.shape[:2]
113
+ logo = cv2.resize(
114
+ logo, (logo_width // 3, logo_height // 3)
115
+ ) # Resize to 1/3 scale
116
+
117
+ # Overlay the logo on the upper right corner
118
+ output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (
119
+ logo
120
+ )
121
+ if pose_landmarks is not None:
122
+ mp_drawing.draw_landmarks(
123
+ image=output_frame,
124
+ landmark_list=pose_landmarks,
125
+ connections=mp_pose.POSE_CONNECTIONS,
126
+ )
127
+
128
+ # Get landmarks
129
+ frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
130
+ pose_landmarks = np.array(
131
+ [
132
+ [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
133
+ for lmk in pose_landmarks.landmark
134
+ ],
135
+ dtype=np.float32,
136
+ )
137
+ assert pose_landmarks.shape == (
138
+ 33,
139
+ 3,
140
+ ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
141
+
142
+ # Classify the pose on the current frame
143
+ pose_classification = pose_classifier(pose_landmarks)
144
+
145
+ # Smooth classification using EMA
146
+ pose_classification_filtered = pose_classification_filter(pose_classification)
147
+ current_position=pose_classification_filtered
148
+
149
+ # Count repetitions
150
+ # repetitions_count = repetition_counter(pose_classification_filtered)
151
+
152
+ # Display repetitions count on the frame
153
+ # cv2.putText(
154
+ # output_frame,
155
+ # f"Push-Ups: {repetitions_count}",
156
+ # (10, 30),
157
+ # cv2.FONT_HERSHEY_SIMPLEX,
158
+ # 1,
159
+ # (0, 0, 0),
160
+ # 2,
161
+ # cv2.LINE_AA,
162
+ # )
163
+ # Display classified pose on the frame
164
+ cv2.putText(
165
+ output_frame,
166
+ f"Pose: {current_position}",
167
+ (10, 70),
168
+ cv2.FONT_HERSHEY_SIMPLEX,
169
+ 1.2, # Smaller font size
170
+ (0, 0, 0),
171
+ 1, # Thinner line
172
+ cv2.LINE_AA,
173
+ )
174
+ else:
175
+ # If no landmarks are detected, still display the last count
176
+ # repetitions_count = repetition_counter.n_repeats
177
+ # cv2.putText(
178
+ # output_frame,
179
+ # f"Push-Ups: {repetitions_count}",
180
+ # (10, 30),
181
+ # cv2.FONT_HERSHEY_SIMPLEX,
182
+ # 1,
183
+ # (0, 255, 0),
184
+ # 2,
185
+ # cv2.LINE_AA,
186
+ # )
187
+ current_position={'None':10.0}
188
+ cv2.putText(
189
+ output_frame,
190
+ f"Pose: {current_position}",
191
+ (10, 70),
192
+ cv2.FONT_HERSHEY_SIMPLEX,
193
+ 1.2, # Smaller font size
194
+ (0, 0, 0),
195
+ 1, # Thinner line
196
+ cv2.LINE_AA,
197
+ )
198
+
199
+ cv2.imshow("Yoga position classification", output_frame)
200
+
201
+ key = cv2.waitKey(1) & 0xFF
202
+ if key == ord("q"):
203
+ break
204
+ elif key == ord("r"):
205
+ # repetition_counter.reset()
206
+ print("Counter reset!")
207
+
208
+ frame_idx += 1
209
+ pbar.update()
210
+
211
+ finally:
212
+
213
+ pose_tracker.close()
214
+ video_cap.release()
215
+ cv2.destroyAllWindows()
216
+
217
+ # Instruction if recorded video with video_path_in
218
+ else:
219
+ assert type(video_path_in)==str, "Error in video path format, not a string. Abort."
220
+ # Open video and get video parameters and check if video is OK
221
+ video_cap = cv2.VideoCapture(video_path_in)
222
+ video_n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
223
+ video_fps = video_cap.get(cv2.CAP_PROP_FPS)
224
+ video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
225
+ video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
226
+ assert type(video_n_frames)==float, 'Error in input video frames type. Abort.'
227
+ assert video_n_frames>0.0, 'Error in input video frames number : no frame. Abort.'
228
+
229
+ class_name='tree'
230
+
231
+ # Initialize tracker, classifier and current position.
232
+ # Initialize tracker.
233
+ mp_pose = mp.solutions.pose
234
+ pose_tracker = mp_pose.Pose()
235
+ # Folder with pose class CSVs. That should be the same folder you used while
236
+ # building classifier to output CSVs.
237
+ pose_samples_folder = 'data/yoga_poses_csvs_out'
238
+ # Initialize embedder.
239
+ pose_embedder = FullBodyPoseEmbedding()
240
+ # Initialize classifier.
241
+ # Check that you are using the same parameters as during bootstrapping.
242
+ pose_classifier = PoseClassifier(
243
+ pose_samples_folder=pose_samples_folder,
244
+ pose_embedder=pose_embedder,
245
+ top_n_by_max_distance=30,
246
+ top_n_by_mean_distance=10)
247
+
248
+ # Initialize list of results
249
+ position_list=[]
250
+ frame_list=[]
251
+
252
+ # Initialize EMA smoothing.
253
+ pose_classification_filter = EMADictSmoothing(
254
+ window_size=10,
255
+ alpha=0.2)
256
+
257
+ # Initialize renderer.
258
+ pose_classification_visualizer = PoseClassificationVisualizer(
259
+ class_name=class_name,
260
+ plot_x_max=video_n_frames,
261
+ # Graphic looks nicer if it's the same as `top_n_by_mean_distance`.
262
+ plot_y_max=10)
263
+
264
+ # Open output video.
265
+ out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height))
266
+
267
+ # Initialize list of results
268
+ frame_idx = 0
269
+ current_position = {"none":10.0}
270
+
271
+ output_frame = None
272
+ with tqdm.tqdm(total=video_n_frames, position=0, leave=True) as pbar:
273
+ while True:
274
+ #on rajoute à chaque itération la valeur de current_position et de frame_idx
275
+ position_list.append(current_position)
276
+ frame_list.append(frame_idx)
277
+
278
+ #on renvoie les deux valeurs au fur et à mesure
279
+ with open(results_classification_path_out, 'a') as f:
280
+ f.write(f'{frame_idx};{current_position}\n')
281
+
282
+ # Get next frame of the video.
283
+ success, input_frame = video_cap.read()
284
+ if not success:
285
+ print("unable to read input video frame, breaking!")
286
+ break
287
+
288
+ # Run pose tracker.
289
+ input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
290
+ result = pose_tracker.process(image=input_frame)
291
+ pose_landmarks = result.pose_landmarks
292
+
293
+ # Draw pose prediction.
294
+ output_frame = input_frame.copy()
295
+ if pose_landmarks is not None:
296
+ mp_drawing.draw_landmarks(
297
+ image=output_frame,
298
+ landmark_list=pose_landmarks,
299
+ connections=mp_pose.POSE_CONNECTIONS)
300
+
301
+ if pose_landmarks is not None:
302
+ # Get landmarks.
303
+ frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
304
+ pose_landmarks = np.array([[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
305
+ for lmk in pose_landmarks.landmark], dtype=np.float32)
306
+ assert pose_landmarks.shape == (33, 3), 'Unexpected landmarks shape: {}'.format(pose_landmarks.shape)
307
+
308
+ # Classify the pose on the current frame.
309
+ pose_classification = pose_classifier(pose_landmarks)
310
+
311
+ # Smooth classification using EMA.
312
+ pose_classification_filtered = pose_classification_filter(pose_classification)
313
+
314
+ current_position=pose_classification_filtered
315
+ # Count repetitions.
316
+ # repetitions_count = repetition_counter(pose_classification_filtered)
317
+ else:
318
+ # No pose => no classification on current frame.
319
+ pose_classification = None
320
+
321
+ # Still add empty classification to the filter to maintaing correct
322
+ # smoothing for future frames.
323
+ pose_classification_filtered = pose_classification_filter(dict())
324
+ pose_classification_filtered = None
325
+
326
+ current_position='None'
327
+ # Don't update the counter presuming that person is 'frozen'. Just
328
+ # take the latest repetitions count.
329
+ # repetitions_count = repetition_counter.n_repeats
330
+
331
+ # Draw classification plot and repetition counter.
332
+ output_frame = pose_classification_visualizer(
333
+ frame=output_frame,
334
+ pose_classification=pose_classification,
335
+ pose_classification_filtered=pose_classification_filtered,
336
+ repetitions_count='0'
337
+ )
338
+
339
+ # Save the output frame.
340
+ out_video.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR))
341
+
342
+ # Show intermediate frames of the video to track progress.
343
+ if frame_idx % 50 == 0:
344
+ show_image(output_frame)
345
+
346
+ frame_idx += 1
347
+ pbar.update()
348
+
349
+ # Close output video.
350
+ out_video.release()
351
+
352
+ # Release MediaPipe resources.
353
+ pose_tracker.close()
354
+
355
+ # Show the last frame of the video.
356
+ if output_frame is not None:
357
+ show_image(output_frame)
358
+
359
+ video_cap.release()
360
+
361
+
362
+
363
+
364
+ return current_position #string between ['Chair', 'Cobra', 'Dog', 'Goddess', 'Plank', 'Tree', 'Warrior', 'None' = nonfallen, 'Fall']
365
+
366
+ # mp_pose = mp.solutions.pose
367
+ # pose_tracker = mp_pose.Pose()
368
+
369
+ # pose_samples_folder = "data/yoga_poses_csvs_out"
370
+ # class_name = "tree"
371
+
372
+ # pose_embedder = FullBodyPoseEmbedding()
373
+
374
+ # pose_classifier = PoseClassifier(
375
+ # pose_samples_folder=pose_samples_folder,
376
+ # pose_embedder=pose_embedder,
377
+ # top_n_by_max_distance=30,
378
+ # top_n_by_mean_distance=10,
379
+ # )
380
+
381
+ # pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2)
382
+
383
+ # repetition_counter = RepetitionCounter(
384
+ # class_name=class_name, enter_threshold=6, exit_threshold=4
385
+ # )
386
+
387
+ # pose_classification_visualizer = PoseClassificationVisualizer(
388
+ # class_name=class_name, plot_x_max=1000, plot_y_max=10
389
+ # )
390
+
391
+ # video_cap = cv2.VideoCapture(0)
392
+ # video_fps = 30
393
+ # video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
394
+ # video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
395
+
396
+ # frame_idx = 0
397
+ # output_frame = None
398
+
399
+ # try:
400
+ # with tqdm.tqdm(position=0, leave=True) as pbar:
401
+ # while True:
402
+ # success, input_frame = video_cap.read()
403
+ # if not success:
404
+ # print("Unable to read input video frame, breaking!")
405
+ # break
406
+
407
+ # # Run pose tracker
408
+ # input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
409
+ # result = pose_tracker.process(image=input_frame_rgb)
410
+ # pose_landmarks = result.pose_landmarks
411
+
412
+ # # Prepare the output frame
413
+ # output_frame = input_frame.copy()
414
+ # if pose_landmarks is not None:
415
+ # mp_drawing.draw_landmarks(
416
+ # image=output_frame,
417
+ # landmark_list=pose_landmarks,
418
+ # connections=mp_pose.POSE_CONNECTIONS,
419
+ # )
420
+
421
+ # # Get landmarks
422
+ # frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
423
+ # pose_landmarks = np.array(
424
+ # [
425
+ # [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
426
+ # for lmk in pose_landmarks.landmark
427
+ # ],
428
+ # dtype=np.float32,
429
+ # )
430
+ # assert pose_landmarks.shape == (
431
+ # 33,
432
+ # 3,
433
+ # ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
434
+
435
+ # # Classify the pose on the current frame
436
+ # pose_classification = pose_classifier(pose_landmarks)
437
+
438
+ # # Smooth classification using EMA
439
+ # pose_classification_filtered = pose_classification_filter(
440
+ # pose_classification
441
+ # )
442
+
443
+ # # Count repetitions
444
+ # # repetitions_count = repetition_counter(pose_classification_filtered)
445
+
446
+ # # Display repetitions count on the frame
447
+ # # cv2.putText(
448
+ # # output_frame,
449
+ # # f"Push-Ups: {repetitions_count}",
450
+ # # (10, 30),
451
+ # # cv2.FONT_HERSHEY_SIMPLEX,
452
+ # # 1,
453
+ # # (0, 255, 0),
454
+ # # 2,
455
+ # # cv2.LINE_AA,
456
+ # # )
457
+
458
+ # # Display classified pose on the frame
459
+ # cv2.putText(
460
+ # output_frame,
461
+ # f"Pose: {pose_classification}",
462
+ # (10, 70),
463
+ # cv2.FONT_HERSHEY_SIMPLEX,
464
+ # 1,
465
+ # (255, 0, 0),
466
+ # 2,
467
+ # cv2.LINE_AA,
468
+ # )
469
+ # else:
470
+ # # If no landmarks are detected, still display the last count
471
+ # # repetitions_count = repetition_counter.n_repeats
472
+ # # cv2.putText(
473
+ # # output_frame,
474
+ # # f"Push-Ups: {repetitions_count}",
475
+ # # (10, 30),
476
+ # # cv2.FONT_HERSHEY_SIMPLEX,
477
+ # # 1,
478
+ # # (0, 255, 0),
479
+ # # 2,
480
+ # # cv2.LINE_AA,
481
+ # # )
482
+ # # If no landmarks are detected, still display the last classified pose
483
+ # # Display classified pose on the frame
484
+ # cv2.putText(
485
+ # output_frame,
486
+ # f"Pose: {pose_classification}",
487
+ # (10, 70),
488
+ # cv2.FONT_HERSHEY_SIMPLEX,
489
+ # 1,
490
+ # (255, 0, 0),
491
+ # 2,
492
+ # cv2.LINE_AA,
493
+ # )
494
+
495
+ # cv2.imshow("Yoga pose classification", output_frame)
496
+
497
+ # key = cv2.waitKey(1) & 0xFF
498
+ # if key == ord("q"):
499
+ # break
500
+ # elif key == ord("r"):
501
+ # # repetition_counter.reset()
502
+ # print("Counter reset!")
503
+
504
+ # frame_idx += 1
505
+ # pbar.update()
506
+
507
+ # finally:
508
+
509
+ # pose_tracker.close()
510
+ # video_cap.release()
511
+ # cv2.destroyAllWindows()
512
+
513
+
514
+ if __name__ == "__main__":
515
+ main()
yoga_position_gradio.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import cv2
3
+ import numpy as np
4
+ import re
5
+ import os
6
+ from mediapipe.python.solutions import drawing_utils as mp_drawing
7
+ import mediapipe as mp
8
+ from PoseClassification.pose_embedding import FullBodyPoseEmbedding
9
+ from PoseClassification.pose_classifier import PoseClassifier
10
+ from PoseClassification.utils import EMADictSmoothing
11
+ # from PoseClassification.utils import RepetitionCounter
12
+ from PoseClassification.visualize import PoseClassificationVisualizer
13
+ import argparse
14
+ from PoseClassification.utils import show_image
15
+
16
+
17
+ def check_major_current_position(positions_detected:dict, threshold_position) -> str:
18
+ '''
19
+ return the major position between those detected in frame, or return none
20
+
21
+ INPUTS
22
+ positions_detected :
23
+ dict of positions given by position classifier and pose_classification_filtered
24
+ {'pose1':8.0, 'pose2':2.0}
25
+ threshold_position :
26
+ values strictly below are considered "none" position
27
+
28
+ OUTPUT
29
+ major_position :
30
+ string with position (classes from classifier and "none")
31
+
32
+ '''
33
+ if max(positions_detected.values())<float(threshold_position):
34
+ major_position='none'
35
+ else:
36
+ major_position=max(positions_detected, key=positions_detected.get)
37
+ return major_position
38
+
39
+
40
+ def yoga_position_classifier():
41
+ #Load arguments
42
+ parser = argparse.ArgumentParser()
43
+
44
+ parser.add_argument("video_path", help="string video path in")
45
+ args = parser.parse_args()
46
+
47
+ video_path_in = args.video_path
48
+ direct_video=False
49
+
50
+ if video_path_in=="live":
51
+ video_path_in='data/live.mp4'
52
+ direct_video=True
53
+
54
+ video_path_out = re.sub(r'.mp4', r'_classified_video.mp4', video_path_in)
55
+ results_classification_path_out = re.sub(r'.mp4', r'_classified_results.csv', video_path_in)
56
+
57
+
58
+ # Initialize tracker, classifier and current position.
59
+ # Initialize tracker.
60
+ mp_pose = mp.solutions.pose
61
+ pose_tracker = mp_pose.Pose()
62
+ # Folder with pose class CSVs. That should be the same folder you used while
63
+ # building classifier to output CSVs.
64
+ pose_samples_folder = 'data/yoga_poses_csvs_out'
65
+ # Initialize embedder.
66
+ pose_embedder = FullBodyPoseEmbedding()
67
+ # Initialize classifier.
68
+ # Check that you are using the same parameters as during bootstrapping.
69
+ pose_classifier = PoseClassifier(
70
+ pose_samples_folder=pose_samples_folder,
71
+ pose_embedder=pose_embedder,
72
+ top_n_by_max_distance=30,
73
+ top_n_by_mean_distance=10)
74
+
75
+
76
+ # Initialize EMA smoothing.
77
+ pose_classification_filter = EMADictSmoothing(
78
+ window_size=10,
79
+ alpha=0.2)
80
+
81
+
82
+ # Initialize list of results
83
+ position_list=[]
84
+ frame_list=[]
85
+
86
+ # Instruction if direct flux video
87
+ if direct_video :
88
+ video_cap = cv2.VideoCapture(0)
89
+ # Instruction if path video
90
+ else :
91
+ assert type(video_path_in)==str, "Error in video path format, not a string. Abort."
92
+ # Open video and get video parameters and check if video is OK
93
+ video_cap = cv2.VideoCapture(video_path_in)
94
+ video_n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
95
+ assert type(video_n_frames)==float, 'Error in input video frames type. Abort.'
96
+ assert video_n_frames>0.0, 'Error in input video frames number : no frame. Abort.'
97
+
98
+ video_fps = video_cap.get(cv2.CAP_PROP_FPS)
99
+ video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
100
+ video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
101
+
102
+ class_names=['chair', 'cobra', 'dog', 'goddess', 'plank', 'tree', 'warrior', 'none']
103
+ position_threshold = 8.0
104
+
105
+ # Open output video.
106
+ out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height))
107
+
108
+ # Initialize results
109
+ frame_idx = 0
110
+ current_position = {"none":10.0}
111
+ output_frame = None
112
+
113
+ position_timer = 0
114
+ previous_position_major = 'none'
115
+
116
+ try:
117
+ with tqdm.tqdm(position=0, leave=True) as pbar:
118
+ while True:
119
+ # Get current time from beggining of video
120
+ time_sec = float(frame_idx*(1/video_fps))
121
+
122
+ # Get current major position (str)
123
+ current_position_major = check_major_current_position(current_position, position_threshold)
124
+
125
+ success, input_frame = video_cap.read()
126
+ if not success:
127
+ print("Unable to read input video frame, breaking!")
128
+ break
129
+
130
+ # Run pose tracker
131
+ input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
132
+ result = pose_tracker.process(image=input_frame_rgb)
133
+ pose_landmarks = result.pose_landmarks
134
+
135
+ # Prepare the output frame
136
+ output_frame = input_frame.copy()
137
+
138
+ # Add a white banner on top
139
+ banner_height = int(video_height//10)
140
+ output_frame[0:banner_height, :] = (255, 255, 255) # White color
141
+
142
+ # Load the logo image
143
+ logo = cv2.imread("src/logo_impredalam.jpg")
144
+ logo_height, logo_width = logo.shape[:2]
145
+ logo_height_rescaled = banner_height
146
+ logo_width_rescaled = int((logo_width*logo_height_rescaled)// logo_height )
147
+ logo = cv2.resize(logo, (logo_width_rescaled, logo_height_rescaled)) # Resize to banner scale
148
+
149
+ # Overlay the logo on the upper right corner
150
+ output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (logo)
151
+
152
+ # If landmarks are detected
153
+ if pose_landmarks is not None:
154
+ mp_drawing.draw_landmarks(
155
+ image=output_frame,
156
+ landmark_list=pose_landmarks,
157
+ connections=mp_pose.POSE_CONNECTIONS,)
158
+
159
+ # Get landmarks
160
+ frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
161
+ pose_landmarks = np.array(
162
+ [
163
+ [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
164
+ for lmk in pose_landmarks.landmark
165
+ ],
166
+ dtype=np.float32,)
167
+ assert pose_landmarks.shape == (33,3,), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
168
+
169
+ # Classify the pose on the current frame
170
+ pose_classification = pose_classifier(pose_landmarks)
171
+
172
+ # Smooth classification using EMA
173
+ pose_classification_filtered = pose_classification_filter(pose_classification)
174
+ current_position=pose_classification_filtered
175
+ current_position_major=check_major_current_position(current_position, position_threshold)
176
+
177
+ # If no landmarks are detected
178
+ else:
179
+
180
+ current_position={'none':10.0}
181
+ current_position_major=check_major_current_position(current_position, position_threshold)
182
+
183
+
184
+ # If landmarks or no landmarks detected :
185
+
186
+ # Compute position timer according to current and previous position
187
+ if current_position_major==previous_position_major:
188
+ #increase position_timer
189
+ position_timer+=(1/video_fps)
190
+ else:
191
+ previous_position_major=current_position_major
192
+ position_timer=0
193
+
194
+ # Display current position on frame
195
+ cv2.putText(
196
+ output_frame,
197
+ f"Pose: {current_position_major}",
198
+ (int(0+(1//50*video_width)), int(0+banner_height//3)), #coord
199
+ cv2.FONT_HERSHEY_SIMPLEX,
200
+ float(0.9*(video_height/video_width)), # Font size
201
+ (0, 0, 0), #color
202
+ 1, # Thinner line
203
+ cv2.LINE_AA,)
204
+
205
+ # Display current position timer on frame
206
+ cv2.putText(
207
+ output_frame,
208
+ f"Duration: {int(position_timer)} seconds",
209
+ (int(0+(1//50*video_width)), int(0+(2*banner_height)//3)), #coord
210
+ cv2.FONT_HERSHEY_SIMPLEX,
211
+ float(0.9*(video_height/video_width)), # Font size
212
+ (0, 0, 0), #color
213
+ 1, # Thinner line
214
+ cv2.LINE_AA,)
215
+
216
+ # Show output frame
217
+ cv2.imshow("Yoga position", output_frame)
218
+
219
+ # Add current_position (dict) and frame index to list (output file for debug)
220
+ position_list.append(current_position)
221
+ frame_list.append(frame_idx)
222
+ # Output file for debug
223
+ with open(results_classification_path_out, 'a') as f:
224
+ f.write(f'{frame_idx},{current_position}\n')
225
+
226
+ key = cv2.waitKey(1) & 0xFF
227
+ if key == ord("q"):
228
+ break
229
+ elif key == ord("r"):
230
+ current_position = {'none':10.0}
231
+ print("Position reset !")
232
+
233
+ frame_idx += 1
234
+ pbar.update()
235
+
236
+ finally:
237
+ pose_tracker.close()
238
+ video_cap.release()
239
+ cv2.destroyAllWindows()
240
+ # Close output video.
241
+ out_video.release()
242
+
243
+ return frame_list, position_list
244
+
245
+
246
+
247
+ if __name__ == "__main__":
248
+ yoga_position_classifier()