import math from collections import defaultdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import craftsman from craftsman.utils.typing import * def dot(x, y): return torch.sum(x * y, -1, keepdim=True) def reflect(x, n): return 2 * dot(x, n) * n - x ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] def scale_tensor( dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale ): if inp_scale is None: inp_scale = (0, 1) if tgt_scale is None: tgt_scale = (0, 1) if isinstance(tgt_scale, Tensor): assert dat.shape[-1] == tgt_scale.shape[-1] dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] return dat def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: if chunk_size <= 0: return func(*args, **kwargs) B = None for arg in list(args) + list(kwargs.values()): if isinstance(arg, torch.Tensor): B = arg.shape[0] break assert ( B is not None ), "No tensor found in args or kwargs, cannot determine batch size." out = defaultdict(list) out_type = None # max(1, B) to support B == 0 for i in range(0, max(1, B), chunk_size): out_chunk = func( *[ arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args ], **{ k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg for k, arg in kwargs.items() }, ) if out_chunk is None: continue out_type = type(out_chunk) if isinstance(out_chunk, torch.Tensor): out_chunk = {0: out_chunk} elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): chunk_length = len(out_chunk) out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} elif isinstance(out_chunk, dict): pass else: print( f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." ) exit(1) for k, v in out_chunk.items(): v = v if torch.is_grad_enabled() else v.detach() out[k].append(v) if out_type is None: return None out_merged: Dict[Any, Optional[torch.Tensor]] = {} for k, v in out.items(): if all([vv is None for vv in v]): # allow None in return value out_merged[k] = None elif all([isinstance(vv, torch.Tensor) for vv in v]): out_merged[k] = torch.cat(v, dim=0) else: raise TypeError( f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" ) if out_type is torch.Tensor: return out_merged[0] elif out_type in [tuple, list]: return out_type([out_merged[i] for i in range(chunk_length)]) elif out_type is dict: return out_merged def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, device: Optional["torch.device"] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): """A helper function to create random tensors on the desired `device` with the desired `dtype`. When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor is always created on the CPU. """ # device on which tensor is created defaults to device rand_device = device batch_size = shape[0] layout = layout or torch.strided device = device or torch.device("cpu") if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" if device != "mps": logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" f" slighly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") # make sure generator list of length 1 is treated like a non-list if isinstance(generator, list) and len(generator) == 1: generator = generator[0] if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents def generate_dense_grid_points( bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij" ): length = bbox_max - bbox_min num_cells = np.exp2(octree_depth) x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) xyz = np.stack((xs, ys, zs), axis=-1) xyz = xyz.reshape(-1, 3) grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] return xyz, grid_size, length