import copy import json import os import numpy as np from scipy.linalg import polar from scipy.spatial.transform import Rotation import torch from torch.utils.data import Dataset from .utils import exists from .utils.logger import print_log def create_dataset(cfg_dataset): kwargs = cfg_dataset name = kwargs.pop('name') dataset = get_dataset(name)(**kwargs) print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}") return dataset def get_dataset(name): return { 'base': PrimitiveDataset, }[name] SHAPE_CODE = { 'CubeBevel': 0, 'SphereSharp': 1, 'CylinderSharp': 2, } class PrimitiveDataset(Dataset): def __init__(self, pc_dir, bs_dir, max_length=144, range_scale=[0, 1], range_rotation=[-180, 180], range_translation=[-1, 1], rotation_type='euler', pc_format='pc', ): self.data_filename = os.listdir(pc_dir) self.pc_dir = pc_dir self.max_length = max_length self.range_scale = range_scale self.range_rotation = range_rotation self.range_translation = range_translation self.rotation_type = rotation_type self.pc_format = pc_format with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f: basic_shapes = json.load(f) self.typeid_map = { 1101002001034001: 'CubeBevel', 1101002001034010: 'SphereSharp', 1101002001034002: 'CylinderSharp', } def __len__(self): return len(self.data_filename) def __getitem__(self, idx): pc_file = os.path.join(self.pc_dir, self.data_filename[idx]) pc = o3d.io.read_point_cloud(pc_file) model_data = {} points = torch.from_numpy(np.asarray(pc.points)).float() colors = torch.from_numpy(np.asarray(pc.colors)).float() normals = torch.from_numpy(np.asarray(pc.normals)).float() if self.pc_format == 'pc': model_data['pc'] = torch.concatenate([points, colors], dim=-1).T elif self.pc_format == 'pn': model_data['pc'] = torch.concatenate([points, normals], dim=-1) elif self.pc_format == 'pcn': model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1) else: raise ValueError(f'invalid pc_format: {self.pc_format}') return model_data def get_typeid_shapename_mapping(shapenames, config_data): typeid_map = {} for info in config_data.values(): for shapename in shapenames: if shapename[3:-4] in info['bpPath']: typeid_map[info['typeId']] = shapename.split('_')[3] break return typeid_map def check_valid_range(data, value_range): lo, hi = value_range assert hi > lo return np.logical_and(data >= lo, hi <= hi).all() def quat_to_euler(quat, degree=True): return Rotation.from_quat(quat).as_euler('XYZ', degrees=degree) def quat_to_rotvec(quat, degree=True): return Rotation.from_quat(quat).as_rotvec(degrees=degree) def rotate_axis(euler): trans = np.eye(4, 4) trans[:3, :3] = Rotation.from_euler('xyz', euler).as_matrix() return trans def SRT_quat_to_matrix(scale, quat, translation): rotation_matrix = Rotation.from_quat(quat).as_matrix() transform_matrix = np.eye(4) transform_matrix[:3, :3] = rotation_matrix * scale transform_matrix[:3, 3] = translation return transform_matrix def matrix_to_SRT_quat2(transform_matrix): # Polar Decomposition transform_matrix = np.array(transform_matrix) translation = transform_matrix[:3, 3] rotation_matrix, scale_matrix = polar(transform_matrix[:3,:3]) quat = Rotation.from_matrix(rotation_matrix).as_quat() scale = np.diag(scale_matrix) return scale, quat, translation def apply_transform_to_block(block, trans_aug): precision_loss = False trans = SRT_quat_to_matrix( block['data']['scale'], block['data']['rotation'], block['data']['location'] ) trans = trans_aug @ trans scale, quat, translation = matrix_to_SRT_quat2(trans) trans_rec = SRT_quat_to_matrix(scale, quat, translation) if not np.allclose(trans, trans_rec, atol=1e-1): precision_loss = True return precision_loss, {} new_block = copy.deepcopy(block) new_block['data']['scale'] = scale.tolist() new_block['data']['rotation'] = quat.tolist() new_block['data']['location'] = translation.tolist() return precision_loss, new_block