Spaces:
Configuration error
Configuration error
""" | |
Copyright (c) Microsoft Corporation. | |
Licensed under the MIT license. | |
""" | |
import cv2 | |
import math | |
import json | |
from PIL import Image | |
import os.path as op | |
import numpy as np | |
import code | |
from custom_mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile | |
from custom_mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml | |
from custom_mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa | |
import torch | |
import torchvision.transforms as transforms | |
class HandMeshTSVDataset(object): | |
def __init__(self, args, img_file, label_file=None, hw_file=None, | |
linelist_file=None, is_train=True, cv2_output=False, scale_factor=1): | |
self.args = args | |
self.img_file = img_file | |
self.label_file = label_file | |
self.hw_file = hw_file | |
self.linelist_file = linelist_file | |
self.img_tsv = self.get_tsv_file(img_file) | |
self.label_tsv = None if label_file is None else self.get_tsv_file(label_file) | |
self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file) | |
if self.is_composite: | |
assert op.isfile(self.linelist_file) | |
self.line_list = [i for i in range(self.hw_tsv.num_rows())] | |
else: | |
self.line_list = load_linelist_file(linelist_file) | |
self.cv2_output = cv2_output | |
self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
self.is_train = is_train | |
self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor] | |
self.noise_factor = 0.4 | |
self.rot_factor = 90 # Random rotation in the range [-rot_factor, rot_factor] | |
self.img_res = 224 | |
self.image_keys = self.prepare_image_keys() | |
self.joints_definition = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', | |
'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') | |
self.root_index = self.joints_definition.index('Wrist') | |
def get_tsv_file(self, tsv_file): | |
if tsv_file: | |
if self.is_composite: | |
return CompositeTSVFile(tsv_file, self.linelist_file, | |
root=self.root) | |
tsv_path = find_file_path_in_yaml(tsv_file, self.root) | |
return TSVFile(tsv_path) | |
def get_valid_tsv(self): | |
# sorted by file size | |
if self.hw_tsv: | |
return self.hw_tsv | |
if self.label_tsv: | |
return self.label_tsv | |
def prepare_image_keys(self): | |
tsv = self.get_valid_tsv() | |
return [tsv.get_key(i) for i in range(tsv.num_rows())] | |
def prepare_image_key_to_index(self): | |
tsv = self.get_valid_tsv() | |
return {tsv.get_key(i) : i for i in range(tsv.num_rows())} | |
def augm_params(self): | |
"""Get augmentation parameters.""" | |
flip = 0 # flipping | |
pn = np.ones(3) # per channel pixel-noise | |
if self.args.multiscale_inference == False: | |
rot = 0 # rotation | |
sc = 1.0 # scaling | |
elif self.args.multiscale_inference == True: | |
rot = self.args.rot | |
sc = self.args.sc | |
if self.is_train: | |
sc = 1.0 | |
# Each channel is multiplied with a number | |
# in the area [1-opt.noiseFactor,1+opt.noiseFactor] | |
pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3) | |
# The rotation is a number in the area [-2*rotFactor, 2*rotFactor] | |
rot = min(2*self.rot_factor, | |
max(-2*self.rot_factor, np.random.randn()*self.rot_factor)) | |
# The scale is multiplied with a number | |
# in the area [1-scaleFactor,1+scaleFactor] | |
sc = min(1+self.scale_factor, | |
max(1-self.scale_factor, np.random.randn()*self.scale_factor+1)) | |
# but it is zero with probability 3/5 | |
if np.random.uniform() <= 0.6: | |
rot = 0 | |
return flip, pn, rot, sc | |
def rgb_processing(self, rgb_img, center, scale, rot, flip, pn): | |
"""Process rgb image and do augmentation.""" | |
rgb_img = crop(rgb_img, center, scale, | |
[self.img_res, self.img_res], rot=rot) | |
# flip the image | |
if flip: | |
rgb_img = flip_img(rgb_img) | |
# in the rgb image we add pixel noise in a channel-wise manner | |
rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0])) | |
rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1])) | |
rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2])) | |
# (3,224,224),float,[0,1] | |
rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0 | |
return rgb_img | |
def j2d_processing(self, kp, center, scale, r, f): | |
"""Process gt 2D keypoints and apply all augmentation transforms.""" | |
nparts = kp.shape[0] | |
for i in range(nparts): | |
kp[i,0:2] = transform(kp[i,0:2]+1, center, scale, | |
[self.img_res, self.img_res], rot=r) | |
# convert to normalized coordinates | |
kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1. | |
# flip the x coordinates | |
if f: | |
kp = flip_kp(kp) | |
kp = kp.astype('float32') | |
return kp | |
def j3d_processing(self, S, r, f): | |
"""Process gt 3D keypoints and apply all augmentation transforms.""" | |
# in-plane rotation | |
rot_mat = np.eye(3) | |
if not r == 0: | |
rot_rad = -r * np.pi / 180 | |
sn,cs = np.sin(rot_rad), np.cos(rot_rad) | |
rot_mat[0,:2] = [cs, -sn] | |
rot_mat[1,:2] = [sn, cs] | |
S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1]) | |
# flip the x coordinates | |
if f: | |
S = flip_kp(S) | |
S = S.astype('float32') | |
return S | |
def pose_processing(self, pose, r, f): | |
"""Process SMPL theta parameters and apply all augmentation transforms.""" | |
# rotation or the pose parameters | |
pose = pose.astype('float32') | |
pose[:3] = rot_aa(pose[:3], r) | |
# flip the pose parameters | |
if f: | |
pose = flip_pose(pose) | |
# (72),float | |
pose = pose.astype('float32') | |
return pose | |
def get_line_no(self, idx): | |
return idx if self.line_list is None else self.line_list[idx] | |
def get_image(self, idx): | |
line_no = self.get_line_no(idx) | |
row = self.img_tsv[line_no] | |
# use -1 to support old format with multiple columns. | |
cv2_im = img_from_base64(row[-1]) | |
if self.cv2_output: | |
return cv2_im.astype(np.float32, copy=True) | |
cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB) | |
return cv2_im | |
def get_annotations(self, idx): | |
line_no = self.get_line_no(idx) | |
if self.label_tsv is not None: | |
row = self.label_tsv[line_no] | |
annotations = json.loads(row[1]) | |
return annotations | |
else: | |
return [] | |
def get_target_from_annotations(self, annotations, img_size, idx): | |
# This function will be overwritten by each dataset to | |
# decode the labels to specific formats for each task. | |
return annotations | |
def get_img_info(self, idx): | |
if self.hw_tsv is not None: | |
line_no = self.get_line_no(idx) | |
row = self.hw_tsv[line_no] | |
try: | |
# json string format with "height" and "width" being the keys | |
return json.loads(row[1])[0] | |
except ValueError: | |
# list of strings representing height and width in order | |
hw_str = row[1].split(' ') | |
hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} | |
return hw_dict | |
def get_img_key(self, idx): | |
line_no = self.get_line_no(idx) | |
# based on the overhead of reading each row. | |
if self.hw_tsv: | |
return self.hw_tsv[line_no][0] | |
elif self.label_tsv: | |
return self.label_tsv[line_no][0] | |
else: | |
return self.img_tsv[line_no][0] | |
def __len__(self): | |
if self.line_list is None: | |
return self.img_tsv.num_rows() | |
else: | |
return len(self.line_list) | |
def __getitem__(self, idx): | |
img = self.get_image(idx) | |
img_key = self.get_img_key(idx) | |
annotations = self.get_annotations(idx) | |
annotations = annotations[0] | |
center = annotations['center'] | |
scale = annotations['scale'] | |
has_2d_joints = annotations['has_2d_joints'] | |
has_3d_joints = annotations['has_3d_joints'] | |
joints_2d = np.asarray(annotations['2d_joints']) | |
joints_3d = np.asarray(annotations['3d_joints']) | |
if joints_2d.ndim==3: | |
joints_2d = joints_2d[0] | |
if joints_3d.ndim==3: | |
joints_3d = joints_3d[0] | |
# Get SMPL parameters, if available | |
has_smpl = np.asarray(annotations['has_smpl']) | |
pose = np.asarray(annotations['pose']) | |
betas = np.asarray(annotations['betas']) | |
# Get augmentation parameters | |
flip,pn,rot,sc = self.augm_params() | |
# Process image | |
img = self.rgb_processing(img, center, sc*scale, rot, flip, pn) | |
img = torch.from_numpy(img).float() | |
# Store image before normalization to use it in visualization | |
transfromed_img = self.normalize_img(img) | |
# normalize 3d pose by aligning the wrist as the root (at origin) | |
root_coord = joints_3d[self.root_index,:-1] | |
joints_3d[:,:-1] = joints_3d[:,:-1] - root_coord[None,:] | |
# 3d pose augmentation (random flip + rotation, consistent to image and SMPL) | |
joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip) | |
# 2d pose augmentation | |
joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip) | |
################################### | |
# Masking percantage | |
# We observe that 0% or 5% works better for 3D hand mesh | |
# We think this is probably becasue 3D vertices are quite sparse in the down-sampled hand mesh | |
mvm_percent = 0.0 # or 0.05 | |
################################### | |
mjm_mask = np.ones((21,1)) | |
if self.is_train: | |
num_joints = 21 | |
pb = np.random.random_sample() | |
masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked | |
indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num) | |
mjm_mask[indices,:] = 0.0 | |
mjm_mask = torch.from_numpy(mjm_mask).float() | |
mvm_mask = np.ones((195,1)) | |
if self.is_train: | |
num_vertices = 195 | |
pb = np.random.random_sample() | |
masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked | |
indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num) | |
mvm_mask[indices,:] = 0.0 | |
mvm_mask = torch.from_numpy(mvm_mask).float() | |
meta_data = {} | |
meta_data['ori_img'] = img | |
meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float() | |
meta_data['betas'] = torch.from_numpy(betas).float() | |
meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float() | |
meta_data['has_3d_joints'] = has_3d_joints | |
meta_data['has_smpl'] = has_smpl | |
meta_data['mjm_mask'] = mjm_mask | |
meta_data['mvm_mask'] = mvm_mask | |
# Get 2D keypoints and apply augmentation transforms | |
meta_data['has_2d_joints'] = has_2d_joints | |
meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float() | |
meta_data['scale'] = float(sc * scale) | |
meta_data['center'] = np.asarray(center).astype(np.float32) | |
return img_key, transfromed_img, meta_data | |
class HandMeshTSVYamlDataset(HandMeshTSVDataset): | |
""" TSVDataset taking a Yaml file for easy function call | |
""" | |
def __init__(self, args, yaml_file, is_train=True, cv2_output=False, scale_factor=1): | |
self.cfg = load_from_yaml_file(yaml_file) | |
self.is_composite = self.cfg.get('composite', False) | |
self.root = op.dirname(yaml_file) | |
if self.is_composite==False: | |
img_file = find_file_path_in_yaml(self.cfg['img'], self.root) | |
label_file = find_file_path_in_yaml(self.cfg.get('label', None), | |
self.root) | |
hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) | |
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), | |
self.root) | |
else: | |
img_file = self.cfg['img'] | |
hw_file = self.cfg['hw'] | |
label_file = self.cfg.get('label', None) | |
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), | |
self.root) | |
super(HandMeshTSVYamlDataset, self).__init__( | |
args, img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor) | |