Michelangelo / inference.py
Maikou's picture
all files first commit
9c3a994
raw
history blame
6.16 kB
# -*- coding: utf-8 -*-
import os
import time
from collections import OrderedDict
from typing import Optional, List
import argparse
from functools import partial
from einops import repeat, rearrange
import numpy as np
from PIL import Image
import trimesh
import cv2
import torch
import pytorch_lightning as pl
from michelangelo.models.tsal.tsal_base import Latent2MeshOutput
from michelangelo.models.tsal.inference_utils import extract_geometry
from michelangelo.utils.misc import get_config_from_file, instantiate_from_config
from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
from michelangelo.utils.visualizers import html_util
def load_model(args):
model_config = get_config_from_file(args.config_path)
if hasattr(model_config, "model"):
model_config = model_config.model
model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path)
model = model.cuda()
model = model.eval()
return model
def load_surface(fp):
with np.load(args.pointcloud_path) as input_pc:
surface = input_pc['points']
normal = input_pc['normals']
rng = np.random.default_rng()
ind = rng.choice(surface.shape[0], 4096, replace=False)
surface = torch.FloatTensor(surface[ind])
normal = torch.FloatTensor(normal[ind])
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
return surface
def prepare_image(args, number_samples=2):
image = cv2.imread(f"{args.image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_pt = torch.tensor(image).float()
image_pt = image_pt / 255 * 2 - 1
image_pt = rearrange(image_pt, "h w c -> c h w")
image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples)
return image_pt
def save_output(args, mesh_outputs):
os.makedirs(args.output_dir, exist_ok=True)
for i, mesh in enumerate(mesh_outputs):
mesh.mesh_f = mesh.mesh_f[:, ::-1]
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
name = str(i) + "_out_mesh.obj"
mesh_output.export(os.path.join(args.output_dir, name), include_normals=True)
print(f'-----------------------------------------------------------------------------')
print(f'>>> Finished and mesh saved in {args.output_dir}')
print(f'-----------------------------------------------------------------------------')
return 0
def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):
surface = load_surface(args.pointcloud_path)
# encoding
shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)
shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)
# decoding
latents = model.model.shape_model.decode(shape_zq)
geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
# reconstruction
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=surface.device,
batch_size=surface.shape[0],
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
)
recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1])
# save
os.makedirs(args.output_dir, exist_ok=True)
recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj'))
print(f'-----------------------------------------------------------------------------')
print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}')
print(f'-----------------------------------------------------------------------------')
return 0
def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7):
sample_inputs = {
"image": prepare_image(args)
}
mesh_outputs = model.sample(
sample_inputs,
sample_times=1,
guidance_scale=guidance_scale,
return_intermediates=False,
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
octree_depth=octree_depth,
)[0]
save_output(args, mesh_outputs)
return 0
def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7):
sample_inputs = {
"text": [args.text] * num_samples
}
mesh_outputs = model.sample(
sample_inputs,
sample_times=1,
guidance_scale=guidance_scale,
return_intermediates=False,
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
octree_depth=octree_depth,
)[0]
save_output(args, mesh_outputs)
return 0
task_dick = {
'reconstruction': reconstruction,
'image2mesh': image2mesh,
'text2mesh': text2mesh,
}
if __name__ == "__main__":
'''
1. Reconstruct point cloud
2. Image-conditioned generation
3. Text-conditioned generation
'''
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True)
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
parser.add_argument("--image_path", type=str, help='Path to the input image')
parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
parser.add_argument("--output_dir", type=str, default='./output')
parser.add_argument("-s", "--seed", type=int, default=0)
args = parser.parse_args()
pl.seed_everything(args.seed)
print(f'-----------------------------------------------------------------------------')
print(f'>>> Running {args.task}')
args.output_dir = os.path.join(args.output_dir, args.task)
print(f'>>> Output directory: {args.output_dir}')
print(f'-----------------------------------------------------------------------------')
task_dick[args.task](args, load_model(args))