Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import json | |
import os | |
import numpy as np | |
from scipy.linalg import polar | |
from scipy.spatial.transform import Rotation | |
import torch | |
from torch.utils.data import Dataset | |
from .utils import exists | |
from .utils.logger import print_log | |
def create_dataset(cfg_dataset): | |
kwargs = cfg_dataset | |
name = kwargs.pop('name') | |
dataset = get_dataset(name)(**kwargs) | |
print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}") | |
return dataset | |
def get_dataset(name): | |
return { | |
'base': PrimitiveDataset, | |
}[name] | |
SHAPE_CODE = { | |
'CubeBevel': 0, | |
'SphereSharp': 1, | |
'CylinderSharp': 2, | |
} | |
class PrimitiveDataset(Dataset): | |
def __init__(self, | |
pc_dir, | |
bs_dir, | |
max_length=144, | |
range_scale=[0, 1], | |
range_rotation=[-180, 180], | |
range_translation=[-1, 1], | |
rotation_type='euler', | |
pc_format='pc', | |
): | |
self.data_filename = os.listdir(pc_dir) | |
self.pc_dir = pc_dir | |
self.max_length = max_length | |
self.range_scale = range_scale | |
self.range_rotation = range_rotation | |
self.range_translation = range_translation | |
self.rotation_type = rotation_type | |
self.pc_format = pc_format | |
with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f: | |
basic_shapes = json.load(f) | |
self.typeid_map = { | |
1101002001034001: 'CubeBevel', | |
1101002001034010: 'SphereSharp', | |
1101002001034002: 'CylinderSharp', | |
} | |
def __len__(self): | |
return len(self.data_filename) | |
def __getitem__(self, idx): | |
pc_file = os.path.join(self.pc_dir, self.data_filename[idx]) | |
pc = o3d.io.read_point_cloud(pc_file) | |
model_data = {} | |
points = torch.from_numpy(np.asarray(pc.points)).float() | |
colors = torch.from_numpy(np.asarray(pc.colors)).float() | |
normals = torch.from_numpy(np.asarray(pc.normals)).float() | |
if self.pc_format == 'pc': | |
model_data['pc'] = torch.concatenate([points, colors], dim=-1).T | |
elif self.pc_format == 'pn': | |
model_data['pc'] = torch.concatenate([points, normals], dim=-1) | |
elif self.pc_format == 'pcn': | |
model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1) | |
else: | |
raise ValueError(f'invalid pc_format: {self.pc_format}') | |
return model_data | |
def get_typeid_shapename_mapping(shapenames, config_data): | |
typeid_map = {} | |
for info in config_data.values(): | |
for shapename in shapenames: | |
if shapename[3:-4] in info['bpPath']: | |
typeid_map[info['typeId']] = shapename.split('_')[3] | |
break | |
return typeid_map | |
def check_valid_range(data, value_range): | |
lo, hi = value_range | |
assert hi > lo | |
return np.logical_and(data >= lo, hi <= hi).all() | |
def quat_to_euler(quat, degree=True): | |
return Rotation.from_quat(quat).as_euler('XYZ', degrees=degree) | |
def quat_to_rotvec(quat, degree=True): | |
return Rotation.from_quat(quat).as_rotvec(degrees=degree) | |
def rotate_axis(euler): | |
trans = np.eye(4, 4) | |
trans[:3, :3] = Rotation.from_euler('xyz', euler).as_matrix() | |
return trans | |
def SRT_quat_to_matrix(scale, quat, translation): | |
rotation_matrix = Rotation.from_quat(quat).as_matrix() | |
transform_matrix = np.eye(4) | |
transform_matrix[:3, :3] = rotation_matrix * scale | |
transform_matrix[:3, 3] = translation | |
return transform_matrix | |
def matrix_to_SRT_quat2(transform_matrix): # Polar Decomposition | |
transform_matrix = np.array(transform_matrix) | |
translation = transform_matrix[:3, 3] | |
rotation_matrix, scale_matrix = polar(transform_matrix[:3,:3]) | |
quat = Rotation.from_matrix(rotation_matrix).as_quat() | |
scale = np.diag(scale_matrix) | |
return scale, quat, translation | |
def apply_transform_to_block(block, trans_aug): | |
precision_loss = False | |
trans = SRT_quat_to_matrix( | |
block['data']['scale'], | |
block['data']['rotation'], | |
block['data']['location'] | |
) | |
trans = trans_aug @ trans | |
scale, quat, translation = matrix_to_SRT_quat2(trans) | |
trans_rec = SRT_quat_to_matrix(scale, quat, translation) | |
if not np.allclose(trans, trans_rec, atol=1e-1): | |
precision_loss = True | |
return precision_loss, {} | |
new_block = copy.deepcopy(block) | |
new_block['data']['scale'] = scale.tolist() | |
new_block['data']['rotation'] = quat.tolist() | |
new_block['data']['location'] = translation.tolist() | |
return precision_loss, new_block | |