import gc from collections import defaultdict import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd import tinycudann as tcnn def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs): B = None for arg in args: if isinstance(arg, torch.Tensor): B = arg.shape[0] break out = defaultdict(list) out_type = None for i in range(0, B, chunk_size): out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) if out_chunk is None: continue out_type = type(out_chunk) if isinstance(out_chunk, torch.Tensor): out_chunk = {0: out_chunk} elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): chunk_length = len(out_chunk) out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} elif isinstance(out_chunk, dict): pass else: print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.') exit(1) for k, v in out_chunk.items(): v = v if torch.is_grad_enabled() else v.detach() v = v.cpu() if move_to_cpu else v out[k].append(v) if out_type is None: return out = {k: torch.cat(v, dim=0) for k, v in out.items()} if out_type is torch.Tensor: return out[0] elif out_type in [tuple, list]: return out_type([out[i] for i in range(chunk_length)]) elif out_type is dict: return out class _TruncExp(Function): # pylint: disable=abstract-method # Implementation from torch-ngp: # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, x): # pylint: disable=arguments-differ ctx.save_for_backward(x) return torch.exp(x) @staticmethod @custom_bwd def backward(ctx, g): # pylint: disable=arguments-differ x = ctx.saved_tensors[0] return g * torch.exp(torch.clamp(x, max=15)) trunc_exp = _TruncExp.apply def get_activation(name): if name is None: return lambda x: x name = name.lower() if name == 'none': return lambda x: x elif name.startswith('scale'): scale_factor = float(name[5:]) return lambda x: x.clamp(0., scale_factor) / scale_factor elif name.startswith('clamp'): clamp_max = float(name[5:]) return lambda x: x.clamp(0., clamp_max) elif name.startswith('mul'): mul_factor = float(name[3:]) return lambda x: x * mul_factor elif name == 'lin2srgb': return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) elif name == 'trunc_exp': return trunc_exp elif name.startswith('+') or name.startswith('-'): return lambda x: x + float(name) elif name == 'sigmoid': return lambda x: torch.sigmoid(x) elif name == 'tanh': return lambda x: torch.tanh(x) else: return getattr(F, name) def dot(x, y): return torch.sum(x*y, -1, keepdim=True) def reflect(x, n): return 2 * dot(x, n) * n - x def scale_anything(dat, inp_scale, tgt_scale): if inp_scale is None: inp_scale = [dat.min(), dat.max()] dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] return dat def cleanup(): gc.collect() torch.cuda.empty_cache() tcnn.free_temporary_memory()