ICON / lib /common /cloth_extraction.py
Yuliang's picture
done
2d5f249
raw history blame
No virus
6.25 kB
import numpy as np
import json
import os
import itertools
import trimesh
from matplotlib.path import Path
from collections import Counter
from sklearn.neighbors import KNeighborsClassifier
def load_segmentation(path, shape):
"""
Get a segmentation mask for a given image
Arguments:
path: path to the segmentation json file
shape: shape of the output mask
Returns:
Returns a segmentation mask
"""
with open(path) as json_file:
dict = json.load(json_file)
segmentations = []
for key, val in dict.items():
if not key.startswith('item'):
continue
# Each item can have multiple polygons. Combine them to one
# segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
# segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
coordinates = []
for segmentation_coord in val['segmentation']:
# The format before is [x1,y1, x2, y2, ....]
x = segmentation_coord[::2]
y = segmentation_coord[1::2]
xy = np.vstack((x, y)).T
coordinates.append(xy)
segmentations.append(
{'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates})
return segmentations
def smpl_to_recon_labels(recon, smpl, k=1):
"""
Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
Arguments:
recon: trimesh object (fully clothed model)
shape: trimesh object (smpl model)
k: number of nearest neighbours to use
Returns:
Returns a dictionary containing the bodypart and the corresponding indices
"""
smpl_vert_segmentation = json.load(
open(os.path.join(os.path.dirname(__file__), 'smpl_vert_segmentation.json')))
n = smpl.vertices.shape[0]
y = np.array([None] * n)
for key, val in smpl_vert_segmentation.items():
y[val] = key
classifier = KNeighborsClassifier(n_neighbors=1)
classifier.fit(smpl.vertices, y)
y_pred = classifier.predict(recon.vertices)
recon_labels = {}
for key in smpl_vert_segmentation.keys():
recon_labels[key] = list(np.argwhere(
y_pred == key).flatten().astype(int))
return recon_labels
def extract_cloth(recon, segmentation, K, R, t, smpl=None):
"""
Extract a portion of a mesh using 2d segmentation coordinates
Arguments:
recon: fully clothed mesh
seg_coord: segmentation coordinates in 2D (NDC)
K: intrinsic matrix of the projection
R: rotation matrix of the projection
t: translation vector of the projection
Returns:
Returns a submesh using the segmentation coordinates
"""
seg_coord = segmentation['coord_normalized']
mesh = trimesh.Trimesh(recon.vertices, recon.faces)
extrinsic = np.zeros((3, 4))
extrinsic[:3, :3] = R
extrinsic[:, 3] = t
P = K[:3, :3] @ extrinsic
P_inv = np.linalg.pinv(P)
# Each segmentation can contain multiple polygons
# We need to check them separately
points_so_far = []
faces = recon.faces
for polygon in seg_coord:
n = len(polygon)
coords_h = np.hstack((polygon, np.ones((n, 1))))
# Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
XYZ = P_inv @ coords_h[:, :, None]
XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
XYZ = XYZ[:, :3] / XYZ[:, 3, None]
p = Path(XYZ[:, :2])
grid = p.contains_points(recon.vertices[:, :2])
indeces = np.argwhere(grid == True)
points_so_far += list(indeces.flatten())
if smpl is not None:
num_verts = recon.vertices.shape[0]
recon_labels = smpl_to_recon_labels(recon, smpl)
body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand']
type = segmentation['type_id']
# Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
# https://github.com/switchablenorms/DeepFashion2
# Short sleeve clothes
if type == 1 or type == 3 or type == 10:
body_parts_to_remove += ['leftForeArm', 'rightForeArm']
# No sleeves at all or lower body clothes
elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
body_parts_to_remove += ['leftForeArm',
'rightForeArm', 'leftArm', 'rightArm']
# Shorts
elif type == 7:
body_parts_to_remove += ['leftLeg', 'rightLeg',
'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']
verts_to_remove = list(itertools.chain.from_iterable(
[recon_labels[part] for part in body_parts_to_remove]))
label_mask = np.zeros(num_verts, dtype=bool)
label_mask[verts_to_remove] = True
seg_mask = np.zeros(num_verts, dtype=bool)
seg_mask[points_so_far] = True
# Remove points that belong to other bodyparts
# If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
combine_mask = np.zeros(num_verts, dtype=bool)
combine_mask[points_so_far] = True
combine_mask[extra_verts_to_remove] = False
all_indices = np.argwhere(combine_mask == True).flatten()
i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
mask = np.zeros(len(recon.faces), dtype=bool)
if len(faces_to_keep) > 0:
mask[faces_to_keep] = True
mesh.update_faces(mask)
mesh.remove_unreferenced_vertices()
# mesh.rezero()
return mesh
return None