|
import torch |
|
import math |
|
import os |
|
import imageio |
|
import mcubes |
|
import trimesh |
|
import numpy as np |
|
import argparse |
|
from torchvision.utils import save_image |
|
from PIL import Image |
|
import glob |
|
from .models.generator import LRMGenerator |
|
from .cam_utils import build_camera_principle, build_camera_standard, center_looking_at_camera_pose |
|
from functools import partial |
|
from rembg import remove, new_session |
|
from kiui.op import recenter |
|
import kiui |
|
|
|
class LRMInferrer: |
|
def __init__(self, model_name: str, resume: str): |
|
print("Initializing LRMInferrer") |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
_model_kwargs = {'camera_embed_dim': 1024, 'rendering_samples_per_ray': 128, 'transformer_dim': 1024, 'transformer_layers': 16, 'transformer_heads': 16, 'triplane_low_res': 32, 'triplane_high_res': 64, 'triplane_dim': 80, 'encoder_freeze': False} |
|
|
|
self.model = self._build_model(_model_kwargs).eval().to(self.device) |
|
checkpoint = torch.load(resume, map_location='cpu') |
|
state_dict = checkpoint['model_state_dict'] |
|
self.model.load_state_dict(state_dict) |
|
del checkpoint, state_dict |
|
torch.cuda.empty_cache() |
|
|
|
def __enter__(self): |
|
print("Entering context") |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
print("Exiting context") |
|
if exc_type: |
|
print(f"Exception type: {exc_type}") |
|
print(f"Exception value: {exc_val}") |
|
print(f"Traceback: {exc_tb}") |
|
|
|
def _build_model(self, model_kwargs): |
|
print("Building model") |
|
model = LRMGenerator(**model_kwargs).to(self.device) |
|
print("Loaded model from checkpoint") |
|
return model |
|
|
|
@staticmethod |
|
def get_surrounding_views(M, radius, elevation): |
|
camera_positions = [] |
|
rand_theta = np.random.uniform(0, np.pi/180) |
|
elevation = math.radians(elevation) |
|
for i in range(M): |
|
theta = 2 * math.pi * i / M + rand_theta |
|
x = radius * math.cos(theta) * math.cos(elevation) |
|
y = radius * math.sin(theta) * math.cos(elevation) |
|
z = radius * math.sin(elevation) |
|
camera_positions.append([x, y, z]) |
|
camera_positions = torch.tensor(camera_positions, dtype=torch.float32) |
|
extrinsics = center_looking_at_camera_pose(camera_positions) |
|
return extrinsics |
|
|
|
@staticmethod |
|
def _default_intrinsics(): |
|
fx = fy = 384 |
|
cx = cy = 256 |
|
w = h = 512 |
|
intrinsics = torch.tensor([ |
|
[fx, fy], |
|
[cx, cy], |
|
[w, h], |
|
], dtype=torch.float32) |
|
return intrinsics |
|
|
|
def _default_source_camera(self, batch_size: int = 1): |
|
dist_to_center = 1.5 |
|
canonical_camera_extrinsics = torch.tensor([[ |
|
[0, 0, 1, 1], |
|
[1, 0, 0, 0], |
|
[0, 1, 0, 0], |
|
]], dtype=torch.float32) |
|
canonical_camera_intrinsics = self._default_intrinsics().unsqueeze(0) |
|
source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics) |
|
return source_camera.repeat(batch_size, 1) |
|
|
|
def _default_render_cameras(self, batch_size: int = 1): |
|
render_camera_extrinsics = self.get_surrounding_views(160, 1.5, 0) |
|
render_camera_intrinsics = self._default_intrinsics().unsqueeze(0).repeat(render_camera_extrinsics.shape[0], 1, 1) |
|
render_cameras = build_camera_standard(render_camera_extrinsics, render_camera_intrinsics) |
|
return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
|
@staticmethod |
|
def images_to_video(images, output_path, fps, verbose=False): |
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
frames = [] |
|
for i in range(images.shape[0]): |
|
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ |
|
f"Frame shape mismatch: {frame.shape} vs {images.shape}" |
|
assert frame.min() >= 0 and frame.max() <= 255, \ |
|
f"Frame value out of range: {frame.min()} ~ {frame.max()}" |
|
frames.append(frame) |
|
imageio.mimwrite(output_path, np.stack(frames), fps=fps) |
|
if verbose: |
|
print(f"Saved video to {output_path}") |
|
|
|
def infer_single(self, image: torch.Tensor, render_size: int, mesh_size: int, export_video: bool, export_mesh: bool): |
|
print("infer_single called") |
|
mesh_thres = 1.0 |
|
chunk_size = 2 |
|
batch_size = 1 |
|
|
|
source_camera = self._default_source_camera(batch_size).to(self.device) |
|
render_cameras = self._default_render_cameras(batch_size).to(self.device) |
|
|
|
with torch.no_grad(): |
|
planes = self.model.forward(image, source_camera) |
|
results = {} |
|
|
|
if export_video: |
|
print("Starting export_video") |
|
frames = [] |
|
for i in range(0, render_cameras.shape[1], chunk_size): |
|
print(f"Processing chunk {i} to {i + chunk_size}") |
|
frames.append( |
|
self.model.synthesizer( |
|
planes, |
|
render_cameras[:, i:i+chunk_size], |
|
render_size, |
|
render_size, |
|
0, |
|
0 |
|
) |
|
) |
|
frames = { |
|
k: torch.cat([r[k] for r in frames], dim=1) |
|
for k in frames[0].keys() |
|
} |
|
results.update({ |
|
'frames': frames, |
|
}) |
|
print("Finished export_video") |
|
|
|
if export_mesh: |
|
print("Starting export_mesh") |
|
grid_out = self.model.synthesizer.forward_grid( |
|
planes=planes, |
|
grid_size=mesh_size, |
|
) |
|
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres) |
|
vtx = vtx / (mesh_size - 1) * 2 - 1 |
|
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0) |
|
vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy() |
|
vtx_colors = (vtx_colors * 255).astype(np.uint8) |
|
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) |
|
results.update({ |
|
'mesh': mesh, |
|
}) |
|
print("Finished export_mesh") |
|
|
|
return results |
|
|
|
def infer(self, source_image: str, dump_path: str, source_size: int, render_size: int, mesh_size: int, export_video: bool, export_mesh: bool): |
|
print("infer called") |
|
session = new_session("isnet-general-use") |
|
rembg_remove = partial(remove, session=session) |
|
image_name = os.path.basename(source_image) |
|
uid = image_name.split('.')[0] |
|
|
|
image = kiui.read_image(source_image, mode='uint8') |
|
image = rembg_remove(image) |
|
mask = rembg_remove(image, only_mask=True) |
|
image = recenter(image, mask, border_ratio=0.20) |
|
os.makedirs(dump_path, exist_ok=True) |
|
|
|
image = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0) / 255.0 |
|
if image.shape[1] == 4: |
|
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...]) |
|
image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True) |
|
image = torch.clamp(image, 0, 1) |
|
save_image(image, os.path.join(dump_path, f'{uid}.png')) |
|
|
|
results = self.infer_single( |
|
image.cuda(), |
|
render_size=render_size, |
|
mesh_size=mesh_size, |
|
export_video=export_video, |
|
export_mesh=export_mesh, |
|
) |
|
|
|
if 'frames' in results: |
|
renderings = results['frames'] |
|
for k, v in renderings.items(): |
|
if k == 'images_rgb': |
|
self.images_to_video( |
|
v[0], |
|
os.path.join(dump_path, f'{uid}.mp4'), |
|
fps=40, |
|
) |
|
print(f"Export video success to {dump_path}") |
|
|
|
if 'mesh' in results: |
|
mesh = results['mesh'] |
|
mesh.export(os.path.join(dump_path, f'{uid}.obj'), 'obj') |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model_name', type=str, default='lrm-base-obj-v1') |
|
parser.add_argument('--source_path', type=str, default='./assets/cat.png') |
|
parser.add_argument('--dump_path', type=str, default='./results/single_image') |
|
parser.add_argument('--source_size', type=int, default=512) |
|
parser.add_argument('--render_size', type=int, default=384) |
|
parser.add_argument('--mesh_size', type=int, default=512) |
|
parser.add_argument('--export_video', action='store_true') |
|
parser.add_argument('--export_mesh', action='store_true') |
|
parser.add_argument('--resume', type=str, required=True, help='Path to a checkpoint to resume training from') |
|
args = parser.parse_args() |
|
|
|
with LRMInferrer(model_name=args.model_name, resume=args.resume) as inferrer: |
|
with torch.autocast(device_type="cuda", cache_enabled=False, dtype=torch.float32): |
|
print("Start inference for image:", args.source_path) |
|
inferrer.infer( |
|
source_image=args.source_path, |
|
dump_path=args.dump_path, |
|
source_size=args.source_size, |
|
render_size=args.render_size, |
|
mesh_size=args.mesh_size, |
|
export_video=args.export_video, |
|
export_mesh=args.export_mesh, |
|
) |
|
print("Finished inference for image:", args.source_path) |
|
|