# -*- coding: utf-8 -*- import os import time from collections import OrderedDict from PIL import Image import torch import trimesh from typing import Optional, List from einops import repeat, rearrange import numpy as np from michelangelo.models.tsal.tsal_base import Latent2MeshOutput from michelangelo.utils.misc import get_config_from_file, instantiate_from_config from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer from michelangelo.utils.visualizers import html_util import gradio as gr gradio_cached_dir = "./gradio_cached_dir" os.makedirs(gradio_cached_dir, exist_ok=True) save_mesh = False state = "" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") box_v = 1.1 viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE") image_model_config_dict = OrderedDict({ "ASLDM-256-obj": { "config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml", "ckpt_path": "./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt", }, }) text_model_config_dict = OrderedDict({ "ASLDM-256": { "config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml", "ckpt_path": "./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt", }, }) class InferenceModel(object): model = None name = "" text2mesh_model = InferenceModel() image2mesh_model = InferenceModel() def set_state(s): global state state = s print(s) def output_to_html_frame(mesh_outputs: List[Latent2MeshOutput], bbox_size: float, image: Optional[np.ndarray] = None, html_frame: bool = False): global viewer for i in range(len(mesh_outputs)): mesh = mesh_outputs[i] if mesh is None: continue mesh_v = mesh.mesh_v.copy() mesh_v[:, 0] += i * np.max(bbox_size) mesh_v[:, 2] += np.max(bbox_size) viewer.add_mesh(mesh_v, mesh.mesh_f) mesh_tag = viewer.to_html(html_frame=False) if image is not None: image_tag = html_util.to_image_embed_tag(image) frame = f"""
{image_tag} {mesh_tag}
""" else: frame = mesh_tag if html_frame: frame = html_util.to_html_frame(frame) viewer.reset() return frame def load_model(model_name: str, model_config_dict: dict, inference_model: InferenceModel): global device if inference_model.name == model_name: model = inference_model.model else: assert model_name in model_config_dict if inference_model.model is not None: del inference_model.model config_ckpt_path = model_config_dict[model_name] model_config = get_config_from_file(config_ckpt_path["config"]) if hasattr(model_config, "model"): model_config = model_config.model model = instantiate_from_config(model_config, ckpt_path=config_ckpt_path["ckpt_path"]) model = model.to(device) model = model.eval() inference_model.model = model inference_model.name = model_name return model def prepare_img(image: np.ndarray): image_pt = torch.tensor(image).float() image_pt = image_pt / 255 * 2 - 1 image_pt = rearrange(image_pt, "h w c -> c h w") return image_pt def prepare_model_viewer(fp): content = f""" """ return content def prepare_html_frame(content): frame = f""" {content} """ return frame def prepare_html_body(content): frame = f""" {content} """ return frame def post_process_mesh_outputs(mesh_outputs): # html_frame = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=True) html_content = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=False) html_frame = prepare_html_frame(html_content) # filename = f"{time.time()}.html" filename = f"text-256-{time.time()}.html" html_filepath = os.path.join(gradio_cached_dir, filename) with open(html_filepath, "w") as writer: writer.write(html_frame) ''' Bug: The iframe tag does not work in Gradio. The chrome returns "No resource with given URL found" Solutions: https://github.com/gradio-app/gradio/issues/884 Due to the security bitches, the server can only find files parallel to the gradio_app.py. The path has format "file/TARGET_FILE_PATH" ''' iframe_tag = f'' filelist = [] filenames = [] for i, mesh in enumerate(mesh_outputs): mesh.mesh_f = mesh.mesh_f[:, ::-1] mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) name = str(i) + "_out_mesh.obj" filepath = gradio_cached_dir + "/" + name mesh_output.export(filepath, include_normals=True) filelist.append(filepath) filenames.append(name) filelist.append(html_filepath) return iframe_tag, filelist def image2mesh(image: np.ndarray, model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03", num_samples: int = 4, guidance_scale: int = 7.5, octree_depth: int = 7): global device, gradio_cached_dir, image_model_config_dict, box_v # load model model = load_model(model_name, image_model_config_dict, image2mesh_model) # prepare image inputs image_pt = prepare_img(image) image_pt = repeat(image_pt, "c h w -> b c h w", b=num_samples) sample_inputs = { "image": image_pt } mesh_outputs = model.sample( sample_inputs, sample_times=1, guidance_scale=guidance_scale, return_intermediates=False, bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], octree_depth=octree_depth, )[0] iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs) return iframe_tag, gr.update(value=filelist, visible=True) def text2mesh(text: str, model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03", num_samples: int = 4, guidance_scale: int = 7.5, octree_depth: int = 7): global device, gradio_cached_dir, text_model_config_dict, text2mesh_model, box_v # load model model = load_model(model_name, text_model_config_dict, text2mesh_model) # prepare text inputs sample_inputs = { "text": [text] * num_samples } mesh_outputs = model.sample( sample_inputs, sample_times=1, guidance_scale=guidance_scale, return_intermediates=False, bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], octree_depth=octree_depth, )[0] iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs) return iframe_tag, gr.update(value=filelist, visible=True) example_dir = './gradio_cached_dir/example/img_example' first_page_items = [ 'alita.jpg', 'burger.jpg' 'loopy.jpg' 'building.jpg', 'mario.jpg', 'car.jpg', 'airplane.jpg', 'bag.jpg', 'bench.jpg', 'ship.jpg' ] raw_example_items = [ # (os.path.join(example_dir, x), x) os.path.join(example_dir, x) for x in os.listdir(example_dir) if x.endswith(('.jpg', '.png')) ] example_items = [x for x in raw_example_items if os.path.basename(x) in first_page_items] + [x for x in raw_example_items if os.path.basename(x) not in first_page_items] example_text = [ ["A 3D model of a car; Audi A6."], ["A 3D model of police car; Highway Patrol Charger"] ], def set_cache(data: gr.SelectData): img_name = os.path.basename(example_items[data.index]) return os.path.join(example_dir, img_name), os.path.join(img_name) def disable_cache(): return "" with gr.Blocks() as app: gr.Markdown("# Michelangelo") gr.Markdown("## [Github](https://github.com/NeuralCarver/Michelangelo) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Project Page](https://neuralcarver.github.io/michelangelo/)") gr.Markdown("Michelangelo is a conditional 3D shape generation system that trains based on the shape-image-text aligned latent representation.") gr.Markdown("### Hint:") gr.Markdown("1. We provide two APIs: Image-conditioned generation and Text-conditioned generation") gr.Markdown("2. Note that the Image-conditioned model is trained on multiple 3D datasets like ShapeNet and Objaverse") gr.Markdown("3. We provide some examples for you to try. You can also upload images or text as input.") gr.Markdown("4. Welcome to share your amazing results with us, and thanks for your interest in our work!") with gr.Row(): with gr.Column(): with gr.Tab("Image to 3D"): img = gr.Image(label="Image") gr.Markdown("For the best results, we suggest that the images uploaded meet the following three criteria: 1. The object is positioned at the center of the image, 2. The image size is square, and 3. The background is relatively clean.") btn_generate_img2obj = gr.Button(value="Generate") with gr.Accordion("Advanced settings", open=False): image_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256-obj",choices=list(image_model_config_dict.keys())) num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1) guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1) octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1) cache_dir = gr.Textbox(value="", visible=False) examples = gr.Gallery(label='Examples', value=example_items, elem_id="gallery", allow_preview=False, columns=[4], object_fit="contain") with gr.Tab("Text to 3D"): prompt = gr.Textbox(label="Prompt", placeholder="A 3D model of motorcar; Porche Cayenne Turbo.") gr.Markdown("For the best results, we suggest that the prompt follows 'A 3D model of CATEGORY; DESCRIPTION'. For example, A 3D model of motorcar; Porche Cayenne Turbo.") btn_generate_txt2obj = gr.Button(value="Generate") with gr.Accordion("Advanced settings", open=False): text_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256",choices=list(text_model_config_dict.keys())) num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1) guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1) octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1) gr.Markdown("#### Examples:") gr.Markdown("1. A 3D model of an airplane; Airbus.") gr.Markdown("2. A 3D model of a fighter aircraft; Attack Fighter.") gr.Markdown("3. A 3D model of a chair; Simple Wooden Chair.") gr.Markdown("4. A 3D model of a laptop computer; Dell Laptop.") gr.Markdown("5. A 3D model of a coupe; Audi A6.") gr.Markdown("6. A 3D model of a motorcar; Hummer H2 SUT.") gr.Markdown("7. A 3D model of a lamp; ceiling light.") gr.Markdown("8. A 3D model of a rifle; AK47.") gr.Markdown("9. A 3D model of a knife; Sword.") gr.Markdown("10. A 3D model of a vase; Plant in pot.") with gr.Column(): model_3d = gr.HTML() file_out = gr.File(label="Files", visible=False) outputs = [model_3d, file_out] img.upload(disable_cache, outputs=cache_dir) examples.select(set_cache, outputs=[img, cache_dir]) print(f'line:404: {cache_dir}') btn_generate_img2obj.click(image2mesh, inputs=[img, image_dropdown_models, num_samples, guidance_scale, octree_depth], outputs=outputs, api_name="generate_img2obj") btn_generate_txt2obj.click(text2mesh, inputs=[prompt, text_dropdown_models, num_samples, guidance_scale, octree_depth], outputs=outputs, api_name="generate_txt2obj") app.launch(server_name="0.0.0.0", server_port=8008, share=False)