dreamgaussian4d / lgm /infer_demo.py
jiaweir
optimize
cdc7dcc
raw
history blame
No virus
7.3 kB
import os
import tyro
import glob
import imageio
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from safetensors.torch import load_file
import kiui
from kiui.op import recenter
from kiui.cam import orbit_camera
from core.options import AllConfigs, Options
from core.models import LGM
from mvdream.pipeline_mvdream import MVDreamPipeline
import cv2
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# opt = tyro.cli(AllConfigs)
# # model
# model = LGM(opt)
# # resume pretrained checkpoint
# if opt.resume is not None:
# if opt.resume.endswith('safetensors'):
# ckpt = load_file(opt.resume, device='cpu')
# else:
# ckpt = torch.load(opt.resume, map_location='cpu')
# model.load_state_dict(ckpt, strict=False)
# print(f'[INFO] Loaded checkpoint from {opt.resume}')
# else:
# print(f'[WARN] model randomly initialized, are you sure?')
# # device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = model.half().to(device)
# model.eval()
# process function
def process(opt: Options, path, pipe, model, rays_embeddings, seed):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
proj_matrix[0, 0] = 1 / tan_half_fov
proj_matrix[1, 1] = 1 / tan_half_fov
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
proj_matrix[2, 3] = 1
name = os.path.splitext(os.path.basename(path))[0]
print(f'[INFO] Processing {path} --> {name}')
os.makedirs('vis_data', exist_ok=True)
os.makedirs('logs', exist_ok=True)
image = kiui.read_image(path, mode='uint8')
# generate mv
image = image.astype(np.float32) / 255.0
# rgba to rgb white bg
if image.shape[-1] == 4:
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
generator = torch.manual_seed(seed)
mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0, generator=generator)
mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
# generate gaussians
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
with torch.inference_mode():
############## align azimuth #####################
with torch.autocast(device_type='cuda', dtype=torch.float16):
# generate gaussians
gaussians = model.forward_gaussians(input_image)
best_azi = 0
best_diff = 1e8
for v, azi in enumerate(np.arange(-180, 180, 1)):
cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
# scale = min(azi / 360, 1)
scale = 1
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
rendered_image = result['image']
rendered_image = rendered_image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy()
rendered_image = cv2.resize(rendered_image, (image.shape[0], image.shape[1]), interpolation=cv2.INTER_AREA)
diff = np.mean((rendered_image- image) ** 2)
if diff < best_diff:
best_diff = diff
best_azi = azi
print("Best aligned azimuth: ", best_azi)
mv_image = []
for v, azi in enumerate([0, 90, 180, 270]):
cam_poses = torch.from_numpy(orbit_camera(0, azi + best_azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
# scale = min(azi / 360, 1)
scale = 1
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
rendered_image = result['image']
rendered_image = rendered_image.squeeze(1)
rendered_image = F.interpolate(rendered_image, (256, 256))
rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy()
mv_image.append(rendered_image)
mv_image = np.concatenate(mv_image, axis=0)
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
################################
with torch.autocast(device_type='cuda', dtype=torch.float16):
# generate gaussians
gaussians, gaussians_orig_res = model.forward_gaussians_downsample(input_image)
# save gaussians
model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply'))
# render 360 video
images = []
elevation = 0
azimuth = np.arange(0, 360, 2, dtype=np.int32)
for azi in tqdm.tqdm(azimuth):
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
image = model.gs.render(gaussians_orig_res, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
images = np.concatenate(images, axis=0)
imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30)