# -*- coding: utf-8 -*- import os import time import numpy as np import warnings import random from omegaconf.listconfig import ListConfig from webdataset import pipelinefilter import torch import torchvision.transforms.functional as TVF from torchvision.transforms import InterpolationMode from torchvision.transforms.transforms import _interpolation_modes_from_int from typing import Sequence from michelangelo.utils import instantiate_from_config def _uid_buffer_pick(buf_dict, rng): uid_keys = list(buf_dict.keys()) selected_uid = rng.choice(uid_keys) buf = buf_dict[selected_uid] k = rng.randint(0, len(buf) - 1) sample = buf[k] buf[k] = buf[-1] buf.pop() if len(buf) == 0: del buf_dict[selected_uid] return sample def _add_to_buf_dict(buf_dict, sample): key = sample["__key__"] uid, uid_sample_id = key.split("_") if uid not in buf_dict: buf_dict[uid] = [] buf_dict[uid].append(sample) return buf_dict def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): """Shuffle the data in the stream. This uses a buffer of size `bufsize`. Shuffling at startup is less random; this is traded off against yielding samples quickly. data: iterator bufsize: buffer size for shuffling returns: iterator rng: either random module or random.Random instance """ if rng is None: rng = random.Random(int((os.getpid() + time.time()) * 1e9)) initial = min(initial, bufsize) buf_dict = dict() current_samples = 0 for sample in data: _add_to_buf_dict(buf_dict, sample) current_samples += 1 if current_samples < bufsize: try: _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708 current_samples += 1 except StopIteration: pass if current_samples >= initial: current_samples -= 1 yield _uid_buffer_pick(buf_dict, rng) while current_samples > 0: current_samples -= 1 yield _uid_buffer_pick(buf_dict, rng) uid_shuffle = pipelinefilter(_uid_shuffle) class RandomSample(object): def __init__(self, num_volume_samples: int = 1024, num_near_samples: int = 1024): super().__init__() self.num_volume_samples = num_volume_samples self.num_near_samples = num_near_samples def __call__(self, sample): rng = np.random.default_rng() # 1. sample surface input total_surface = sample["surface"] ind = rng.choice(total_surface.shape[0], replace=False) surface = total_surface[ind] # 2. sample volume/near geometric points vol_points = sample["vol_points"] vol_label = sample["vol_label"] near_points = sample["near_points"] near_label = sample["near_label"] ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) vol_points = vol_points[ind] vol_label = vol_label[ind] vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) near_points = near_points[ind] near_label = near_label[ind] near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) # concat sampled volume and near points geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) sample = { "surface": surface, "geo_points": geo_points } return sample class SplitRandomSample(object): def __init__(self, use_surface_sample: bool = False, num_surface_samples: int = 4096, num_volume_samples: int = 1024, num_near_samples: int = 1024): super().__init__() self.use_surface_sample = use_surface_sample self.num_surface_samples = num_surface_samples self.num_volume_samples = num_volume_samples self.num_near_samples = num_near_samples def __call__(self, sample): rng = np.random.default_rng() # 1. sample surface input surface = sample["surface"] if self.use_surface_sample: replace = surface.shape[0] < self.num_surface_samples ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace) surface = surface[ind] # 2. sample volume/near geometric points vol_points = sample["vol_points"] vol_label = sample["vol_label"] near_points = sample["near_points"] near_label = sample["near_label"] ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) vol_points = vol_points[ind] vol_label = vol_label[ind] vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) near_points = near_points[ind] near_label = near_label[ind] near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) # concat sampled volume and near points geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) sample = { "surface": surface, "geo_points": geo_points } return sample class FeatureSelection(object): VALID_SURFACE_FEATURE_DIMS = { "none": [0, 1, 2], # xyz "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal "normal": [0, 1, 2, 6, 7, 8] } def __init__(self, surface_feature_type: str): self.surface_feature_type = surface_feature_type self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type] def __call__(self, sample): sample["surface"] = sample["surface"][:, self.surface_dims] return sample class AxisScaleTransform(object): def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): assert isinstance(interval, (tuple, list, ListConfig)) self.interval = interval self.min_val = interval[0] self.max_val = interval[1] self.inter_size = interval[1] - interval[0] self.jitter = jitter self.jitter_scale = jitter_scale def __call__(self, sample): surface = sample["surface"][..., 0:3] geo_points = sample["geo_points"][..., 0:3] scaling = torch.rand(1, 3) * self.inter_size + self.min_val # print(scaling) surface = surface * scaling geo_points = geo_points * scaling scale = (1 / torch.abs(surface).max().item()) * 0.999999 surface *= scale geo_points *= scale if self.jitter: surface += self.jitter_scale * torch.randn_like(surface) surface.clamp_(min=-1.015, max=1.015) sample["surface"][..., 0:3] = surface sample["geo_points"][..., 0:3] = geo_points return sample class ToTensor(object): def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")): self.tensor_keys = tensor_keys def __call__(self, sample): for key in self.tensor_keys: if key not in sample: continue sample[key] = torch.tensor(sample[key], dtype=torch.float32) return sample class AxisScale(object): def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): assert isinstance(interval, (tuple, list, ListConfig)) self.interval = interval self.jitter = jitter self.jitter_scale = jitter_scale def __call__(self, surface, *args): scaling = torch.rand(1, 3) * 0.5 + 0.75 # print(scaling) surface = surface * scaling scale = (1 / torch.abs(surface).max().item()) * 0.999999 surface *= scale args_outputs = [] for _arg in args: _arg = _arg * scaling * scale args_outputs.append(_arg) if self.jitter: surface += self.jitter_scale * torch.randn_like(surface) surface.clamp_(min=-1, max=1) if len(args) == 0: return surface else: return surface, *args_outputs class RandomResize(torch.nn.Module): """Apply randomly Resize with a given probability.""" def __init__( self, size, resize_radio=(0.5, 1), allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR), interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=None, ): super().__init__() if not isinstance(size, (int, Sequence)): raise TypeError(f"Size should be int or sequence. Got {type(size)}") if isinstance(size, Sequence) and len(size) not in (1, 2): raise ValueError("If size is a sequence, it should have 1 or 2 values") self.size = size self.max_size = max_size # Backward compatibility with integer value if isinstance(interpolation, int): warnings.warn( "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " "Please use InterpolationMode enum." ) interpolation = _interpolation_modes_from_int(interpolation) self.interpolation = interpolation self.antialias = antialias self.resize_radio = resize_radio self.allow_resize_interpolations = allow_resize_interpolations def random_resize_params(self): radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0] if isinstance(self.size, int): size = int(self.size * radio) elif isinstance(self.size, Sequence): size = list(self.size) size = (int(size[0] * radio), int(size[1] * radio)) else: raise RuntimeError() interpolation = self.allow_resize_interpolations[ torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,)) ] return size, interpolation def forward(self, img): size, interpolation = self.random_resize_params() img = TVF.resize(img, size, interpolation, self.max_size, self.antialias) img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) return img def __repr__(self) -> str: detail = f"(size={self.size}, interpolation={self.interpolation.value}," detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}" return f"{self.__class__.__name__}{detail}" class Compose(object): """Composes several transforms together. This transform does not support torchscript. Please, see the note below. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.ToTensor(), >>> ]) .. note:: In order to script the transformations, please use ``torch.nn.Sequential`` as below. >>> transforms = torch.nn.Sequential( >>> transforms.CenterCrop(10), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> ) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require `lambda` functions or ``PIL.Image``. """ def __init__(self, transforms): self.transforms = transforms def __call__(self, *args): for t in self.transforms: args = t(*args) return args def __repr__(self): format_string = self.__class__.__name__ + '(' for t in self.transforms: format_string += '\n' format_string += ' {0}'.format(t) format_string += '\n)' return format_string def identity(*args, **kwargs): if len(args) == 1: return args[0] else: return args def build_transforms(cfg): if cfg is None: return identity transforms = [] for transform_name, cfg_instance in cfg.items(): transform_instance = instantiate_from_config(cfg_instance) transforms.append(transform_instance) print(f"Build transform: {transform_instance}") transforms = Compose(transforms) return transforms