Spaces:
Configuration error
Configuration error
import torch | |
from lib.config import cfg | |
from .nerf_net_utils import * | |
from .. import embedder | |
from . import if_clight_renderer | |
class Renderer(if_clight_renderer.Renderer): | |
def __init__(self, net): | |
super(Renderer, self).__init__(net) | |
def prepare_inside_pts(self, pts, batch): | |
if 'Ks' not in batch: | |
__import__('ipdb').set_trace() | |
return raw | |
sh = pts.shape | |
pts = pts.view(sh[0], -1, sh[3]) | |
insides = [] | |
for nv in range(batch['Ks'].size(1)): | |
# project pts to image space | |
R = batch['RT'][:, nv, :3, :3] | |
T = batch['RT'][:, nv, :3, 3] | |
pts_ = torch.matmul(pts, R.transpose(2, 1)) + T[:, None] | |
pts_ = torch.matmul(pts_, batch['Ks'][:, nv].transpose(2, 1)) | |
pts2d = pts_[..., :2] / pts_[..., 2:] | |
# ensure that pts2d is inside the image | |
pts2d = pts2d.round().long() | |
H, W = int(cfg.H * cfg.ratio), int(cfg.W * cfg.ratio) | |
pts2d[..., 0] = torch.clamp(pts2d[..., 0], 0, W - 1) | |
pts2d[..., 1] = torch.clamp(pts2d[..., 1], 0, H - 1) | |
# remove the points outside the mask | |
pts2d = pts2d[0] | |
msk = batch['msks'][0, nv] | |
inside = msk[pts2d[:, 1], pts2d[:, 0]][None].bool() | |
insides.append(inside) | |
inside = insides[0] | |
for i in range(1, len(insides)): | |
inside = inside * insides[i] | |
return inside | |
def get_density_color(self, wpts, viewdir, inside, raw_decoder): | |
n_batch, n_pixel, n_sample = wpts.shape[:3] | |
wpts = wpts.view(n_batch, n_pixel * n_sample, -1) | |
viewdir = viewdir[:, :, None].repeat(1, 1, n_sample, 1).contiguous() | |
viewdir = viewdir.view(n_batch, n_pixel * n_sample, -1) | |
wpts = wpts[inside][None] | |
viewdir = viewdir[inside][None] | |
full_raw = torch.zeros([n_batch, n_pixel * n_sample, 4]).to(wpts) | |
if inside.sum() == 0: | |
return full_raw | |
raw = raw_decoder(wpts, viewdir) | |
full_raw[inside] = raw[0] | |
return full_raw | |
def get_pixel_value(self, ray_o, ray_d, near, far, feature_volume, | |
sp_input, batch): | |
# sampling points along camera rays | |
wpts, z_vals = self.get_sampling_points(ray_o, ray_d, near, far) | |
inside = self.prepare_inside_pts(wpts, batch) | |
# viewing direction | |
viewdir = ray_d / torch.norm(ray_d, dim=2, keepdim=True) | |
raw_decoder = lambda x_point, viewdir_val: self.net.calculate_density_color( | |
x_point, viewdir_val, feature_volume, sp_input) | |
# compute the color and density | |
wpts_raw = self.get_density_color(wpts, viewdir, inside, raw_decoder) | |
# volume rendering for wpts | |
n_batch, n_pixel, n_sample = wpts.shape[:3] | |
raw = wpts_raw.reshape(-1, n_sample, 4) | |
z_vals = z_vals.view(-1, n_sample) | |
ray_d = ray_d.view(-1, 3) | |
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs( | |
raw, z_vals, ray_d, cfg.raw_noise_std, cfg.white_bkgd) | |
ret = { | |
'rgb_map': rgb_map.view(n_batch, n_pixel, -1), | |
'disp_map': disp_map.view(n_batch, n_pixel), | |
'acc_map': acc_map.view(n_batch, n_pixel), | |
'weights': weights.view(n_batch, n_pixel, -1), | |
'depth_map': depth_map.view(n_batch, n_pixel) | |
} | |
return ret | |