Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,597 Bytes
829e08b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import copy
import json
import os
import numpy as np
from scipy.linalg import polar
from scipy.spatial.transform import Rotation
import torch
from torch.utils.data import Dataset
from .utils import exists
from .utils.logger import print_log
def create_dataset(cfg_dataset):
kwargs = cfg_dataset
name = kwargs.pop('name')
dataset = get_dataset(name)(**kwargs)
print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}")
return dataset
def get_dataset(name):
return {
'base': PrimitiveDataset,
}[name]
SHAPE_CODE = {
'CubeBevel': 0,
'SphereSharp': 1,
'CylinderSharp': 2,
}
class PrimitiveDataset(Dataset):
def __init__(self,
pc_dir,
bs_dir,
max_length=144,
range_scale=[0, 1],
range_rotation=[-180, 180],
range_translation=[-1, 1],
rotation_type='euler',
pc_format='pc',
):
self.data_filename = os.listdir(pc_dir)
self.pc_dir = pc_dir
self.max_length = max_length
self.range_scale = range_scale
self.range_rotation = range_rotation
self.range_translation = range_translation
self.rotation_type = rotation_type
self.pc_format = pc_format
with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f:
basic_shapes = json.load(f)
self.typeid_map = {
1101002001034001: 'CubeBevel',
1101002001034010: 'SphereSharp',
1101002001034002: 'CylinderSharp',
}
def __len__(self):
return len(self.data_filename)
def __getitem__(self, idx):
pc_file = os.path.join(self.pc_dir, self.data_filename[idx])
pc = o3d.io.read_point_cloud(pc_file)
model_data = {}
points = torch.from_numpy(np.asarray(pc.points)).float()
colors = torch.from_numpy(np.asarray(pc.colors)).float()
normals = torch.from_numpy(np.asarray(pc.normals)).float()
if self.pc_format == 'pc':
model_data['pc'] = torch.concatenate([points, colors], dim=-1).T
elif self.pc_format == 'pn':
model_data['pc'] = torch.concatenate([points, normals], dim=-1)
elif self.pc_format == 'pcn':
model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1)
else:
raise ValueError(f'invalid pc_format: {self.pc_format}')
return model_data
def get_typeid_shapename_mapping(shapenames, config_data):
typeid_map = {}
for info in config_data.values():
for shapename in shapenames:
if shapename[3:-4] in info['bpPath']:
typeid_map[info['typeId']] = shapename.split('_')[3]
break
return typeid_map
def check_valid_range(data, value_range):
lo, hi = value_range
assert hi > lo
return np.logical_and(data >= lo, hi <= hi).all()
def quat_to_euler(quat, degree=True):
return Rotation.from_quat(quat).as_euler('XYZ', degrees=degree)
def quat_to_rotvec(quat, degree=True):
return Rotation.from_quat(quat).as_rotvec(degrees=degree)
def rotate_axis(euler):
trans = np.eye(4, 4)
trans[:3, :3] = Rotation.from_euler('xyz', euler).as_matrix()
return trans
def SRT_quat_to_matrix(scale, quat, translation):
rotation_matrix = Rotation.from_quat(quat).as_matrix()
transform_matrix = np.eye(4)
transform_matrix[:3, :3] = rotation_matrix * scale
transform_matrix[:3, 3] = translation
return transform_matrix
def matrix_to_SRT_quat2(transform_matrix): # Polar Decomposition
transform_matrix = np.array(transform_matrix)
translation = transform_matrix[:3, 3]
rotation_matrix, scale_matrix = polar(transform_matrix[:3,:3])
quat = Rotation.from_matrix(rotation_matrix).as_quat()
scale = np.diag(scale_matrix)
return scale, quat, translation
def apply_transform_to_block(block, trans_aug):
precision_loss = False
trans = SRT_quat_to_matrix(
block['data']['scale'],
block['data']['rotation'],
block['data']['location']
)
trans = trans_aug @ trans
scale, quat, translation = matrix_to_SRT_quat2(trans)
trans_rec = SRT_quat_to_matrix(scale, quat, translation)
if not np.allclose(trans, trans_rec, atol=1e-1):
precision_loss = True
return precision_loss, {}
new_block = copy.deepcopy(block)
new_block['data']['scale'] = scale.tolist()
new_block['data']['rotation'] = quat.tolist()
new_block['data']['location'] = translation.tolist()
return precision_loss, new_block
|