Spaces:
Sleeping
Sleeping
import cv2 | |
from matplotlib import pyplot as plt | |
import numpy as np | |
import os, csv | |
from PIL import Image, ImageDraw | |
import sys | |
import tqdm | |
from mediapipe.python.solutions import drawing_utils as mp_drawing | |
from mediapipe.python.solutions import pose as mp_pose | |
from PoseClassification.utils import show_image | |
class BootstrapHelper(object): | |
"""Helps to bootstrap images and filter pose samples for classification.""" | |
def __init__(self, images_in_folder, images_out_folder, csvs_out_folder): | |
self._images_in_folder = images_in_folder | |
self._images_out_folder = images_out_folder | |
self._csvs_out_folder = csvs_out_folder | |
# Get list of pose classes and print image statistics. | |
self._pose_class_names = sorted( | |
[n for n in os.listdir(self._images_in_folder) if not n.startswith(".")] | |
) | |
def bootstrap(self, per_pose_class_limit=None): | |
"""Bootstraps images in a given folder. | |
Required image in folder (same use for image out folder): | |
pushups_up/ | |
image_001.jpg | |
image_002.jpg | |
... | |
pushups_down/ | |
image_001.jpg | |
image_002.jpg | |
... | |
... | |
Produced CSVs out folder: | |
pushups_up.csv | |
pushups_down.csv | |
Produced CSV structure with pose 3D landmarks: | |
sample_00001,x1,y1,z1,x2,y2,z2,.... | |
sample_00002,x1,y1,z1,x2,y2,z2,.... | |
""" | |
# Create output folder for CVSs. | |
if not os.path.exists(self._csvs_out_folder): | |
os.makedirs(self._csvs_out_folder) | |
for pose_class_name in self._pose_class_names: | |
print("Bootstrapping ", pose_class_name, file=sys.stderr) | |
# Paths for the pose class. | |
images_in_folder = os.path.join(self._images_in_folder, pose_class_name) | |
images_out_folder = os.path.join(self._images_out_folder, pose_class_name) | |
csv_out_path = os.path.join(self._csvs_out_folder, pose_class_name + ".csv") | |
if not os.path.exists(images_out_folder): | |
os.makedirs(images_out_folder) | |
with open(csv_out_path, "w") as csv_out_file: | |
csv_out_writer = csv.writer( | |
csv_out_file, delimiter=",", quoting=csv.QUOTE_MINIMAL | |
) | |
# Get list of images. | |
image_names = sorted( | |
[n for n in os.listdir(images_in_folder) if not n.startswith(".")] | |
) | |
if per_pose_class_limit is not None: | |
image_names = image_names[:per_pose_class_limit] | |
# Bootstrap every image. | |
for image_name in tqdm.tqdm(image_names): | |
# Load image. | |
input_frame = cv2.imread(os.path.join(images_in_folder, image_name)) | |
input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
# Initialize fresh pose tracker and run it. | |
# with mp_pose.Pose(upper_body_only=False) as pose_tracker: | |
with mp_pose.Pose() as pose_tracker: | |
result = pose_tracker.process(image=input_frame) | |
pose_landmarks = result.pose_landmarks | |
# Save image with pose prediction (if pose was detected). | |
output_frame = input_frame.copy() | |
if pose_landmarks is not None: | |
mp_drawing.draw_landmarks( | |
image=output_frame, | |
landmark_list=pose_landmarks, | |
connections=mp_pose.POSE_CONNECTIONS, | |
) | |
output_frame = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR) | |
cv2.imwrite( | |
os.path.join(images_out_folder, image_name), output_frame | |
) | |
# Save landmarks if pose was detected. | |
if pose_landmarks is not None: | |
# Get landmarks. | |
frame_height, frame_width = ( | |
output_frame.shape[0], | |
output_frame.shape[1], | |
) | |
pose_landmarks = np.array( | |
[ | |
[ | |
lmk.x * frame_width, | |
lmk.y * frame_height, | |
lmk.z * frame_width, | |
] | |
for lmk in pose_landmarks.landmark | |
], | |
dtype=np.float32, | |
) | |
assert pose_landmarks.shape == ( | |
33, | |
3, | |
), "Unexpected landmarks shape: {}".format(pose_landmarks.shape) | |
csv_out_writer.writerow( | |
[image_name] + pose_landmarks.flatten().astype(str).tolist() | |
) | |
# Draw XZ projection and concatenate with the image. | |
projection_xz = self._draw_xz_projection( | |
output_frame=output_frame, pose_landmarks=pose_landmarks | |
) | |
output_frame = np.concatenate((output_frame, projection_xz), axis=1) | |
def _draw_xz_projection(self, output_frame, pose_landmarks, r=0.5, color="red"): | |
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
img = Image.new("RGB", (frame_width, frame_height), color="white") | |
if pose_landmarks is None: | |
return np.asarray(img) | |
# Scale radius according to the image width. | |
r *= frame_width * 0.01 | |
draw = ImageDraw.Draw(img) | |
for idx_1, idx_2 in mp_pose.POSE_CONNECTIONS: | |
# Flip Z and move hips center to the center of the image. | |
x1, y1, z1 = pose_landmarks[idx_1] * [1, 1, -1] + [0, 0, frame_height * 0.5] | |
x2, y2, z2 = pose_landmarks[idx_2] * [1, 1, -1] + [0, 0, frame_height * 0.5] | |
draw.ellipse([x1 - r, z1 - r, x1 + r, z1 + r], fill=color) | |
draw.ellipse([x2 - r, z2 - r, x2 + r, z2 + r], fill=color) | |
draw.line([x1, z1, x2, z2], width=int(r), fill=color) | |
return np.asarray(img) | |
def align_images_and_csvs(self, print_removed_items=False): | |
"""Makes sure that image folders and CSVs have the same sample. | |
Leaves only intersetion of samples in both image folders and CSVs. | |
""" | |
for pose_class_name in self._pose_class_names: | |
# Paths for the pose class. | |
images_out_folder = os.path.join(self._images_out_folder, pose_class_name) | |
csv_out_path = os.path.join(self._csvs_out_folder, pose_class_name + ".csv") | |
# Read CSV into memory. | |
rows = [] | |
with open(csv_out_path) as csv_out_file: | |
csv_out_reader = csv.reader(csv_out_file, delimiter=",") | |
for row in csv_out_reader: | |
rows.append(row) | |
# Image names left in CSV. | |
image_names_in_csv = [] | |
# Re-write the CSV removing lines without corresponding images. | |
with open(csv_out_path, "w") as csv_out_file: | |
csv_out_writer = csv.writer( | |
csv_out_file, delimiter=",", quoting=csv.QUOTE_MINIMAL | |
) | |
for row in rows: | |
image_name = row[0] | |
image_path = os.path.join(images_out_folder, image_name) | |
if os.path.exists(image_path): | |
image_names_in_csv.append(image_name) | |
csv_out_writer.writerow(row) | |
elif print_removed_items: | |
print("Removed image from CSV: ", image_path) | |
# Remove images without corresponding line in CSV. | |
for image_name in os.listdir(images_out_folder): | |
if image_name not in image_names_in_csv: | |
image_path = os.path.join(images_out_folder, image_name) | |
os.remove(image_path) | |
if print_removed_items: | |
print("Removed image from folder: ", image_path) | |
def analyze_outliers(self, outliers): | |
"""Classifies each sample against all other to find outliers. | |
If sample is classified differrently than the original class - it should | |
either be deleted or more similar samples should be added. | |
""" | |
for outlier in outliers: | |
image_path = os.path.join( | |
self._images_out_folder, outlier.sample.class_name, outlier.sample.name | |
) | |
print("Outlier") | |
print(" sample path = ", image_path) | |
print(" sample class = ", outlier.sample.class_name) | |
print(" detected class = ", outlier.detected_class) | |
print(" all classes = ", outlier.all_classes) | |
img = cv2.imread(image_path) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
show_image(img, figsize=(20, 20)) | |
def remove_outliers(self, outliers): | |
"""Removes outliers from the image folders.""" | |
for outlier in outliers: | |
image_path = os.path.join( | |
self._images_out_folder, outlier.sample.class_name, outlier.sample.name | |
) | |
os.remove(image_path) | |
def print_images_in_statistics(self): | |
"""Prints statistics from the input image folder.""" | |
self._print_images_statistics(self._images_in_folder, self._pose_class_names) | |
def print_images_out_statistics(self): | |
"""Prints statistics from the output image folder.""" | |
self._print_images_statistics(self._images_out_folder, self._pose_class_names) | |
def _print_images_statistics(self, images_folder, pose_class_names): | |
print("Number of images per pose class:") | |
for pose_class_name in pose_class_names: | |
n_images = len( | |
[ | |
n | |
for n in os.listdir(os.path.join(images_folder, pose_class_name)) | |
if not n.startswith(".") | |
] | |
) | |
print(" {}: {}".format(pose_class_name, n_images)) | |