Spaces:
Build error
Build error
import PIL | |
import torch | |
import numpy as np | |
import gsplat as gs | |
import torch.nn as nn | |
from copy import deepcopy | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from ops.utils import ( | |
dpt2xyz, | |
alpha_inpaint_mask, | |
transform_points, | |
numpy_normalize, | |
numpy_quaternion_from_matrix | |
) | |
class Frame(): | |
''' | |
rgb: in shape of H*W*3, in range of 0-1 | |
dpt: in shape of H*W, real depth | |
inpaint: bool mask in shape of H*W for inpainting | |
intrinsic: 3*3 | |
extrinsic: array in shape of 4*4 | |
As a class for: | |
initialize camera | |
accept rendering result | |
accept inpainting result | |
All at 2D-domain | |
''' | |
def __init__(self, | |
H: int = None, | |
W: int = None, | |
rgb: np.array = None, | |
dpt: np.array = None, | |
sky: np.array = None, | |
inpaint: np.array = None, | |
intrinsic: np.array = None, | |
extrinsic: np.array = None, | |
# detailed target | |
ideal_dpt: np.array = None, | |
ideal_nml: np.array = None, | |
prompt: str = None) -> None: | |
self.H = H | |
self.W = W | |
self.rgb = rgb | |
self.dpt = dpt | |
self.sky = sky | |
self.prompt = prompt | |
self.intrinsic = intrinsic | |
self.extrinsic = extrinsic | |
self._rgb_rect() | |
self._extr_rect() | |
# for inpainting | |
self.inpaint = inpaint | |
self.inpaint_wo_edge = inpaint | |
# for supervision | |
self.ideal_dpt = ideal_dpt | |
self.ideal_nml = ideal_nml | |
def _rgb_rect(self): | |
if self.rgb is not None: | |
if isinstance(self.rgb, PIL.PngImagePlugin.PngImageFile): | |
self.rgb = np.array(self.rgb) | |
if isinstance(self.rgb, PIL.JpegImagePlugin.JpegImageFile): | |
self.rgb = np.array(self.rgb) | |
if np.amax(self.rgb) > 1.1: | |
self.rgb = self.rgb / 255 | |
def _extr_rect(self): | |
if self.extrinsic is None: self.extrinsic = np.eye(4) | |
self.inv_extrinsic = np.linalg.inv(self.extrinsic) | |
class Gaussian_Frame(): | |
''' | |
In-frame-frustrum | |
Gaussians from a single RGBD frame | |
As a class for: | |
accept information from initialized/inpainting+geo-estimated frame | |
saving pixelsplat properties including rgb, xyz, scale, rotation, opacity; note here, we made a modification to xyz; | |
we first project depth to xyz | |
then we tune a scale map(initialized to ones) and a shift map(initialized to zeros), they are optimized and add to the original xyz when rendering | |
''' | |
# as pixelsplat guassian | |
rgb: torch.Tensor = None, | |
scale: torch.Tensor = None, | |
opacity: torch.Tensor = None, | |
rotation: torch.Tensor = None, | |
# gaussian center | |
dpt: torch.Tensor = None, | |
xyz: torch.Tensor = None, | |
# as a frame | |
H: int = 480, | |
W: int = 640, | |
def __init__(self, frame: Frame, device = 'cuda'): | |
'''after inpainting''' | |
# de-active functions | |
self.rgbs_deact = torch.logit | |
self.scales_deact = torch.log | |
self.opacity_deact = torch.logit | |
self.device = device | |
# for gaussian initialization | |
self._set_property_from_frame(frame) | |
def _to_3d(self): | |
# inv intrinsic | |
xyz = dpt2xyz(self.dpt,self.intrinsic) | |
inv_extrinsic = np.linalg.inv(self.extrinsic) | |
xyz = transform_points(xyz,inv_extrinsic) | |
return xyz | |
def _paint_filter(self,paint_mask): | |
if np.sum(paint_mask)<3: | |
paint_mask = np.zeros((self.H,self.W)) | |
paint_mask[0:1] = 1 | |
paint_mask = paint_mask>.5 | |
self.rgb = self.rgb[paint_mask] | |
self.xyz = self.xyz[paint_mask] | |
self.scale = self.scale[paint_mask] | |
self.opacity = self.opacity[paint_mask] | |
self.rotation = self.rotation[paint_mask] | |
def _to_cuda(self): | |
self.rgb = torch.from_numpy(self.rgb.astype(np.float32)).to(self.device) | |
self.xyz = torch.from_numpy(self.xyz.astype(np.float32)).to(self.device) | |
self.scale = torch.from_numpy(self.scale.astype(np.float32)).to(self.device) | |
self.opacity = torch.from_numpy(self.opacity.astype(np.float32)).to(self.device) | |
self.rotation = torch.from_numpy(self.rotation.astype(np.float32)).to(self.device) | |
def _fine_init_scale_rotations(self): | |
# from https://arxiv.org/pdf/2406.09394 | |
""" Compute rotation matrices that align z-axis with given normal vectors using matrix operations. """ | |
up_axis = np.array([0,1,0]) | |
nml = self.nml @ self.extrinsic[0:3,0:3] | |
qz = numpy_normalize(nml) | |
qx = np.cross(up_axis,qz) | |
qx = numpy_normalize(qx) | |
qy = np.cross(qz,qx) | |
qy = numpy_normalize(qy) | |
rot = np.concatenate([qx[...,None],qy[...,None],qz[...,None]],axis=-1) | |
self.rotation = numpy_quaternion_from_matrix(rot) | |
# scale | |
safe_nml = deepcopy(self.nml) | |
safe_nml[safe_nml[:,:,-1]<0.2,-1] = .2 | |
normal_xoz = deepcopy(safe_nml) | |
normal_yoz = deepcopy(safe_nml) | |
normal_xoz[...,1] = 0. | |
normal_yoz[...,0] = 0. | |
normal_xoz = numpy_normalize(normal_xoz) | |
normal_yoz = numpy_normalize(normal_yoz) | |
cos_theta_x = np.abs(normal_xoz[...,2]) | |
cos_theta_y = np.abs(normal_yoz[...,2]) | |
scale_basic = self.dpt / self.intrinsic[0,0] / np.sqrt(2) | |
scale_x = scale_basic / cos_theta_x | |
scale_y = scale_basic / cos_theta_y | |
scale_z = (scale_x + scale_y) / 10. | |
self.scale = np.concatenate([scale_x[...,None], | |
scale_y[...,None], | |
scale_z[...,None]],axis=-1) | |
def _coarse_init_scale_rotations(self): | |
# gaussian property -- HW3 scale | |
self.scale = self.dpt / self.intrinsic[0,0] / np.sqrt(2) | |
self.scale = self.scale[:,:,None].repeat(3,-1) | |
# gaussian property -- HW4 rotation | |
self.rotation = np.zeros((self.H,self.W,4)) | |
self.rotation[:,:,0] = 1. | |
def _set_property_from_frame(self,frame: Frame): | |
'''frame here is a complete init/inpainted frame''' | |
# basic frame-level property | |
self.H = frame.H | |
self.W = frame.W | |
self.dpt = frame.dpt | |
self.intrinsic = frame.intrinsic | |
self.extrinsic = frame.extrinsic | |
# gaussian property -- xyz with train-able pixel-aligned scale and shift | |
self.xyz = self._to_3d() | |
# gaussian property -- HW3 rgb | |
self.rgb = frame.rgb | |
# gaussian property -- HW4 rotation HW3 scale | |
self._coarse_init_scale_rotations() | |
# gaussian property -- HW opacity | |
self.opacity = np.ones((self.H,self.W,1)) * 0.8 | |
# to cuda | |
self._paint_filter(frame.inpaint_wo_edge) | |
self._to_cuda() | |
# de-activate | |
self.rgb = self.rgbs_deact(self.rgb) | |
self.scale = self.scales_deact(self.scale) | |
self.opacity = self.opacity_deact(self.opacity) | |
# to torch parameters | |
self.rgb = nn.Parameter(self.rgb,requires_grad=False) | |
self.xyz = nn.Parameter(self.xyz,requires_grad=False) | |
self.scale = nn.Parameter(self.scale,requires_grad=False) | |
self.opacity = nn.Parameter(self.opacity,requires_grad=False) | |
self.rotation = nn.Parameter(self.rotation,requires_grad=False) | |
def _require_grad(self,sign=True): | |
self.rgb = self.rgb.requires_grad_(sign) | |
self.xyz = self.xyz.requires_grad_(sign) | |
self.scale = self.scale.requires_grad_(sign) | |
self.opacity = self.opacity.requires_grad_(sign) | |
self.rotation = self.rotation.requires_grad_(sign) | |
class Gaussian_Scene(): | |
def __init__(self,cfg=None): | |
# frames initialing the frame | |
self.frames = [] | |
self.gaussian_frames: list[Gaussian_Frame] = [] # gaussian frame require training at this optimization | |
# activate fuctions | |
self.rgbs_act = torch.sigmoid | |
self.scales_act = torch.exp | |
self.opacity_act = torch.sigmoid | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# for traj generation | |
self.traj_type = 'spiral' | |
if cfg is not None: | |
self.traj_min_percentage = cfg.scene.traj.near_percentage | |
self.traj_max_percentage = cfg.scene.traj.far_percentage | |
self.traj_forward_ratio = cfg.scene.traj.traj_forward_ratio | |
self.traj_backward_ratio = cfg.scene.traj.traj_backward_ratio | |
else: | |
self.traj_min_percentage,self.traj_max_percentage,self.traj_forward_ratio,self.traj_backward_ratio = 5, 50, 0.3, 0.4 | |
# basic operations | |
def _render_RGBD(self,frame,background_color='black'): | |
''' | |
:intinsic: tensor of [fu,fv,cu,cv] 4-dimension | |
:extinsic: tensor 4*4-dimension | |
:out: tensor H*W*3-dimension | |
''' | |
background = None | |
if background_color =='white': | |
background = torch.ones(1,4,device=self.device)*0.1 | |
background[:,-1] = 0. # for depth | |
# aligned untrainable xyz and unaligned trainable xyz | |
# others | |
xyz = torch.cat([gf.xyz.reshape(-1,3) for gf in self.gaussian_frames],dim=0) | |
rgb = torch.cat([gf.rgb.reshape(-1,3) for gf in self.gaussian_frames],dim=0) | |
scale = torch.cat([gf.scale.reshape(-1,3) for gf in self.gaussian_frames],dim=0) | |
opacity = torch.cat([gf.opacity.reshape(-1) for gf in self.gaussian_frames],dim=0) | |
rotation = torch.cat([gf.rotation.reshape(-1,4) for gf in self.gaussian_frames],dim=0) | |
# activate | |
rgb = self.rgbs_act(rgb) | |
scale = self.scales_act(scale) | |
rotation = F.normalize(rotation,dim=1) | |
opacity = self.opacity_act(opacity) | |
# property | |
H,W = frame.H, frame.W | |
intrinsic = torch.from_numpy(frame.intrinsic.astype(np.float32)).to(self.device) | |
extrinsic = torch.from_numpy(frame.extrinsic.astype(np.float32)).to(self.device) | |
# render | |
render_out,render_alpha,_ = gs.rendering.rasterization(means = xyz, | |
scales = scale, | |
quats = rotation, | |
opacities = opacity, | |
colors = rgb, | |
Ks = intrinsic[None], | |
viewmats = extrinsic[None], | |
width = W, | |
height = H, | |
packed = False, | |
near_plane= 0.01, | |
render_mode="RGB+ED", | |
backgrounds=background) # render: 1*H*W*(3+1) | |
render_out = render_out.squeeze() # result: H*W*(3+1) | |
render_rgb = render_out[:,:,0:3] | |
render_dpt = render_out[:,:,-1] | |
return render_rgb, render_dpt, render_alpha | |
def _render_for_inpaint(self,frame): | |
# first render | |
render_rgb, render_dpt, render_alpha = self._render_RGBD(frame) | |
render_msk = alpha_inpaint_mask(render_alpha) | |
# to numpy | |
render_rgb = render_rgb.detach().cpu().numpy() | |
render_dpt = render_dpt.detach().cpu().numpy() | |
render_alpha = render_alpha.detach().cpu().numpy() | |
# assign back | |
frame.rgb = render_rgb | |
frame.dpt = render_dpt | |
frame.inpaint = render_msk | |
return frame | |
def _add_trainable_frame(self,frame:Frame,require_grad=True): | |
# for the init frame, we keep all pixels for finetuning | |
self.frames.append(frame) | |
gf = Gaussian_Frame(frame, self.device) | |
gf._require_grad(require_grad) | |
self.gaussian_frames.append(gf) | |