Spaces:
Running
on
Zero
Running
on
Zero
""" | |
NeRF differentiable renderer. | |
References: | |
https://github.com/bmild/nerf | |
https://github.com/kwea123/nerf_pl | |
""" | |
import torch | |
import torch.autograd.profiler as profiler | |
from dotmap import DotMap | |
class _RenderWrapper(torch.nn.Module): | |
def __init__(self, net, renderer, simple_output): | |
super().__init__() | |
self.net = net | |
self.renderer = renderer | |
self.simple_output = simple_output | |
def forward( | |
self, | |
rays, | |
want_weights=False, | |
want_alphas=False, | |
want_z_samps=False, | |
want_rgb_samps=False, | |
sample_from_dist=None, | |
): | |
if rays.shape[0] == 0: | |
return ( | |
torch.zeros(0, 3, device=rays.device), | |
torch.zeros(0, device=rays.device), | |
) | |
outputs = self.renderer( | |
self.net, | |
rays, | |
want_weights=want_weights and not self.simple_output, | |
want_alphas=want_alphas and not self.simple_output, | |
want_z_samps=want_z_samps and not self.simple_output, | |
want_rgb_samps=want_rgb_samps and not self.simple_output, | |
sample_from_dist=sample_from_dist, | |
) | |
if self.simple_output: | |
if self.renderer.using_fine: | |
rgb = outputs.fine.rgb | |
depth = outputs.fine.depth | |
else: | |
rgb = outputs.coarse.rgb | |
depth = outputs.coarse.depth | |
return rgb, depth | |
else: | |
# Make DotMap to dict to support DataParallel | |
return outputs.toDict() | |
class NeRFRenderer(torch.nn.Module): | |
""" | |
NeRF differentiable renderer | |
:param n_coarse number of coarse (binned uniform) samples | |
:param n_fine number of fine (importance) samples | |
:param n_fine_depth number of expected depth samples | |
:param noise_std noise to add to sigma. We do not use it | |
:param depth_std noise for depth samples | |
:param eval_batch_size ray batch size for evaluation | |
:param white_bkgd if true, background color is white; else black | |
:param lindisp if to use samples linear in disparity instead of distance | |
:param sched ray sampling schedule. list containing 3 lists of equal length. | |
sched[0] is list of iteration numbers, | |
sched[1] is list of coarse sample numbers, | |
sched[2] is list of fine sample numbers | |
""" | |
def __init__( | |
self, | |
n_coarse=128, | |
n_fine=0, | |
n_fine_depth=0, | |
noise_std=0.0, | |
depth_std=0.01, | |
eval_batch_size=100000, | |
white_bkgd=False, | |
lindisp=False, | |
sched=None, # ray sampling schedule for coarse and fine rays | |
hard_alpha_cap=False, | |
render_mode="volumetric", | |
surface_sigmoid_scale=.1, | |
render_flow=False, | |
normalize_dino=False, | |
): | |
super().__init__() | |
self.n_coarse, self.n_fine = n_coarse, n_fine | |
self.n_fine_depth = n_fine_depth | |
self.noise_std = noise_std | |
self.depth_std = depth_std | |
self.eval_batch_size = eval_batch_size | |
self.white_bkgd = white_bkgd | |
self.lindisp = lindisp | |
if lindisp: | |
print("Using linear displacement rays") | |
self.using_fine = n_fine > 0 | |
self.sched = sched | |
if sched is not None and len(sched) == 0: | |
self.sched = None | |
self.register_buffer( | |
"iter_idx", torch.tensor(0, dtype=torch.long), persistent=True | |
) | |
self.register_buffer( | |
"last_sched", torch.tensor(0, dtype=torch.long), persistent=True | |
) | |
self.hard_alpha_cap = hard_alpha_cap | |
assert render_mode in ("volumetric", "surface", "neus") | |
self.render_mode = render_mode | |
self.only_surface_color = (self.render_mode == "surface") | |
self.surface_sigmoid_scale = surface_sigmoid_scale | |
self.render_flow = render_flow | |
self.normalize_dino = normalize_dino | |
def sample_coarse(self, rays): | |
""" | |
Stratified sampling. Note this is different from original NeRF slightly. | |
:param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) | |
:return (B, Kc) | |
""" | |
device = rays.device | |
near, far = rays[:, 6:7], rays[:, 7:8] # (B, 1) | |
step = 1.0 / self.n_coarse | |
B = rays.shape[0] | |
z_steps = torch.linspace(0, 1 - step, self.n_coarse, device=device) # (Kc) | |
z_steps = z_steps.unsqueeze(0).repeat(B, 1) # (B, Kc) | |
z_steps += torch.rand_like(z_steps) * step | |
if not self.lindisp: # Use linear sampling in depth space | |
return near * (1 - z_steps) + far * z_steps # (B, Kf) | |
else: # Use linear sampling in disparity space | |
return 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) # (B, Kf) | |
# Use linear sampling in depth space | |
return near * (1 - z_steps) + far * z_steps # (B, Kc) | |
def sample_coarse_from_dist(self, rays, weights, z_samp): | |
device = rays.device | |
B = rays.shape[0] | |
num_bins = weights.shape[-1] | |
num_samples = self.n_coarse | |
weights = weights.detach() + 1e-5 # Prevent division by zero | |
pdf = weights / torch.sum(weights, -1, keepdim=True) # (B, Kc) | |
cdf = torch.cumsum(pdf, -1) # (B, Kc) | |
cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (B, Kc+1) | |
u = torch.rand(B, num_samples, dtype=torch.float32, device=device) # (B, Kf) | |
interval_ids = torch.searchsorted(cdf, u, right=True) - 1 # (B, Kf) | |
interval_ids = torch.clamp(interval_ids, 0, num_samples - 1) | |
interval_interp = torch.rand_like(interval_ids, dtype=torch.float32) | |
# z_samps describe the centers of the respective histogram bins. Therefore, we have to extend them to the left and right | |
if self.lindisp: | |
z_samp = 1 / z_samp | |
centers = 0.5 * (z_samp[:, 1:] + z_samp[:, :-1]) | |
interval_borders = torch.cat((z_samp[:, :1], centers, z_samp[:, -1:]), dim=-1) | |
left_border = torch.gather(interval_borders, dim=-1, index=interval_ids) | |
right_border = torch.gather(interval_borders, dim=-1, index=interval_ids + 1) | |
z_samp_new = ( | |
left_border * (1 - interval_interp) + right_border * interval_interp | |
) | |
if self.lindisp: | |
z_samp_new = 1 / z_samp_new | |
assert not torch.any(torch.isnan(z_samp_new)) | |
return z_samp_new | |
def sample_fine(self, rays, weights): | |
"""min | |
Weighted stratified (importance) sample | |
:param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) | |
:param weights (B, Kc) | |
:return (B, Kf-Kfd) | |
""" | |
device = rays.device | |
B = rays.shape[0] | |
weights = weights.detach() + 1e-5 # Prevent division by zero | |
pdf = weights / torch.sum(weights, -1, keepdim=True) # (B, Kc) | |
cdf = torch.cumsum(pdf, -1) # (B, Kc) | |
cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (B, Kc+1) | |
u = torch.rand( | |
B, self.n_fine - self.n_fine_depth, dtype=torch.float32, device=device | |
) # (B, Kf) | |
inds = torch.searchsorted(cdf, u, right=True).float() - 1.0 # (B, Kf) | |
inds = torch.clamp_min(inds, 0.0) | |
z_steps = (inds + torch.rand_like(inds)) / self.n_coarse # (B, Kf) | |
near, far = rays[:, 6:7], rays[:, 7:8] # (B, 1) | |
if not self.lindisp: # Use linear sampling in depth space | |
z_samp = near * (1 - z_steps) + far * z_steps # (B, Kf) | |
else: # Use linear sampling in disparity space | |
z_samp = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) # (B, Kf) | |
assert not torch.any(torch.isnan(z_samp)) | |
return z_samp | |
def sample_fine_depth(self, rays, depth): | |
""" | |
Sample around specified depth | |
:param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) | |
:param depth (B) | |
:return (B, Kfd) | |
""" | |
z_samp = depth.unsqueeze(1).repeat((1, self.n_fine_depth)) | |
z_samp += torch.randn_like(z_samp) * self.depth_std | |
# Clamp does not support tensor bounds | |
z_samp = torch.max(torch.min(z_samp, rays[:, 7:8]), rays[:, 6:7]) | |
assert not torch.any(torch.isnan(z_samp)) | |
return z_samp | |
def composite(self, model, rays, z_samp, coarse=True, sb=0): | |
""" | |
Render RGB and depth for each ray using NeRF alpha-compositing formula, | |
given sampled positions along each ray (see sample_*) | |
:param model should return (B, (r, g, b, sigma)) when called with (B, (x, y, z)) | |
should also support 'coarse' boolean argument | |
:param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) | |
:param z_samp z positions sampled for each ray (B, K) | |
:param coarse whether to evaluate using coarse NeRF | |
:param sb super-batch dimension; 0 = disable | |
:return weights (B, K), rgb (B, 3), depth (B) | |
""" | |
with profiler.record_function("renderer_composite"): | |
B, K = z_samp.shape | |
r_dim = rays.shape[-1] | |
deltas = z_samp[:, 1:] - z_samp[:, :-1] # (B, K-1) | |
delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # infty (B, 1) | |
# delta_inf = rays[:, -1:] - z_samp[:, -1:] | |
deltas = torch.cat([deltas, delta_inf], -1) # (B, K) | |
# (B, K, 3) | |
points = rays[:, None, :3] + z_samp.unsqueeze(2) * rays[:, None, 3:6] | |
points = points.reshape(-1, 3) # (B*K, 3) | |
if r_dim > 8: | |
ray_info = rays[:, None, 8:].expand(-1, K, -1) | |
else: | |
ray_info = None | |
if hasattr(model, "use_viewdirs"): | |
use_viewdirs = model.use_viewdirs | |
else: | |
use_viewdirs = None | |
viewdirs_all = [] | |
rgbs_all, invalid_all, sigmas_all, extras_all, state_dicts_all = [], [], [], [], [] | |
if sb > 0: | |
points = points.reshape( | |
sb, -1, 3 | |
) # (SB, B'*K, 3) B' is real ray batch size | |
if ray_info is not None: | |
ray_info = ray_info.reshape(sb, -1, ray_info.shape[-1]) | |
eval_batch_dim = 1 | |
eval_batch_size = (self.eval_batch_size - 1) // sb + 1 | |
else: | |
eval_batch_size = self.eval_batch_size | |
eval_batch_dim = 0 | |
split_points = torch.split(points, eval_batch_size, dim=eval_batch_dim) | |
if ray_info is not None: | |
split_ray_infos = torch.split(ray_info, eval_batch_size, dim=eval_batch_dim) | |
else: | |
split_ray_infos = [None for _ in split_points] | |
if use_viewdirs: | |
dim1 = K | |
viewdirs = rays[:, None, 3:6].expand(-1, dim1, -1) | |
if sb > 0: | |
viewdirs = viewdirs.reshape(sb, -1, 3) # (SB, B'*K, 3) | |
else: | |
viewdirs = viewdirs.reshape(-1, 3) # (B*K, 3) | |
split_viewdirs = torch.split( | |
viewdirs, eval_batch_size, dim=eval_batch_dim | |
) | |
for i, pnts in enumerate(split_points): | |
dirs = split_viewdirs[i] | |
infos = split_ray_infos[i] | |
rgbs, invalid, sigmas, extras, state_dict = model( | |
pnts, coarse=coarse, viewdirs=dirs, only_density=self.only_surface_color, ray_info=ray_info, render_flow=self.render_flow | |
) | |
rgbs_all.append(rgbs) | |
invalid_all.append(invalid) | |
sigmas_all.append(sigmas) | |
extras_all.append(extras) | |
viewdirs_all.append(dirs) | |
if state_dict is not None: | |
state_dicts_all.append(state_dict) | |
else: | |
for i, pnts in enumerate(split_points): | |
infos = split_ray_infos[i] | |
rgbs, invalid, sigmas, extras, state_dict = model(pnts, coarse=coarse, only_density=self.only_surface_color, ray_info=infos, render_flow=self.render_flow) | |
rgbs_all.append(rgbs) | |
invalid_all.append(invalid) | |
sigmas_all.append(sigmas) | |
extras_all.append(extras) | |
if state_dict is not None: | |
state_dicts_all.append(state_dict) | |
points, viewdirs = None, None | |
# (B*K, 4) OR (SB, B'*K, 4) | |
if not self.only_surface_color: | |
rgbs = torch.cat(rgbs_all, dim=eval_batch_dim) | |
else: | |
rgbs = None | |
invalid = torch.cat(invalid_all, dim=eval_batch_dim) | |
sigmas = torch.cat(sigmas_all, dim=eval_batch_dim) | |
if not extras_all[0] is None: | |
extras = torch.cat(extras_all, dim=eval_batch_dim) | |
else: | |
extras = None | |
deltas = deltas.float() | |
sigmas = sigmas.float() | |
if ( | |
state_dicts_all is not None and len(state_dicts_all) != 0 | |
): ## not empty in a list | |
state_dicts = { | |
key: torch.cat( | |
[state_dicts[key] for state_dicts in state_dicts_all], | |
dim=eval_batch_dim, | |
) | |
for key in state_dicts_all[0].keys() | |
} | |
else: | |
state_dicts = None | |
if rgbs is not None: | |
rgbs = rgbs.reshape(B, K, -1) # (B, K, 4 or 5) | |
invalid = invalid.reshape(B, K, -1) | |
sigmas = sigmas.reshape(B, K) | |
if extras is not None: | |
extras = extras.reshape(B, K, -1) | |
if state_dicts is not None: | |
state_dicts = { | |
key: value.reshape(B, K, *value.shape[2:]) | |
for key, value in state_dicts.items() | |
} # BxKx... (BxKxn_viewsx...) | |
if self.training and self.noise_std > 0.0: | |
sigmas = sigmas + torch.randn_like(sigmas) * self.noise_std | |
alphas = 1 - torch.exp( | |
-deltas.abs() * torch.relu(sigmas) | |
) # (B, K) (delta should be positive anyways) | |
if self.hard_alpha_cap: | |
alphas[:, -1] = 1 | |
deltas, sigmas = None, None | |
alphas_shifted = torch.cat( | |
[torch.ones_like(alphas[:, :1]), 1 - alphas + 1e-10], -1 | |
) # (B, K+1) = [1, a1, a2, ...] | |
T = torch.cumprod(alphas_shifted, -1) # (B) | |
weights = alphas * T[:, :-1] # (B, K) | |
# alphas = None | |
alphas_shifted = None | |
depth_final = torch.sum(weights * z_samp, -1) # (B) | |
state_dicts["dino_features"] = torch.sum(state_dicts["dino_features"].mul_(weights.unsqueeze(-1)), -2) | |
if self.render_mode == "neus": | |
# dist_from_surf = z_samp - depth_final[..., None] | |
indices = torch.arange(0, weights.shape[-1], device=weights.device, dtype=weights.dtype).unsqueeze(0) | |
surface_index = torch.sum(weights * indices, dim=-1, keepdim=True) | |
dist_from_surf = surface_index - indices | |
weights = torch.exp(-.5 * (dist_from_surf * self.surface_sigmoid_scale) ** 2) | |
weights = weights / torch.sum(weights, dim=-1, keepdim=True) | |
if not self.only_surface_color: | |
rgb_final = torch.sum(weights.unsqueeze(-1) * rgbs, -2) # (B, 3) | |
else: | |
surface_points = rays[:, None, :3] + depth_final[:, None, None] * rays[:, None, 3:6] | |
surface_points = surface_points.reshape(sb, -1, 3) | |
if ray_info is not None: | |
ray_info = ray_info.reshape(sb, -1, K, ray_info.shape[-1])[:, :, 0, :] | |
rgb_final, invalid_colors = model.sample_colors(surface_points, ray_info=ray_info, render_flow=self.render_flow) | |
rgb_final = rgb_final.permute(0, 2, 1, 3).reshape(B, -1) | |
invalid_colors = invalid_colors.permute(0, 2, 1, 3).reshape(B, 1, -1) | |
invalid = ((invalid > .5) | invalid_colors).float() | |
if self.white_bkgd: | |
# White background | |
pix_alpha = weights.sum(dim=1) # (B), pixel alpha | |
rgb_final = rgb_final + 1 - pix_alpha.unsqueeze(-1) # (B, 3) | |
if extras is not None: | |
extras_final = torch.sum(weights.unsqueeze(-1) * extras, -2) # (B, extras) | |
else: | |
extras_final = None | |
for name, x in [("weights", weights), ("rgb_final", rgb_final), ("depth_final", depth_final), ("alphas", alphas), ("invalid", invalid), ("z_samp", z_samp)]: | |
if torch.any(torch.isnan(x)): | |
print(f"Detected NaN in {name} ({x.dtype}):") | |
print(x) | |
exit() | |
if ray_info is not None: | |
ray_info = rays[:, None, 8:] | |
# return (weights, rgb_final, depth_final, alphas, invalid, z_samp, rgbs, viewdirs) | |
return ( | |
weights, | |
rgb_final, | |
depth_final, | |
alphas, | |
invalid, | |
z_samp, | |
rgbs, | |
ray_info, | |
extras_final, | |
state_dicts, | |
) | |
def forward( | |
self, | |
model, | |
rays, | |
want_weights=False, | |
want_alphas=False, | |
want_z_samps=False, | |
want_rgb_samps=False, | |
sample_from_dist=None, | |
): | |
""" | |
:model nerf model, should return (SB, B, (r, g, b, sigma)) | |
when called with (SB, B, (x, y, z)), for multi-object: | |
SB = 'super-batch' = size of object batch, | |
B = size of per-object ray batch. | |
Should also support 'coarse' boolean argument for coarse NeRF. | |
:param rays ray spec [origins (3), directions (3), near (1), far (1)] (SB, B, 8) | |
:param want_weights if true, returns compositing weights (SB, B, K) | |
:return render dict | |
""" | |
with profiler.record_function("renderer_forward"): | |
if self.sched is not None and self.last_sched.item() > 0: | |
self.n_coarse = self.sched[1][self.last_sched.item() - 1] | |
self.n_fine = self.sched[2][self.last_sched.item() - 1] | |
assert len(rays.shape) == 3 | |
superbatch_size = rays.shape[0] | |
r_dim = rays.shape[-1] | |
rays = rays.reshape(-1, r_dim) # (SB * B, 8) | |
if sample_from_dist is None: | |
z_coarse = self.sample_coarse(rays) # (B, Kc) | |
else: | |
prop_weights, prop_z_samp = sample_from_dist | |
n_samples = prop_weights.shape[-1] | |
prop_weights = prop_weights.reshape(-1, n_samples) | |
prop_z_samp = prop_z_samp.reshape(-1, n_samples) | |
z_coarse = self.sample_coarse_from_dist(rays, prop_weights, prop_z_samp) | |
z_coarse, _ = torch.sort(z_coarse, dim=-1) | |
coarse_composite = self.composite( | |
model, | |
rays, | |
z_coarse, | |
coarse=True, | |
sb=superbatch_size, | |
) | |
outputs = DotMap( | |
coarse=self._format_outputs( | |
coarse_composite, | |
superbatch_size, | |
want_weights=want_weights, | |
want_alphas=want_alphas, | |
want_z_samps=want_z_samps, | |
want_rgb_samps=want_rgb_samps, | |
), | |
) | |
outputs.state_dict = coarse_composite[-1] | |
if self.using_fine: | |
all_samps = [z_coarse] | |
if self.n_fine - self.n_fine_depth > 0: | |
all_samps.append( | |
self.sample_fine(rays, coarse_composite[0].detach()) | |
) # (B, Kf - Kfd) | |
if self.n_fine_depth > 0: | |
all_samps.append( | |
self.sample_fine_depth(rays, coarse_composite[2]) | |
) # (B, Kfd) | |
z_combine = torch.cat(all_samps, dim=-1) # (B, Kc + Kf) | |
z_combine_sorted, argsort = torch.sort(z_combine, dim=-1) | |
fine_composite = self.composite( | |
model, | |
rays, | |
z_combine_sorted, | |
coarse=False, | |
sb=superbatch_size, | |
) | |
outputs.fine = self._format_outputs( | |
fine_composite, | |
superbatch_size, | |
want_weights=want_weights, | |
want_alphas=want_alphas, | |
want_z_samps=want_z_samps, | |
want_rgb_samps=want_rgb_samps, | |
) | |
return outputs | |
def _format_outputs( | |
self, | |
rendered_outputs, | |
superbatch_size, | |
want_weights=False, | |
want_alphas=False, | |
want_z_samps=False, | |
want_rgb_samps=False, | |
): | |
( | |
weights, | |
rgb_final, | |
depth, | |
alphas, | |
invalid, | |
z_samps, | |
rgb_samps, | |
ray_info, | |
extras, | |
state_dict, | |
) = rendered_outputs | |
n_smps = weights.shape[-1] | |
out_d_rgb = rgb_final.shape[-1] | |
out_d_i = invalid.shape[-1] | |
out_d_dino = state_dict["dino_features"].shape[-1] | |
if superbatch_size > 0: | |
rgb_final = rgb_final.reshape(superbatch_size, -1, out_d_rgb) | |
depth = depth.reshape(superbatch_size, -1) | |
invalid = invalid.reshape(superbatch_size, -1, n_smps, out_d_i) | |
ret_dict = DotMap(rgb=rgb_final, depth=depth, invalid=invalid) | |
if ray_info is not None: | |
ri_shape = ray_info.shape[-1] | |
ray_info = ray_info.reshape(superbatch_size, -1, ri_shape) | |
ret_dict.ray_info = ray_info | |
if extras is not None: | |
extras_shape = extras.shape[-1] | |
extras = extras.reshape(superbatch_size, -1, extras_shape) | |
ret_dict.extras = extras | |
if want_weights: | |
weights = weights.reshape(superbatch_size, -1, n_smps) | |
ret_dict.weights = weights | |
if want_alphas: | |
alphas = alphas.reshape(superbatch_size, -1, n_smps) | |
ret_dict.alphas = alphas | |
if want_z_samps: | |
z_samps = z_samps.reshape(superbatch_size, -1, n_smps) | |
ret_dict.z_samps = z_samps | |
if want_rgb_samps: | |
rgb_samps = rgb_samps.reshape(superbatch_size, -1, n_smps, out_d_rgb) | |
ret_dict.rgb_samps = rgb_samps | |
if "dino_features" in state_dict: | |
dino_features = state_dict["dino_features"].reshape(superbatch_size, -1, out_d_dino) | |
ret_dict.dino_features = dino_features | |
if "invalid_features" in state_dict: | |
invalid_features = state_dict["invalid_features"].reshape(superbatch_size, -1, n_smps, out_d_i) | |
ret_dict.invalid_features = invalid_features | |
return ret_dict | |
def sched_step(self, steps=1): | |
""" | |
Called each training iteration to update sample numbers | |
according to schedule | |
""" | |
if self.sched is None: | |
return | |
self.iter_idx += steps | |
while ( | |
self.last_sched.item() < len(self.sched[0]) | |
and self.iter_idx.item() >= self.sched[0][self.last_sched.item()] | |
): | |
self.n_coarse = self.sched[1][self.last_sched.item()] | |
self.n_fine = self.sched[2][self.last_sched.item()] | |
print( | |
"INFO: NeRF sampling resolution changed on schedule ==> c", | |
self.n_coarse, | |
"f", | |
self.n_fine, | |
) | |
self.last_sched += 1 | |
def from_conf(cls, conf, white_bkgd=False, eval_batch_size=100000): | |
return cls( | |
conf.get("n_coarse", 128), | |
conf.get("n_fine", 0), | |
n_fine_depth=conf.get("n_fine_depth", 0), | |
noise_std=conf.get("noise_std", 0.0), | |
depth_std=conf.get("depth_std", 0.01), | |
white_bkgd=conf.get("white_bkgd", white_bkgd), | |
lindisp=conf.get("lindisp", True), | |
eval_batch_size=conf.get("eval_batch_size", eval_batch_size), | |
sched=conf.get("sched", None), | |
hard_alpha_cap=conf.get("hard_alpha_cap", False), | |
render_mode=conf.get("render_mode", "volumetric"), | |
surface_sigmoid_scale=conf.get("surface_sigmoid_scale", 1), | |
render_flow=conf.get("render_flow", False), | |
normalize_dino=conf.get("normalize_dino", False), | |
) | |
def bind_parallel(self, net, gpus=None, simple_output=False): | |
""" | |
Returns a wrapper module compatible with DataParallel. | |
Specifically, it renders rays with this renderer | |
but always using the given network instance. | |
Specify a list of GPU ids in 'gpus' to apply DataParallel automatically. | |
:param net A PixelNeRF network | |
:param gpus list of GPU ids to parallize to. If length is 1, | |
does not parallelize | |
:param simple_output only returns rendered (rgb, depth) instead of the | |
full render output map. Saves data tranfer cost. | |
:return torch module | |
""" | |
wrapped = _RenderWrapper(net, self, simple_output=simple_output) | |
if gpus is not None and len(gpus) > 1: | |
print("Using multi-GPU", gpus) | |
wrapped = torch.nn.DataParallel(wrapped, gpus, dim=1) | |
return wrapped | |