File size: 4,597 Bytes
829e08b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import copy
import json
import os

import numpy as np  
from scipy.linalg import polar
from scipy.spatial.transform import Rotation
import torch
from torch.utils.data import Dataset

from .utils import exists
from .utils.logger import print_log


def create_dataset(cfg_dataset):
    kwargs = cfg_dataset
    name = kwargs.pop('name')
    dataset = get_dataset(name)(**kwargs)
    print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}")
    return dataset

def get_dataset(name):
    return {
        'base': PrimitiveDataset,
    }[name]


SHAPE_CODE = {
    'CubeBevel': 0,
    'SphereSharp': 1,
    'CylinderSharp': 2,
}


class PrimitiveDataset(Dataset): 
    def __init__(self,
        pc_dir,
        bs_dir,
        max_length=144,
        range_scale=[0, 1],
        range_rotation=[-180, 180],
        range_translation=[-1, 1],
        rotation_type='euler',
        pc_format='pc',
    ):
        self.data_filename = os.listdir(pc_dir)

        self.pc_dir = pc_dir
        self.max_length = max_length
        self.range_scale = range_scale
        self.range_rotation = range_rotation
        self.range_translation = range_translation
        self.rotation_type = rotation_type
        self.pc_format = pc_format

        with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f:
            basic_shapes = json.load(f)
            
        self.typeid_map = {
            1101002001034001: 'CubeBevel',
            1101002001034010: 'SphereSharp',
            1101002001034002: 'CylinderSharp',
        }

    def __len__(self):
        return len(self.data_filename)

    def __getitem__(self, idx):
        pc_file = os.path.join(self.pc_dir, self.data_filename[idx])
        pc = o3d.io.read_point_cloud(pc_file)

        model_data = {}

        points = torch.from_numpy(np.asarray(pc.points)).float()
        colors = torch.from_numpy(np.asarray(pc.colors)).float()
        normals = torch.from_numpy(np.asarray(pc.normals)).float()
        if self.pc_format == 'pc':
            model_data['pc'] = torch.concatenate([points, colors], dim=-1).T
        elif self.pc_format == 'pn':
            model_data['pc'] = torch.concatenate([points, normals], dim=-1)
        elif self.pc_format == 'pcn':
            model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1)
        else:
            raise ValueError(f'invalid pc_format: {self.pc_format}')

        return model_data


def get_typeid_shapename_mapping(shapenames, config_data):
    typeid_map = {}
    for info in config_data.values():
        for shapename in shapenames:
            if shapename[3:-4] in info['bpPath']:
                typeid_map[info['typeId']] = shapename.split('_')[3]
                break
    return typeid_map


def check_valid_range(data, value_range):
    lo, hi = value_range
    assert hi > lo
    return np.logical_and(data >= lo, hi <= hi).all()


def quat_to_euler(quat, degree=True):
    return Rotation.from_quat(quat).as_euler('XYZ', degrees=degree)


def quat_to_rotvec(quat, degree=True):
    return Rotation.from_quat(quat).as_rotvec(degrees=degree)


def rotate_axis(euler):
    trans = np.eye(4, 4)
    trans[:3, :3] = Rotation.from_euler('xyz', euler).as_matrix()
    return trans


def SRT_quat_to_matrix(scale, quat, translation):
    rotation_matrix = Rotation.from_quat(quat).as_matrix()
    transform_matrix = np.eye(4)
    transform_matrix[:3, :3] = rotation_matrix * scale
    transform_matrix[:3, 3] = translation
    return transform_matrix


def matrix_to_SRT_quat2(transform_matrix): # Polar Decomposition
    transform_matrix = np.array(transform_matrix)
    translation = transform_matrix[:3, 3]
    rotation_matrix, scale_matrix = polar(transform_matrix[:3,:3])
    quat = Rotation.from_matrix(rotation_matrix).as_quat()
    scale = np.diag(scale_matrix)
    return scale, quat, translation


def apply_transform_to_block(block, trans_aug):
    precision_loss = False
    trans = SRT_quat_to_matrix(
        block['data']['scale'],
        block['data']['rotation'],
        block['data']['location']
    )

    trans = trans_aug @ trans
    scale, quat, translation = matrix_to_SRT_quat2(trans)

    trans_rec = SRT_quat_to_matrix(scale, quat, translation)
    if not np.allclose(trans, trans_rec, atol=1e-1):
        precision_loss = True
        return precision_loss, {}

    new_block = copy.deepcopy(block)
    new_block['data']['scale'] = scale.tolist()
    new_block['data']['rotation'] = quat.tolist()
    new_block['data']['location'] = translation.tolist()
    return precision_loss, new_block