Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import nvdiffrast.torch as dr | |
import kiui | |
from kiui.mesh import Mesh | |
import json | |
from pathlib import Path | |
import tqdm | |
from PIL import Image | |
from torchvision.transforms.functional import to_tensor | |
from torchvision.utils import save_image | |
import trimesh | |
from mediapy import write_image, write_video | |
from einops import rearrange | |
from kiui.op import uv_padding, safe_normalize, inverse_sigmoid | |
from kiui.cam import orbit_camera, get_perspective | |
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity | |
from mesh import Mesh | |
from mediapy import read_video | |
import tyro | |
from datasets.v3d import get_uniform_poses | |
class Refiner(nn.Module): | |
def __init__(self, mesh_filename, video, num_opt=4, lpips: float = 0.0) -> None: | |
super().__init__() | |
self.output_size = 512 | |
znear = 0.1 | |
zfar = 10 | |
self.mesh = Mesh.load_obj(mesh_filename) | |
# self.mesh.v[..., 1], self.mesh.v[..., 2] = ( | |
# self.mesh.v[..., 2], | |
# self.mesh.v[..., 1], | |
# ) | |
self.glctx = dr.RasterizeGLContext() | |
self.device = torch.device("cuda") | |
self.lpips_meter = LearnedPerceptualImagePatchSimilarity( | |
net_type="vgg", normalize=True | |
).to(self.device) | |
self.lpips = lpips | |
fov = 60 | |
frames = read_video(video) | |
self.name = Path(video).stem | |
frames = frames.astype(np.float32) / 255.0 | |
frames = np.moveaxis(frames, -1, 1) | |
num_frames, h, w, c = frames.shape | |
self.poses = get_uniform_poses(num_frames, 2.0, 0.0, opengl=True) | |
frames = frames.astype(np.float32) / 255.0 | |
self.image_gt = torch.from_numpy(frames).to(self.device) | |
self.n_frames = len(self.poses) | |
self.opt_frames = np.linspace(0, self.n_frames, num_opt + 1)[:num_opt].astype( | |
int | |
) | |
print(self.opt_frames) | |
# gs renderer | |
self.tan_half_fov = np.tan(0.5 * np.deg2rad(fov)) | |
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) | |
self.proj_matrix[0, 0] = 1 / self.tan_half_fov | |
self.proj_matrix[1, 1] = 1 / self.tan_half_fov | |
self.proj_matrix[2, 2] = (zfar + znear) / (zfar - znear) | |
self.proj_matrix[3, 2] = -(zfar * znear) / (zfar - znear) | |
self.proj_matrix[2, 3] = 1 | |
self.glctx = dr.RasterizeGLContext() | |
self.proj = torch.from_numpy(get_perspective(fov)).float().to(self.device) | |
self.v = self.mesh.v.contiguous().float().to(self.device) | |
self.f = self.mesh.f.contiguous().int().to(self.device) | |
self.vc = self.mesh.vc.contiguous().float().to(self.device) | |
# self.vt = self.mesh.vt | |
# self.ft = self.mesh.ft | |
def render_normal(self, pose): | |
h = w = self.output_size | |
v = self.v | |
f = self.f | |
if not hasattr(self.mesh, "vn") or self.mesh.vn is None: | |
self.mesh.auto_normal() | |
vc = self.mesh.vn.to(self.device) | |
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) | |
vc = torch.einsum("ij, kj -> ki", pose[:3, :3].T, vc).contiguous() | |
# get v_clip and render rgb | |
v_cam = ( | |
torch.matmul( | |
F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T | |
) | |
.float() | |
.unsqueeze(0) | |
) | |
v_clip = v_cam @ self.proj.T | |
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) | |
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] | |
alpha = ( | |
dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) | |
) # [H, W] important to enable gradients! | |
# color, texc_db = dr.interpolate( | |
# self.vc.unsqueeze(0), rast, f, rast_db=rast_db, diff_attrs="all" | |
# ) | |
color, texc_db = dr.interpolate(vc.unsqueeze(0), rast, f) | |
color = dr.antialias(color, rast, v_clip, f) | |
# image = torch.sigmoid( | |
# dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db) | |
# ) # [1, H, W, 3] | |
image = color.view(1, h, w, 3) | |
# image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) | |
image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] | |
image = (image + 1) / 2.0 | |
image = alpha * image + (1 - alpha) | |
return image, alpha | |
def render_mesh(self, pose, use_sigmoid=True): | |
h = w = self.output_size | |
v = self.v | |
f = self.f | |
if use_sigmoid: | |
vc = torch.sigmoid(self.vc) | |
else: | |
vc = self.vc | |
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) | |
# get v_clip and render rgb | |
v_cam = ( | |
torch.matmul( | |
F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T | |
) | |
.float() | |
.unsqueeze(0) | |
) | |
v_clip = v_cam @ self.proj.T | |
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) | |
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] | |
alpha = ( | |
dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) | |
) # [H, W] important to enable gradients! | |
# color, texc_db = dr.interpolate( | |
# self.vc.unsqueeze(0), rast, f, rast_db=rast_db, diff_attrs="all" | |
# ) | |
color, texc_db = dr.interpolate(vc.unsqueeze(0), rast, f) | |
color = dr.antialias(color, rast, v_clip, f) | |
# image = torch.sigmoid( | |
# dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db) | |
# ) # [1, H, W, 3] | |
image = color.view(1, h, w, 3) | |
# image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) | |
image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] | |
image = alpha * image + (1 - alpha) | |
return image, alpha | |
def refine_texture(self, texture_resolution: int = 512, iters: int = 5000): | |
h = w = texture_resolution | |
albedo = torch.ones(h * w, 3, device=self.device, dtype=torch.float32) * 0.5 | |
albedo = albedo.view(h, w, -1) | |
vc_original = self.vc.clone() | |
self.vc = nn.Parameter(inverse_sigmoid(vc_original)).to(self.device) | |
optimizer = torch.optim.Adam( | |
[ | |
{"params": self.vc, "lr": 1e-3}, | |
] | |
) | |
pbar = tqdm.trange(iters) | |
for i in pbar: | |
index = np.random.choice(self.opt_frames) | |
pose = self.poses[index] | |
image_gt = self.image_gt[index] | |
image_pred, _ = self.render_mesh(pose) | |
# if i % 1000 == 0: | |
# save_image(image_pred, f"tmp/image_pred_{i}.png") | |
# save_image(image_gt, f"tmp/image_gt_{i}.png") | |
loss = F.mse_loss(image_pred, image_gt) | |
if self.lpips > 0.0: | |
loss += ( | |
self.lpips_meter( | |
image_gt.clamp(0, 1)[None], image_pred.clamp(0, 1)[None] | |
) | |
* self.lpips | |
) | |
# * 10.0 | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
pbar.set_description(f"MSE = {loss.item():.6f}") | |
def render_spiral(self): | |
images = [] | |
for i, pose in enumerate(self.poses): | |
image, _ = self.render_mesh(pose, use_sigmoid=False) | |
images.append(image) | |
images = torch.stack(images) | |
images = images.cpu().numpy() | |
images = rearrange(images, "b c h w -> b h w c") | |
if not Path("renders").exists(): | |
Path("renders").mkdir(parents=True, exist_ok=True) | |
write_video(f"renders/{self.name}.mp4", images, fps=3) | |
def render_normal_spiral(self): | |
images = [] | |
for i, pose in enumerate(self.poses): | |
image, _ = self.render_normal(pose) | |
images.append(image) | |
images = torch.stack(images) | |
images = images.cpu().numpy() | |
images = rearrange(images, "b c h w -> b h w c") | |
Path("renders").mkdir(exist_ok=True, parents=True) | |
write_video(f"renders/{self.name}_normal.mp4", images, fps=3) | |
def export(self, filename): | |
mesh = trimesh.Trimesh( | |
vertices=self.mesh.v.cpu().numpy(), | |
faces=self.mesh.f.cpu().numpy(), | |
vertex_colors=torch.sigmoid(self.vc.detach()).cpu().numpy(), | |
) | |
self.vc.data = torch.sigmoid(self.vc.detach()) | |
trimesh.repair.fix_inversion(mesh) | |
mesh.export(filename) | |
def do_refine( | |
mesh: str, | |
scene: str, | |
num_opt: int = 4, | |
iters: int = 2000, | |
skip_refine: bool = False, | |
render_normal: bool = True, | |
lpips: float = 1.0, | |
): | |
refiner = Refiner( | |
# "tmp/corgi_size_1.obj", | |
mesh, | |
scene, | |
num_opt=num_opt, | |
lpips=lpips, | |
) | |
if not skip_refine: | |
refiner.refine_texture(512, iters) | |
save_path = Path("refined") / f"{Path(scene).stem}.obj" | |
if not save_path.parent.exists(): | |
save_path.parent.mkdir(exist_ok=True, parents=True) | |
refiner.export(str(save_path)) | |
refiner.render_spiral() | |
if render_normal: | |
refiner.render_normal_spiral() | |
if __name__ == "__main__": | |
tyro.cli(do_refine) | |