lgm / convert.py
brarnovidra's picture
lgm
e206fc8
raw
history blame
17.6 kB
import os
import tyro
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.options import AllConfigs, Options
from core.gs import GaussianRenderer
import mcubes
import nerfacc
import nvdiffrast.torch as dr
import kiui
from kiui.mesh import Mesh
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency
from kiui.op import uv_padding, safe_normalize, inverse_sigmoid
from kiui.cam import orbit_camera, get_perspective
from kiui.nn import MLP, trunc_exp
from kiui.gridencoder import GridEncoder
def get_rays(pose, h, w, fovy, opengl=True):
x, y = torch.meshgrid(
torch.arange(w, device=pose.device),
torch.arange(h, device=pose.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
cx = w * 0.5
cy = h * 0.5
focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
camera_dirs = F.pad(
torch.stack(
[
(x - cx + 0.5) / focal,
(y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if opengl else 1.0),
) # [hw, 3]
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
rays_d = safe_normalize(rays_d)
return rays_o, rays_d
# Triple renderer of gaussians, gaussian, and diso mesh.
# gaussian --> nerf --> mesh
class Converter(nn.Module):
def __init__(self, opt: Options):
super().__init__()
self.opt = opt
self.device = torch.device("cuda")
# gs renderer
self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device)
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
self.proj_matrix[2, 3] = 1
self.gs_renderer = GaussianRenderer(opt)
self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device)
# nerf renderer
if not self.opt.force_cuda_rast:
self.glctx = dr.RasterizeGLContext()
else:
self.glctx = dr.RasterizeCudaContext()
self.step = 0
self.render_step_size = 5e-3
self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
self.estimator = nerfacc.OccGridEstimator(roi_aabb=self.aabb, resolution=64, levels=1)
self.encoder_density = GridEncoder(num_levels=12) # VMEncoder(output_dim=16, mode='sum')
self.encoder = GridEncoder(num_levels=12)
self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
# mesh renderer
self.proj = torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
self.v = self.f = None
self.vt = self.ft = None
self.deform = None
self.albedo = None
@torch.no_grad()
def render_gs(self, pose):
cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
out = self.gs_renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0))
image = out['image'].squeeze(1).squeeze(0) # [C, H, W]
alpha = out['alpha'].squeeze(2).squeeze(1).squeeze(0) # [H, W]
return image, alpha
def get_density(self, xs):
# xs: [..., 3]
prefix = xs.shape[:-1]
xs = xs.view(-1, 3)
feats = self.encoder_density(xs)
density = trunc_exp(self.mlp_density(feats))
density = density.view(*prefix, 1)
return density
def render_nerf(self, pose):
pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
# get rays
resolution = self.opt.output_size
rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
# update occ grid
if self.training:
def occ_eval_fn(xs):
sigmas = self.get_density(xs)
return self.render_step_size * sigmas
self.estimator.update_every_n_steps(self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8)
self.step += 1
# render
def sigma_fn(t_starts, t_ends, ray_indices):
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
sigmas = self.get_density(xs)
return sigmas.squeeze(-1)
with torch.no_grad():
ray_indices, t_starts, t_ends = self.estimator.sampling(
rays_o,
rays_d,
sigma_fn=sigma_fn,
near_plane=0.01,
far_plane=100,
render_step_size=self.render_step_size,
stratified=self.training,
cone_angle=0,
)
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
sigmas = self.get_density(xs).squeeze(-1)
rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
n_rays=rays_o.shape[0]
weights, trans, alphas = nerfacc.render_weight_from_density(t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays)
color = nerfacc.accumulate_along_rays(weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays)
alpha = nerfacc.accumulate_along_rays(weights, values=None, ray_indices=ray_indices, n_rays=n_rays)
color = color + 1 * (1.0 - alpha)
color = color.view(resolution, resolution, 3).clamp(0, 1).permute(2, 0, 1).contiguous()
alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
return color, alpha
def fit_nerf(self, iters=512, resolution=128):
self.opt.output_size = resolution
optimizer = torch.optim.Adam([
{'params': self.encoder_density.parameters(), 'lr': 1e-2},
{'params': self.encoder.parameters(), 'lr': 1e-2},
{'params': self.mlp_density.parameters(), 'lr': 1e-3},
{'params': self.mlp.parameters(), 'lr': 1e-3},
])
print(f"[INFO] fitting nerf...")
pbar = tqdm.trange(iters)
for i in pbar:
ver = np.random.randint(-45, 45)
hor = np.random.randint(-180, 180)
rad = np.random.uniform(1.5, 3.0)
pose = orbit_camera(ver, hor, rad)
image_gt, alpha_gt = self.render_gs(pose)
image_pred, alpha_pred = self.render_nerf(pose)
# if i % 200 == 0:
# kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
loss = loss_mse #+ 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
loss.backward()
self.encoder_density.grad_total_variation(1e-8)
optimizer.step()
optimizer.zero_grad()
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
print(f"[INFO] finished fitting nerf!")
def render_mesh(self, pose):
h = w = self.opt.output_size
v = self.v + self.deform
f = self.f
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
# get v_clip and render rgb
v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
v_clip = v_cam @ self.proj.T
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) # [H, W] important to enable gradients!
if self.albedo is None:
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
xyzs = xyzs.view(-1, 3)
mask = (alpha > 0).view(-1)
image = torch.zeros_like(xyzs, dtype=torch.float32)
if mask.any():
masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask].detach(), bound=1)))
image[mask] = masked_albedo.float()
else:
texc, texc_db = dr.interpolate(self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs='all')
image = torch.sigmoid(dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)) # [1, H, W, 3]
image = image.view(1, h, w, 3)
# image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
image = alpha * image + (1 - alpha)
return image, alpha
def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4):
self.opt.output_size = resolution
# init mesh from nerf
grid_size = 256
sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
S = 128
density_thresh = 10
X = torch.linspace(-1, 1, grid_size).split(S)
Y = torch.linspace(-1, 1, grid_size).split(S)
Z = torch.linspace(-1, 1, grid_size).split(S)
for xi, xs in enumerate(X):
for yi, ys in enumerate(Y):
for zi, zs in enumerate(Z):
xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
val = self.get_density(pts.to(self.device))
sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
vertices = vertices / (grid_size - 1.0) * 2 - 1
# clean
vertices = vertices.astype(np.float32)
triangles = triangles.astype(np.int32)
vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
if triangles.shape[0] > decimate_target:
vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
# fit mesh from gs
lr_factor = 1
optimizer = torch.optim.Adam([
{'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
{'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
{'params': self.deform, 'lr': 1e-4},
])
print(f"[INFO] fitting mesh...")
pbar = tqdm.trange(iters)
for i in pbar:
ver = np.random.randint(-10, 10)
hor = np.random.randint(-180, 180)
rad = self.opt.cam_radius # np.random.uniform(1, 2)
pose = orbit_camera(ver, hor, rad)
image_gt, alpha_gt = self.render_gs(pose)
image_pred, alpha_pred = self.render_mesh(pose)
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
# loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
loss_normal = normal_consistency(self.v + self.deform, self.f)
loss_offsets = (self.deform ** 2).sum(-1).mean()
loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
loss.backward()
optimizer.step()
optimizer.zero_grad()
# remesh periodically
if i > 0 and i % 512 == 0:
vertices = (self.v + self.deform).detach().cpu().numpy()
triangles = self.f.detach().cpu().numpy()
vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
if triangles.shape[0] > decimate_target:
vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
lr_factor *= 0.5
optimizer = torch.optim.Adam([
{'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
{'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
{'params': self.deform, 'lr': 1e-4},
])
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
# last clean
vertices = (self.v + self.deform).detach().cpu().numpy()
triangles = self.f.detach().cpu().numpy()
vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
print(f"[INFO] finished fitting mesh!")
# uv mesh refine
def fit_mesh_uv(self, iters=512, resolution=512, texture_resolution=1024, padding=2):
self.opt.output_size = resolution
# unwrap uv
print(f"[INFO] uv unwrapping...")
mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
mesh.auto_normal()
mesh.auto_uv()
self.vt = mesh.vt
self.ft = mesh.ft
# render uv maps
h = w = texture_resolution
uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]
# masked query
xyzs = xyzs.view(-1, 3)
mask = (mask > 0).view(-1)
albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
if mask.any():
print(f"[INFO] querying texture...")
xyzs = xyzs[mask] # [M, 3]
# batched inference to avoid OOM
batch = []
head = 0
while head < xyzs.shape[0]:
tail = min(head + 640000, xyzs.shape[0])
batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
head += 640000
albedo[mask] = torch.cat(batch, dim=0)
albedo = albedo.view(h, w, -1)
mask = mask.view(h, w)
albedo = uv_padding(albedo, mask, padding)
# optimize texture
self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
optimizer = torch.optim.Adam([
{'params': self.albedo, 'lr': 1e-3},
])
print(f"[INFO] fitting mesh texture...")
pbar = tqdm.trange(iters)
for i in pbar:
# shrink to front view as we care more about it...
ver = np.random.randint(-5, 5)
hor = np.random.randint(-15, 15)
rad = self.opt.cam_radius # np.random.uniform(1, 2)
pose = orbit_camera(ver, hor, rad)
image_gt, alpha_gt = self.render_gs(pose)
image_pred, alpha_pred = self.render_mesh(pose)
loss_mse = F.mse_loss(image_pred, image_gt)
loss = loss_mse
loss.backward()
optimizer.step()
optimizer.zero_grad()
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
print(f"[INFO] finished fitting mesh texture!")
@torch.no_grad()
def export_mesh(self, path):
mesh = Mesh(v=self.v, f=self.f, vt=self.vt, ft=self.ft, albedo=torch.sigmoid(self.albedo), device=self.device)
mesh.auto_normal()
mesh.write(path)
opt = tyro.cli(AllConfigs)
# load a saved ply and convert to mesh
assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py'
converter = Converter(opt).cuda()
converter.fit_nerf()
converter.fit_mesh()
converter.fit_mesh_uv()
converter.export_mesh(opt.test_path.replace('.ply', '.glb'))