V3D / mesh_recon /models /utils.py
heheyas
init
cfb7702
raw
history blame
No virus
3.76 kB
import gc
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
import tinycudann as tcnn
def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs):
B = None
for arg in args:
if isinstance(arg, torch.Tensor):
B = arg.shape[0]
break
out = defaultdict(list)
out_type = None
for i in range(0, B, chunk_size):
out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs)
if out_chunk is None:
continue
out_type = type(out_chunk)
if isinstance(out_chunk, torch.Tensor):
out_chunk = {0: out_chunk}
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
chunk_length = len(out_chunk)
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
elif isinstance(out_chunk, dict):
pass
else:
print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.')
exit(1)
for k, v in out_chunk.items():
v = v if torch.is_grad_enabled() else v.detach()
v = v.cpu() if move_to_cpu else v
out[k].append(v)
if out_type is None:
return
out = {k: torch.cat(v, dim=0) for k, v in out.items()}
if out_type is torch.Tensor:
return out[0]
elif out_type in [tuple, list]:
return out_type([out[i] for i in range(chunk_length)])
elif out_type is dict:
return out
class _TruncExp(Function): # pylint: disable=abstract-method
# Implementation from torch-ngp:
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, x): # pylint: disable=arguments-differ
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g): # pylint: disable=arguments-differ
x = ctx.saved_tensors[0]
return g * torch.exp(torch.clamp(x, max=15))
trunc_exp = _TruncExp.apply
def get_activation(name):
if name is None:
return lambda x: x
name = name.lower()
if name == 'none':
return lambda x: x
elif name.startswith('scale'):
scale_factor = float(name[5:])
return lambda x: x.clamp(0., scale_factor) / scale_factor
elif name.startswith('clamp'):
clamp_max = float(name[5:])
return lambda x: x.clamp(0., clamp_max)
elif name.startswith('mul'):
mul_factor = float(name[3:])
return lambda x: x * mul_factor
elif name == 'lin2srgb':
return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.)
elif name == 'trunc_exp':
return trunc_exp
elif name.startswith('+') or name.startswith('-'):
return lambda x: x + float(name)
elif name == 'sigmoid':
return lambda x: torch.sigmoid(x)
elif name == 'tanh':
return lambda x: torch.tanh(x)
else:
return getattr(F, name)
def dot(x, y):
return torch.sum(x*y, -1, keepdim=True)
def reflect(x, n):
return 2 * dot(x, n) * n - x
def scale_anything(dat, inp_scale, tgt_scale):
if inp_scale is None:
inp_scale = [dat.min(), dat.max()]
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
return dat
def cleanup():
gc.collect()
torch.cuda.empty_cache()
tcnn.free_temporary_memory()