|
|
|
import os |
|
import time |
|
from collections import OrderedDict |
|
from PIL import Image |
|
import torch |
|
import trimesh |
|
from typing import Optional, List |
|
from einops import repeat, rearrange |
|
import numpy as np |
|
from michelangelo.models.tsal.tsal_base import Latent2MeshOutput |
|
from michelangelo.utils.misc import get_config_from_file, instantiate_from_config |
|
from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer |
|
from michelangelo.utils.visualizers import html_util |
|
|
|
import gradio as gr |
|
|
|
|
|
gradio_cached_dir = "./gradio_cached_dir" |
|
os.makedirs(gradio_cached_dir, exist_ok=True) |
|
|
|
save_mesh = False |
|
|
|
state = "" |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
box_v = 1.1 |
|
viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE") |
|
|
|
image_model_config_dict = OrderedDict({ |
|
"ASLDM-256-obj": { |
|
"config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml", |
|
"ckpt_path": "./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt", |
|
}, |
|
}) |
|
|
|
text_model_config_dict = OrderedDict({ |
|
"ASLDM-256": { |
|
"config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml", |
|
"ckpt_path": "./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt", |
|
}, |
|
}) |
|
|
|
|
|
class InferenceModel(object): |
|
model = None |
|
name = "" |
|
|
|
|
|
text2mesh_model = InferenceModel() |
|
image2mesh_model = InferenceModel() |
|
|
|
|
|
def set_state(s): |
|
global state |
|
state = s |
|
print(s) |
|
|
|
|
|
def output_to_html_frame(mesh_outputs: List[Latent2MeshOutput], bbox_size: float, |
|
image: Optional[np.ndarray] = None, |
|
html_frame: bool = False): |
|
global viewer |
|
|
|
for i in range(len(mesh_outputs)): |
|
mesh = mesh_outputs[i] |
|
if mesh is None: |
|
continue |
|
|
|
mesh_v = mesh.mesh_v.copy() |
|
mesh_v[:, 0] += i * np.max(bbox_size) |
|
mesh_v[:, 2] += np.max(bbox_size) |
|
viewer.add_mesh(mesh_v, mesh.mesh_f) |
|
|
|
mesh_tag = viewer.to_html(html_frame=False) |
|
|
|
if image is not None: |
|
image_tag = html_util.to_image_embed_tag(image) |
|
frame = f""" |
|
<table border = "1"> |
|
<tr> |
|
<td>{image_tag}</td> |
|
<td>{mesh_tag}</td> |
|
</tr> |
|
</table> |
|
""" |
|
else: |
|
frame = mesh_tag |
|
|
|
if html_frame: |
|
frame = html_util.to_html_frame(frame) |
|
|
|
viewer.reset() |
|
|
|
return frame |
|
|
|
|
|
def load_model(model_name: str, model_config_dict: dict, inference_model: InferenceModel): |
|
global device |
|
|
|
if inference_model.name == model_name: |
|
model = inference_model.model |
|
else: |
|
assert model_name in model_config_dict |
|
|
|
if inference_model.model is not None: |
|
del inference_model.model |
|
|
|
config_ckpt_path = model_config_dict[model_name] |
|
|
|
model_config = get_config_from_file(config_ckpt_path["config"]) |
|
if hasattr(model_config, "model"): |
|
model_config = model_config.model |
|
|
|
model = instantiate_from_config(model_config, ckpt_path=config_ckpt_path["ckpt_path"]) |
|
model = model.to(device) |
|
model = model.eval() |
|
|
|
inference_model.model = model |
|
inference_model.name = model_name |
|
|
|
return model |
|
|
|
|
|
def prepare_img(image: np.ndarray): |
|
image_pt = torch.tensor(image).float() |
|
image_pt = image_pt / 255 * 2 - 1 |
|
image_pt = rearrange(image_pt, "h w c -> c h w") |
|
|
|
return image_pt |
|
|
|
def prepare_model_viewer(fp): |
|
content = f""" |
|
<head> |
|
<script |
|
type="module" src="https://ajax.googleapis.com/ajax/libs/model-viewer/3.1.1/model-viewer.min.js"> |
|
</script> |
|
</head> |
|
<body> |
|
<model-viewer |
|
style="height: 150px; width: 150px;" |
|
rotation-per-second="10deg" |
|
id="t1" |
|
src="file/gradio_cached_dir/{fp}" |
|
environment-image="neutral" |
|
camera-target="0m 0m 0m" |
|
orientation="0deg 90deg 170deg" |
|
shadow-intensity="1" |
|
ar:true |
|
auto-rotate |
|
camera-controls> |
|
</model-viewer> |
|
</body> |
|
""" |
|
return content |
|
|
|
def prepare_html_frame(content): |
|
frame = f""" |
|
<html> |
|
<body> |
|
{content} |
|
</body> |
|
</html> |
|
""" |
|
return frame |
|
|
|
def prepare_html_body(content): |
|
frame = f""" |
|
<body> |
|
{content} |
|
</body> |
|
""" |
|
return frame |
|
|
|
def post_process_mesh_outputs(mesh_outputs): |
|
|
|
html_content = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=False) |
|
html_frame = prepare_html_frame(html_content) |
|
|
|
|
|
filename = f"text-256-{time.time()}.html" |
|
html_filepath = os.path.join(gradio_cached_dir, filename) |
|
with open(html_filepath, "w") as writer: |
|
writer.write(html_frame) |
|
|
|
''' |
|
Bug: The iframe tag does not work in Gradio. |
|
The chrome returns "No resource with given URL found" |
|
Solutions: |
|
https://github.com/gradio-app/gradio/issues/884 |
|
Due to the security bitches, the server can only find files parallel to the gradio_app.py. |
|
The path has format "file/TARGET_FILE_PATH" |
|
''' |
|
|
|
iframe_tag = f'<iframe src="file/gradio_cached_dir/{filename}" width="600%" height="400" frameborder="0"></iframe>' |
|
|
|
filelist = [] |
|
filenames = [] |
|
for i, mesh in enumerate(mesh_outputs): |
|
mesh.mesh_f = mesh.mesh_f[:, ::-1] |
|
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) |
|
|
|
name = str(i) + "_out_mesh.obj" |
|
filepath = gradio_cached_dir + "/" + name |
|
mesh_output.export(filepath, include_normals=True) |
|
filelist.append(filepath) |
|
filenames.append(name) |
|
|
|
filelist.append(html_filepath) |
|
return iframe_tag, filelist |
|
|
|
def image2mesh(image: np.ndarray, |
|
model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03", |
|
num_samples: int = 4, |
|
guidance_scale: int = 7.5, |
|
octree_depth: int = 7): |
|
global device, gradio_cached_dir, image_model_config_dict, box_v |
|
|
|
|
|
model = load_model(model_name, image_model_config_dict, image2mesh_model) |
|
|
|
|
|
image_pt = prepare_img(image) |
|
image_pt = repeat(image_pt, "c h w -> b c h w", b=num_samples) |
|
|
|
sample_inputs = { |
|
"image": image_pt |
|
} |
|
mesh_outputs = model.sample( |
|
sample_inputs, |
|
sample_times=1, |
|
guidance_scale=guidance_scale, |
|
return_intermediates=False, |
|
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], |
|
octree_depth=octree_depth, |
|
)[0] |
|
|
|
iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs) |
|
|
|
return iframe_tag, gr.update(value=filelist, visible=True) |
|
|
|
|
|
def text2mesh(text: str, |
|
model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03", |
|
num_samples: int = 4, |
|
guidance_scale: int = 7.5, |
|
octree_depth: int = 7): |
|
global device, gradio_cached_dir, text_model_config_dict, text2mesh_model, box_v |
|
|
|
|
|
model = load_model(model_name, text_model_config_dict, text2mesh_model) |
|
|
|
|
|
sample_inputs = { |
|
"text": [text] * num_samples |
|
} |
|
mesh_outputs = model.sample( |
|
sample_inputs, |
|
sample_times=1, |
|
guidance_scale=guidance_scale, |
|
return_intermediates=False, |
|
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], |
|
octree_depth=octree_depth, |
|
)[0] |
|
|
|
iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs) |
|
|
|
return iframe_tag, gr.update(value=filelist, visible=True) |
|
|
|
example_dir = './gradio_cached_dir/example/img_example' |
|
|
|
first_page_items = [ |
|
'alita.jpg', |
|
'burger.jpg' |
|
'loopy.jpg' |
|
'building.jpg', |
|
'mario.jpg', |
|
'car.jpg', |
|
'airplane.jpg', |
|
'bag.jpg', |
|
'bench.jpg', |
|
'ship.jpg' |
|
] |
|
raw_example_items = [ |
|
|
|
os.path.join(example_dir, x) |
|
for x in os.listdir(example_dir) |
|
if x.endswith(('.jpg', '.png')) |
|
] |
|
example_items = [x for x in raw_example_items if os.path.basename(x) in first_page_items] + [x for x in raw_example_items if os.path.basename(x) not in first_page_items] |
|
|
|
example_text = [ |
|
["A 3D model of a car; Audi A6."], |
|
["A 3D model of police car; Highway Patrol Charger"] |
|
], |
|
|
|
def set_cache(data: gr.SelectData): |
|
img_name = os.path.basename(example_items[data.index]) |
|
return os.path.join(example_dir, img_name), os.path.join(img_name) |
|
|
|
def disable_cache(): |
|
return "" |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Michelangelo") |
|
gr.Markdown("## [Github](https://github.com/NeuralCarver/Michelangelo) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Project Page](https://neuralcarver.github.io/michelangelo/)") |
|
gr.Markdown("Michelangelo is a conditional 3D shape generation system that trains based on the shape-image-text aligned latent representation.") |
|
gr.Markdown("### Hint:") |
|
gr.Markdown("1. We provide two APIs: Image-conditioned generation and Text-conditioned generation") |
|
gr.Markdown("2. Note that the Image-conditioned model is trained on multiple 3D datasets like ShapeNet and Objaverse") |
|
gr.Markdown("3. We provide some examples for you to try. You can also upload images or text as input.") |
|
gr.Markdown("4. Welcome to share your amazing results with us, and thanks for your interest in our work!") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
with gr.Tab("Image to 3D"): |
|
img = gr.Image(label="Image") |
|
gr.Markdown("For the best results, we suggest that the images uploaded meet the following three criteria: 1. The object is positioned at the center of the image, 2. The image size is square, and 3. The background is relatively clean.") |
|
btn_generate_img2obj = gr.Button(value="Generate") |
|
|
|
with gr.Accordion("Advanced settings", open=False): |
|
image_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256-obj",choices=list(image_model_config_dict.keys())) |
|
num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1) |
|
guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1) |
|
octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1) |
|
|
|
|
|
cache_dir = gr.Textbox(value="", visible=False) |
|
examples = gr.Gallery(label='Examples', value=example_items, elem_id="gallery", allow_preview=False, columns=[4], object_fit="contain") |
|
|
|
with gr.Tab("Text to 3D"): |
|
prompt = gr.Textbox(label="Prompt", placeholder="A 3D model of motorcar; Porche Cayenne Turbo.") |
|
gr.Markdown("For the best results, we suggest that the prompt follows 'A 3D model of CATEGORY; DESCRIPTION'. For example, A 3D model of motorcar; Porche Cayenne Turbo.") |
|
btn_generate_txt2obj = gr.Button(value="Generate") |
|
|
|
with gr.Accordion("Advanced settings", open=False): |
|
text_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256",choices=list(text_model_config_dict.keys())) |
|
num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1) |
|
guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1) |
|
octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1) |
|
|
|
gr.Markdown("#### Examples:") |
|
gr.Markdown("1. A 3D model of a coupe; Audi A6.") |
|
gr.Markdown("2. A 3D model of a motorcar; Hummer H2 SUT.") |
|
gr.Markdown("3. A 3D model of an airplane; Airbus.") |
|
gr.Markdown("4. A 3D model of a fighter aircraft; Attack Fighter.") |
|
gr.Markdown("5. A 3D model of a chair; Simple Wooden Chair.") |
|
gr.Markdown("6. A 3D model of a laptop computer; Dell Laptop.") |
|
gr.Markdown("7. A 3D model of a lamp; ceiling light.") |
|
gr.Markdown("8. A 3D model of a rifle; AK47.") |
|
gr.Markdown("9. A 3D model of a knife; Sword.") |
|
gr.Markdown("10. A 3D model of a vase; Plant in pot.") |
|
|
|
with gr.Column(): |
|
model_3d = gr.HTML() |
|
file_out = gr.File(label="Files", visible=False) |
|
|
|
outputs = [model_3d, file_out] |
|
|
|
img.upload(disable_cache, outputs=cache_dir) |
|
examples.select(set_cache, outputs=[img, cache_dir]) |
|
print(f'line:404: {cache_dir}') |
|
btn_generate_img2obj.click(image2mesh, inputs=[img, image_dropdown_models, num_samples, |
|
guidance_scale, octree_depth], |
|
outputs=outputs, api_name="generate_img2obj") |
|
|
|
btn_generate_txt2obj.click(text2mesh, inputs=[prompt, text_dropdown_models, num_samples, |
|
guidance_scale, octree_depth], |
|
outputs=outputs, api_name="generate_txt2obj") |
|
|
|
app.launch(server_name="0.0.0.0", server_port=8008, share=False) |