PrimitiveAnything / primitive_anything /primitive_dataset.py
YulianSa's picture
update
139b771
import copy
import json
import os
import numpy as np
from scipy.linalg import polar
from scipy.spatial.transform import Rotation
import torch
from torch.utils.data import Dataset
from .utils import exists
from .utils.logger import print_log
def create_dataset(cfg_dataset):
kwargs = cfg_dataset
name = kwargs.pop('name')
dataset = get_dataset(name)(**kwargs)
print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}")
return dataset
def get_dataset(name):
return {
'base': PrimitiveDataset,
}[name]
SHAPE_CODE = {
'CubeBevel': 0,
'SphereSharp': 1,
'CylinderSharp': 2,
}
class PrimitiveDataset(Dataset):
def __init__(self,
pc_dir,
bs_dir,
max_length=144,
range_scale=[0, 1],
range_rotation=[-180, 180],
range_translation=[-1, 1],
rotation_type='euler',
pc_format='pc',
):
self.data_filename = os.listdir(pc_dir)
self.pc_dir = pc_dir
self.max_length = max_length
self.range_scale = range_scale
self.range_rotation = range_rotation
self.range_translation = range_translation
self.rotation_type = rotation_type
self.pc_format = pc_format
with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f:
basic_shapes = json.load(f)
self.typeid_map = {
1101002001034001: 'CubeBevel',
1101002001034010: 'SphereSharp',
1101002001034002: 'CylinderSharp',
}
def __len__(self):
return len(self.data_filename)
def __getitem__(self, idx):
pc_file = os.path.join(self.pc_dir, self.data_filename[idx])
pc = o3d.io.read_point_cloud(pc_file)
model_data = {}
points = torch.from_numpy(np.asarray(pc.points)).float()
colors = torch.from_numpy(np.asarray(pc.colors)).float()
normals = torch.from_numpy(np.asarray(pc.normals)).float()
if self.pc_format == 'pc':
model_data['pc'] = torch.concatenate([points, colors], dim=-1).T
elif self.pc_format == 'pn':
model_data['pc'] = torch.concatenate([points, normals], dim=-1)
elif self.pc_format == 'pcn':
model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1)
else:
raise ValueError(f'invalid pc_format: {self.pc_format}')
return model_data
def get_typeid_shapename_mapping(shapenames, config_data):
typeid_map = {}
for info in config_data.values():
for shapename in shapenames:
if shapename[3:-4] in info['bpPath']:
typeid_map[info['typeId']] = shapename.split('_')[3]
break
return typeid_map
def check_valid_range(data, value_range):
lo, hi = value_range
assert hi > lo
return np.logical_and(data >= lo, hi <= hi).all()
def quat_to_euler(quat, degree=True):
return Rotation.from_quat(quat).as_euler('XYZ', degrees=degree)
def quat_to_rotvec(quat, degree=True):
return Rotation.from_quat(quat).as_rotvec(degrees=degree)
def rotate_axis(euler):
trans = np.eye(4, 4)
trans[:3, :3] = Rotation.from_euler('xyz', euler).as_matrix()
return trans
def SRT_quat_to_matrix(scale, quat, translation):
rotation_matrix = Rotation.from_quat(quat).as_matrix()
transform_matrix = np.eye(4)
transform_matrix[:3, :3] = rotation_matrix * scale
transform_matrix[:3, 3] = translation
return transform_matrix
def matrix_to_SRT_quat2(transform_matrix): # Polar Decomposition
transform_matrix = np.array(transform_matrix)
translation = transform_matrix[:3, 3]
rotation_matrix, scale_matrix = polar(transform_matrix[:3,:3])
quat = Rotation.from_matrix(rotation_matrix).as_quat()
scale = np.diag(scale_matrix)
return scale, quat, translation
def apply_transform_to_block(block, trans_aug):
precision_loss = False
trans = SRT_quat_to_matrix(
block['data']['scale'],
block['data']['rotation'],
block['data']['location']
)
trans = trans_aug @ trans
scale, quat, translation = matrix_to_SRT_quat2(trans)
trans_rec = SRT_quat_to_matrix(scale, quat, translation)
if not np.allclose(trans, trans_rec, atol=1e-1):
precision_loss = True
return precision_loss, {}
new_block = copy.deepcopy(block)
new_block['data']['scale'] = scale.tolist()
new_block['data']['rotation'] = quat.tolist()
new_block['data']['location'] = translation.tolist()
return precision_loss, new_block