Spaces:
Configuration error
Configuration error
import torch | |
from lib.config import cfg | |
from .nerf_net_utils import * | |
from .. import embedder | |
from . import if_clight_renderer_mmsk | |
class Renderer(if_clight_renderer_mmsk.Renderer): | |
def __init__(self, net): | |
super(Renderer, self).__init__(net) | |
def prepare_inside_pts(self, wpts, batch): | |
if 'R0_snap' not in batch: | |
__import__('ipdb').set_trace() | |
return raw | |
# transform points from the world space to the smpl space | |
Th = batch['Th'] | |
can_pts = wpts - Th[:, None, None] | |
R = batch['R'] | |
can_pts = torch.matmul(can_pts, R) | |
R0 = batch['R0_snap'] | |
Th0 = batch['Th0_snap'] | |
# transform pts from smpl coordinate to the world coordinate | |
sh = can_pts.shape | |
can_pts = can_pts.view(sh[0], -1, sh[3]) | |
pts = torch.matmul(can_pts, R0.transpose(2, 1)) + Th0[:, None] | |
# project pts to image space | |
R = batch['RT'][..., :3] | |
T = batch['RT'][..., 3] | |
pts = torch.matmul(pts, R.transpose(2, 1)) + T[:, None] | |
pts = torch.matmul(pts, batch['K'].transpose(2, 1)) | |
pts2d = pts[..., :2] / pts[..., 2:] | |
# ensure that pts2d is inside the image | |
pts2d = pts2d.round().long() | |
H, W = int(cfg.H * cfg.ratio), int(cfg.W * cfg.ratio) | |
pts2d[..., 0] = torch.clamp(pts2d[..., 0], 0, W - 1) | |
pts2d[..., 1] = torch.clamp(pts2d[..., 1], 0, H - 1) | |
# remove the points outside the mask | |
pts2d = pts2d[0] | |
msk = batch['msk'][0] | |
inside = msk[pts2d[:, 1], pts2d[:, 0]][None].bool() | |
return inside | |