import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import nvdiffrast.torch as dr import kiui from kiui.mesh import Mesh import json from pathlib import Path import tqdm from PIL import Image from torchvision.transforms.functional import to_tensor from torchvision.utils import save_image import trimesh from mediapy import write_image, write_video from einops import rearrange from kiui.op import uv_padding, safe_normalize, inverse_sigmoid from kiui.cam import orbit_camera, get_perspective from torchmetrics.image import LearnedPerceptualImagePatchSimilarity from mesh import Mesh from mediapy import read_video import tyro from datasets.v3d import get_uniform_poses class Refiner(nn.Module): def __init__(self, mesh_filename, video, num_opt=4, lpips: float = 0.0) -> None: super().__init__() self.output_size = 512 znear = 0.1 zfar = 10 self.mesh = Mesh.load_obj(mesh_filename) # self.mesh.v[..., 1], self.mesh.v[..., 2] = ( # self.mesh.v[..., 2], # self.mesh.v[..., 1], # ) self.glctx = dr.RasterizeGLContext() self.device = torch.device("cuda") self.lpips_meter = LearnedPerceptualImagePatchSimilarity( net_type="vgg", normalize=True ).to(self.device) self.lpips = lpips fov = 60 frames = read_video(video) self.name = Path(video).stem frames = frames.astype(np.float32) / 255.0 frames = np.moveaxis(frames, -1, 1) num_frames, h, w, c = frames.shape self.poses = get_uniform_poses(num_frames, 2.0, 0.0, opengl=True) frames = frames.astype(np.float32) / 255.0 self.image_gt = torch.from_numpy(frames).to(self.device) self.n_frames = len(self.poses) self.opt_frames = np.linspace(0, self.n_frames, num_opt + 1)[:num_opt].astype( int ) print(self.opt_frames) # gs renderer self.tan_half_fov = np.tan(0.5 * np.deg2rad(fov)) self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) self.proj_matrix[0, 0] = 1 / self.tan_half_fov self.proj_matrix[1, 1] = 1 / self.tan_half_fov self.proj_matrix[2, 2] = (zfar + znear) / (zfar - znear) self.proj_matrix[3, 2] = -(zfar * znear) / (zfar - znear) self.proj_matrix[2, 3] = 1 self.glctx = dr.RasterizeGLContext() self.proj = torch.from_numpy(get_perspective(fov)).float().to(self.device) self.v = self.mesh.v.contiguous().float().to(self.device) self.f = self.mesh.f.contiguous().int().to(self.device) self.vc = self.mesh.vc.contiguous().float().to(self.device) # self.vt = self.mesh.vt # self.ft = self.mesh.ft def render_normal(self, pose): h = w = self.output_size v = self.v f = self.f if not hasattr(self.mesh, "vn") or self.mesh.vn is None: self.mesh.auto_normal() vc = self.mesh.vn.to(self.device) pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) vc = torch.einsum("ij, kj -> ki", pose[:3, :3].T, vc).contiguous() # get v_clip and render rgb v_cam = ( torch.matmul( F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T ) .float() .unsqueeze(0) ) v_clip = v_cam @ self.proj.T rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] alpha = ( dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) ) # [H, W] important to enable gradients! # color, texc_db = dr.interpolate( # self.vc.unsqueeze(0), rast, f, rast_db=rast_db, diff_attrs="all" # ) color, texc_db = dr.interpolate(vc.unsqueeze(0), rast, f) color = dr.antialias(color, rast, v_clip, f) # image = torch.sigmoid( # dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db) # ) # [1, H, W, 3] image = color.view(1, h, w, 3) # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] image = (image + 1) / 2.0 image = alpha * image + (1 - alpha) return image, alpha def render_mesh(self, pose, use_sigmoid=True): h = w = self.output_size v = self.v f = self.f if use_sigmoid: vc = torch.sigmoid(self.vc) else: vc = self.vc pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) # get v_clip and render rgb v_cam = ( torch.matmul( F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T ) .float() .unsqueeze(0) ) v_clip = v_cam @ self.proj.T rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] alpha = ( dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) ) # [H, W] important to enable gradients! # color, texc_db = dr.interpolate( # self.vc.unsqueeze(0), rast, f, rast_db=rast_db, diff_attrs="all" # ) color, texc_db = dr.interpolate(vc.unsqueeze(0), rast, f) color = dr.antialias(color, rast, v_clip, f) # image = torch.sigmoid( # dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db) # ) # [1, H, W, 3] image = color.view(1, h, w, 3) # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] image = alpha * image + (1 - alpha) return image, alpha def refine_texture(self, texture_resolution: int = 512, iters: int = 5000): h = w = texture_resolution albedo = torch.ones(h * w, 3, device=self.device, dtype=torch.float32) * 0.5 albedo = albedo.view(h, w, -1) vc_original = self.vc.clone() self.vc = nn.Parameter(inverse_sigmoid(vc_original)).to(self.device) optimizer = torch.optim.Adam( [ {"params": self.vc, "lr": 1e-3}, ] ) pbar = tqdm.trange(iters) for i in pbar: index = np.random.choice(self.opt_frames) pose = self.poses[index] image_gt = self.image_gt[index] image_pred, _ = self.render_mesh(pose) # if i % 1000 == 0: # save_image(image_pred, f"tmp/image_pred_{i}.png") # save_image(image_gt, f"tmp/image_gt_{i}.png") loss = F.mse_loss(image_pred, image_gt) if self.lpips > 0.0: loss += ( self.lpips_meter( image_gt.clamp(0, 1)[None], image_pred.clamp(0, 1)[None] ) * self.lpips ) # * 10.0 loss.backward() optimizer.step() optimizer.zero_grad() pbar.set_description(f"MSE = {loss.item():.6f}") @torch.no_grad() def render_spiral(self): images = [] for i, pose in enumerate(self.poses): image, _ = self.render_mesh(pose, use_sigmoid=False) images.append(image) images = torch.stack(images) images = images.cpu().numpy() images = rearrange(images, "b c h w -> b h w c") if not Path("renders").exists(): Path("renders").mkdir(parents=True, exist_ok=True) write_video(f"renders/{self.name}.mp4", images, fps=3) @torch.no_grad() def render_normal_spiral(self): images = [] for i, pose in enumerate(self.poses): image, _ = self.render_normal(pose) images.append(image) images = torch.stack(images) images = images.cpu().numpy() images = rearrange(images, "b c h w -> b h w c") Path("renders").mkdir(exist_ok=True, parents=True) write_video(f"renders/{self.name}_normal.mp4", images, fps=3) def export(self, filename): mesh = trimesh.Trimesh( vertices=self.mesh.v.cpu().numpy(), faces=self.mesh.f.cpu().numpy(), vertex_colors=torch.sigmoid(self.vc.detach()).cpu().numpy(), ) self.vc.data = torch.sigmoid(self.vc.detach()) trimesh.repair.fix_inversion(mesh) mesh.export(filename) def do_refine( mesh: str, scene: str, num_opt: int = 4, iters: int = 2000, skip_refine: bool = False, render_normal: bool = True, lpips: float = 1.0, ): refiner = Refiner( # "tmp/corgi_size_1.obj", mesh, scene, num_opt=num_opt, lpips=lpips, ) if not skip_refine: refiner.refine_texture(512, iters) save_path = Path("refined") / f"{Path(scene).stem}.obj" if not save_path.parent.exists(): save_path.parent.mkdir(exist_ok=True, parents=True) refiner.export(str(save_path)) refiner.render_spiral() if render_normal: refiner.render_normal_spiral() if __name__ == "__main__": tyro.cli(do_refine)