File size: 10,351 Bytes
8133633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import math
import os
import json
import re
import cv2
from dataclasses import dataclass, field

import random
import imageio
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image

from craftsman.utils.typing import *

def fit_bounding_box(img, mask, marign_pix_dis, background_color):
    # alpha_channel = img[:, :, 3]
    alpha_channel = mask.numpy().squeeze()
    height = np.any(alpha_channel, axis=1)
    width = np.any(alpha_channel, axis=0)
    h_min, h_max = np.where(height)[0][[0, -1]]
    w_min, w_max = np.where(width)[0][[0, -1]]
    box_height = h_max - h_min
    box_width = w_max - w_min
    cropped_image = img[h_min:h_max, w_min:w_max]
    if box_height > box_width:
        new_hight = 512 - 2 * marign_pix_dis
        new_width = int((512 - 2 * marign_pix_dis) / (box_height) * box_width) + 1
    else:
        new_hight = int((512 - 2 * marign_pix_dis) / (box_width) * box_height) + 1
        new_width = 512 - 2 * marign_pix_dis 
    new_h_min_pos = int((512 - new_hight) / 2 + 1)
    new_h_max_pos = new_hight + new_h_min_pos

    new_w_min_pos = int((512 - new_width) / 2 + 1)
    new_w_max_pos = new_width + new_w_min_pos
    # extend of the bbox
    new_image = np.full((512, 512, 3), background_color)
    new_image[new_h_min_pos:new_h_max_pos, new_w_min_pos:new_w_max_pos, :] = cv2.resize(cropped_image.numpy(), (new_width, new_hight))
    
    return torch.from_numpy(new_image)
        
@dataclass
class BaseDataModuleConfig:
    local_dir: str = None

    ################################# Geometry part #################################
    load_geometry: bool = True           # whether to load geometry data
    geo_data_type: str = "occupancy"     # occupancy, sdf
    geo_data_path: str = ""              # path to the geometry data
    # for occupancy and sdf data
    n_samples: int = 4096                # number of points in input point cloud
    upsample_ratio: int = 1              # upsample ratio for input point cloud
    sampling_strategy: str = "random"    # sampling strategy for input point cloud
    scale: float = 1.0                   # scale of the input point cloud and target supervision
    load_supervision: bool = True        # whether to load supervision
    supervision_type: str = "occupancy"  # occupancy, sdf, tsdf
    tsdf_threshold: float = 0.05         # threshold for truncating sdf values, used when input is sdf
    n_supervision: int = 10000           # number of points in supervision

    ################################# Image part #################################
    load_image: bool = False             # whether to load images 
    image_data_path: str = ""            # path to the image data
    image_type: str = "rgb"              # rgb, normal
    background_color: Tuple[float, float, float] = field(
            default_factory=lambda: (0.5, 0.5, 0.5)
        )
    idx: Optional[List[int]] = None      # index of the image to load
    n_views: int = 1                     # number of views
    marign_pix_dis: int = 30             # margin of the bounding box


class BaseDataset(Dataset):
    def __init__(self, cfg: Any, split: str) -> None:
        super().__init__()
        self.cfg: BaseDataModuleConfig = cfg
        self.split = split

        self.uids = json.load(open(f'{cfg.root_dir}/{split}.json'))
        print(f"Loaded {len(self.uids)} {split} uids")
    
    def __len__(self):
        return len(self.uids)


    def _load_shape_from_occupancy_or_sdf(self, index: int) -> Dict[str, Any]:
        if self.cfg.geo_data_type == "occupancy":
            # for input point cloud, using Objaverse-MIX data
            pointcloud = np.load(f'{self.cfg.geo_data_path}/{self.uids[index]}/pointcloud.npz')
            surface = np.asarray(pointcloud['points']) * 2 # range from -1 to 1
            normal = np.asarray(pointcloud['normals'])
            surface = np.concatenate([surface, normal], axis=1)
        elif self.cfg.geo_data_type == "sdf":
            # for sdf data with our own format
            if re.match(r"\.\.", self.uids[index]):
                data = np.load(f'{self.cfg.geo_data_path}/{self.uids[index]}.npz')
            else:
                data = np.load(f'{self.uids[index]}.npz')
            # for input point cloud
            surface = data["surface"]
        else:
            raise NotImplementedError(f"Data type {self.cfg.geo_data_type} not implemented")
        
        # random sampling
        if self.cfg.sampling_strategy == "random":
            rng = np.random.default_rng()
            ind = rng.choice(surface.shape[0], self.cfg.upsample_ratio * self.cfg.n_samples, replace=False)
            surface = surface[ind]
        elif self.cfg.sampling_strategy == "fps":
            import fpsample
            kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(surface[:, :3], self.cfg.n_samples, h=5)
            surface = surface[kdline_fps_samples_idx]
        else:
            raise NotImplementedError(f"sampling strategy {self.cfg.sampling_strategy} not implemented")
        # rescale data
        surface[:, :3] = surface[:, :3] * self.cfg.scale # target scale
        ret = {
            "uid": self.uids[index].split('/')[-1],
            "surface": surface.astype(np.float32),
        }

        return ret

    def _load_shape_supervision_occupancy_or_sdf(self, index: int) -> Dict[str, Any]:
        # for supervision
        ret = {}
        if self.cfg.data_type == "occupancy":
            points = np.load(f'{self.cfg.geo_data_path}/{self.uids[index]}/points.npz')
            rand_points = np.asarray(points['points']) * 2 # range from -1.1 to 1.1
            occupancies = np.asarray(points['occupancies'])
            occupancies = np.unpackbits(occupancies)
        elif self.cfg.data_type == "sdf":
            data = np.load(f'{self.cfg.geo_data_path}/{self.uids[index]}.npz')
            rand_points = data['rand_points']
            sdfs = data['sdfs']
        else:
            raise NotImplementedError(f"Data type {self.cfg.data_type} not implemented")

        # random sampling
        rng = np.random.default_rng()
        ind = rng.choice(rand_points.shape[0], self.cfg.n_supervision, replace=False)
        rand_points = rand_points[ind]
        rand_points = rand_points * self.cfg.scale
        ret["rand_points"] = rand_points.astype(np.float32)

        if self.cfg.data_type == "occupancy":
            assert self.cfg.supervision_type == "occupancy", "Only occupancy supervision is supported for occupancy data"
            occupancies = occupancies[ind]
            ret["occupancies"] = occupancies.astype(np.float32)
        elif self.cfg.data_type == "sdf":
            if self.cfg.supervision_type == "sdf":
                ret["sdf"] = sdfs[ind].flatten().astype(np.float32)
            elif self.cfg.supervision_type == "occupancy":
                ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype(np.float32)
            elif self.cfg.supervision_type == "tsdf":
                ret["sdf"] = sdfs[ind].flatten().astype(np.float32).clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold) / self.cfg.tsdf_threshold
            else:
                raise NotImplementedError(f"Supervision type {self.cfg.supervision_type} not implemented")

        return ret


    def _load_image(self, index: int) -> Dict[str, Any]:
        def _load_single_image(img_path, background_color, marign_pix_dis=None):
            img = torch.from_numpy(
                np.asarray(
                    Image.fromarray(imageio.v2.imread(img_path))
                    .convert("RGBA")
                )
                / 255.0
            ).float()
            mask: Float[Tensor, "H W 1"] = img[:, :, -1:]
            image: Float[Tensor, "H W 3"] = img[:, :, :3] * mask + background_color[
                None, None, :
            ] * (1 - mask)
            if marign_pix_dis is not None:
                image = fit_bounding_box(image, mask, marign_pix_dis, background_color)
            return image, mask
        
        if self.cfg.background_color == [-1, -1, -1]:
            background_color = torch.randint(0, 256, (3,))
        else:
            background_color = torch.as_tensor(self.cfg.background_color)
        ret = {}
        if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal":
            assert self.cfg.n_views == 1, "Only single view is supported for single image"
            sel_idx = random.choice(self.cfg.idx)
            ret["sel_image_idx"] = sel_idx
            if self.cfg.image_type == "rgb":
                img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{'{:04d}'.format(sel_idx)}_rgb.png"
            elif self.cfg.image_type == "normal":
                img_path = f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f"/{'{:04d}'.format(sel_idx)}_normal.png"
            ret["image"], ret["mask"] = _load_single_image(img_path, background_color, self.cfg.marign_pix_dis)

        else:
            raise NotImplementedError(f"Image type {self.cfg.image_type} not implemented")
        
        return ret

    def _get_data(self, index):
        ret = {"uid": self.uids[index]}
        # load geometry
        if self.cfg.load_geometry:
            if self.cfg.geo_data_type == "occupancy" or self.cfg.geo_data_type == "sdf":
                # load shape
                ret = self._load_shape_from_occupancy_or_sdf(index)
                # load supervision for shape
                if self.cfg.load_supervision:
                    ret.update(self._load_shape_supervision_occupancy_or_sdf(index))
            else:
                raise NotImplementedError(f"Geo data type {self.cfg.geo_data_type} not implemented")

        # load image
        if self.cfg.load_image:
            ret.update(self._load_image(index))

        return ret
        
    def __getitem__(self, index):
        try:
            return self._get_data(index)
        except Exception as e:
            print(f"Error in {self.uids[index]}: {e}")
            return self.__getitem__(np.random.randint(len(self)))

    def collate(self, batch):
        from torch.utils.data._utils.collate import default_collate_fn_map
        return torch.utils.data.default_collate(batch)