YOGAI / PoseClassification /pose_classifier.py
1mpreccable's picture
Upload 35 files
0ccc9b6 verified
raw
history blame
7.27 kB
import numpy as np
import os, csv
class PoseSample(object):
def __init__(self, name, landmarks, class_name, embedding):
self.name = name
self.landmarks = landmarks
self.class_name = class_name
self.embedding = embedding
class PoseSampleOutlier(object):
def __init__(self, sample, detected_class, all_classes):
self.sample = sample
self.detected_class = detected_class
self.all_classes = all_classes
class PoseClassifier(object):
"""Classifies pose landmarks."""
def __init__(
self,
pose_samples_folder,
pose_embedder,
file_extension="csv",
file_separator=",",
n_landmarks=33,
n_dimensions=3,
top_n_by_max_distance=30,
top_n_by_mean_distance=10,
axes_weights=(1.0, 1.0, 0.2),
):
self._pose_embedder = pose_embedder
self._n_landmarks = n_landmarks
self._n_dimensions = n_dimensions
self._top_n_by_max_distance = top_n_by_max_distance
self._top_n_by_mean_distance = top_n_by_mean_distance
self._axes_weights = axes_weights
self._pose_samples = self._load_pose_samples(
pose_samples_folder,
file_extension,
file_separator,
n_landmarks,
n_dimensions,
pose_embedder,
)
def _load_pose_samples(
self,
pose_samples_folder,
file_extension,
file_separator,
n_landmarks,
n_dimensions,
pose_embedder,
):
"""Loads pose samples from a given folder.
Required folder structure:
neutral_standing.csv
pushups_down.csv
pushups_up.csv
squats_down.csv
...
Required CSV structure:
sample_00001,x1,y1,z1,x2,y2,z2,....
sample_00002,x1,y1,z1,x2,y2,z2,....
...
"""
# Each file in the folder represents one pose class.
file_names = [
name
for name in os.listdir(pose_samples_folder)
if name.endswith(file_extension)
]
pose_samples = []
for file_name in file_names:
# Use file name as pose class name.
class_name = file_name[: -(len(file_extension) + 1)]
# Parse CSV.
with open(os.path.join(pose_samples_folder, file_name)) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=file_separator)
for row in csv_reader:
assert (
len(row) == n_landmarks * n_dimensions + 1
), "Wrong number of values: {}".format(len(row))
landmarks = np.array(row[1:], np.float32).reshape(
[n_landmarks, n_dimensions]
)
pose_samples.append(
PoseSample(
name=row[0],
landmarks=landmarks,
class_name=class_name,
embedding=pose_embedder(landmarks),
)
)
return pose_samples
def find_pose_sample_outliers(self):
"""Classifies each sample against the entire database."""
# Find outliers in target poses
outliers = []
for sample in self._pose_samples:
# Find nearest poses for the target one.
pose_landmarks = sample.landmarks.copy()
pose_classification = self.__call__(pose_landmarks)
class_names = [
class_name
for class_name, count in pose_classification.items()
if count == max(pose_classification.values())
]
# Sample is an outlier if nearest poses have different class or more than
# one pose class is detected as nearest.
if sample.class_name not in class_names or len(class_names) != 1:
outliers.append(
PoseSampleOutlier(sample, class_names, pose_classification)
)
return outliers
def __call__(self, pose_landmarks):
"""Classifies given pose.
Classification is done in two stages:
* First we pick top-N samples by MAX distance. It allows to remove samples
that are almost the same as given pose, but has few joints bent in the
other direction.
* Then we pick top-N samples by MEAN distance. After outliers are removed
on a previous step, we can pick samples that are closes on average.
Args:
pose_landmarks: NumPy array with 3D landmarks of shape (N, 3).
Returns:
Dictionary with count of nearest pose samples from the database. Sample:
{
'pushups_down': 8,
'pushups_up': 2,
}
"""
# Check that provided and target poses have the same shape.
assert pose_landmarks.shape == (
self._n_landmarks,
self._n_dimensions,
), "Unexpected shape: {}".format(pose_landmarks.shape)
# Get given pose embedding.
pose_embedding = self._pose_embedder(pose_landmarks)
flipped_pose_embedding = self._pose_embedder(
pose_landmarks * np.array([-1, 1, 1])
)
# Filter by max distance.
#
# That helps to remove outliers - poses that are almost the same as the
# given one, but has one joint bent into another direction and actually
# represnt a different pose class.
max_dist_heap = []
for sample_idx, sample in enumerate(self._pose_samples):
max_dist = min(
np.max(np.abs(sample.embedding - pose_embedding) * self._axes_weights),
np.max(
np.abs(sample.embedding - flipped_pose_embedding)
* self._axes_weights
),
)
max_dist_heap.append([max_dist, sample_idx])
max_dist_heap = sorted(max_dist_heap, key=lambda x: x[0])
max_dist_heap = max_dist_heap[: self._top_n_by_max_distance]
# Filter by mean distance.
#
# After removing outliers we can find the nearest pose by mean distance.
mean_dist_heap = []
for _, sample_idx in max_dist_heap:
sample = self._pose_samples[sample_idx]
mean_dist = min(
np.mean(np.abs(sample.embedding - pose_embedding) * self._axes_weights),
np.mean(
np.abs(sample.embedding - flipped_pose_embedding)
* self._axes_weights
),
)
mean_dist_heap.append([mean_dist, sample_idx])
mean_dist_heap = sorted(mean_dist_heap, key=lambda x: x[0])
mean_dist_heap = mean_dist_heap[: self._top_n_by_mean_distance]
# Collect results into map: (class_name -> n_samples)
class_names = [
self._pose_samples[sample_idx].class_name
for _, sample_idx in mean_dist_heap
]
result = {
class_name: class_names.count(class_name) for class_name in set(class_names)
}
return result