# -*- coding: utf-8 -*- import os import time from collections import OrderedDict from PIL import Image import torch import trimesh from typing import Optional, List from einops import repeat, rearrange import numpy as np from michelangelo.models.tsal.tsal_base import Latent2MeshOutput from michelangelo.utils.misc import get_config_from_file, instantiate_from_config from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer from michelangelo.utils.visualizers import html_util import gradio as gr from omegaconf import OmegaConf from huggingface_hub import snapshot_download gradio_cached_dir = "./gradio_cached_dir" os.makedirs(gradio_cached_dir, exist_ok=True) save_mesh = False state = "" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") box_v = 1.1 viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE") image_model_config_dict = OrderedDict({ "ASLDM-256-obj": { # "config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml", # "ckpt_path": "./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt", "config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml", "ckpt_path": "checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt", }, }) text_model_config_dict = OrderedDict({ "ASLDM-256": { # "config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml", # "ckpt_path": "./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt", "config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml", "ckpt_path": "checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt", }, }) model_path = snapshot_download(repo_id="Maikou/Michelangelo") class InferenceModel(object): model = None name = "" text2mesh_model = InferenceModel() image2mesh_model = InferenceModel() def set_state(s): global state state = s print(s) def output_to_html_frame(mesh_outputs: List[Latent2MeshOutput], bbox_size: float, image: Optional[np.ndarray] = None, html_frame: bool = False): global viewer for i in range(len(mesh_outputs)): mesh = mesh_outputs[i] if mesh is None: continue mesh_v = mesh.mesh_v.copy() mesh_v[:, 0] += i * np.max(bbox_size) mesh_v[:, 2] += np.max(bbox_size) viewer.add_mesh(mesh_v, mesh.mesh_f) mesh_tag = viewer.to_html(html_frame=False) if image is not None: image_tag = html_util.to_image_embed_tag(image) frame = f"""
{image_tag} {mesh_tag}
""" else: frame = mesh_tag if html_frame: frame = html_util.to_html_frame(frame) viewer.reset() return frame def load_model(model_name: str, model_config_dict: dict, inference_model: InferenceModel): global device if inference_model.name == model_name: model = inference_model.model else: assert model_name in model_config_dict if inference_model.model is not None: del inference_model.model config_ckpt_path = model_config_dict[model_name] # raw_config_file = config_ckpt_path["config"] # raw_config = OmegaConf.load(raw_config_file) # raw_clip_ckpt_path = raw_config['model']['params']['first_stage_config']['params']['aligned_module_cfg']['params']['clip_model_version'] # clip_ckpt_path = os.path.join(model_path, raw_clip_ckpt_path) # raw_config['model']['params']['first_stage_config']['params']['aligned_module_cfg']['params']['clip_model_version'] = clip_ckpt_path # raw_config['model']['params']['cond_stage_config']['params']['version'] = clip_ckpt_path # OmegaConf.save(raw_config, 'current_config.yaml') # model_config = get_config_from_file('current_config.yaml') model_config = get_config_from_file(config_ckpt_path["config"]) if hasattr(model_config, "model"): model_config = model_config.model ckpt_path = os.path.join(model_path, config_ckpt_path["ckpt_path"]) model = instantiate_from_config(model_config, ckpt_path=ckpt_path) model = model.to(device) model = model.eval() inference_model.model = model inference_model.name = model_name return model def prepare_img(image: np.ndarray): image_pt = torch.tensor(image).float() image_pt = image_pt / 255 * 2 - 1 image_pt = rearrange(image_pt, "h w c -> c h w") return image_pt def prepare_model_viewer(fp): content = f""" """ return content def prepare_html_frame(content): frame = f""" {content} """ return frame def prepare_html_body(content): frame = f""" {content} """ return frame def post_process_mesh_outputs(mesh_outputs): # html_frame = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=True) html_content = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=False) html_frame = prepare_html_frame(html_content) # filename = f"{time.time()}.html" filename = f"four-in-one-{time.time()}.html" html_filepath = os.path.join(gradio_cached_dir, filename) with open(html_filepath, "w") as writer: writer.write(html_frame) ''' Bug: The iframe tag does not work in Gradio. The chrome returns "No resource with given URL found" Solutions: https://github.com/gradio-app/gradio/issues/884 Due to the security bitches, the server can only find files parallel to the gradio_app.py. The path has format "file/TARGET_FILE_PATH" ''' iframe_tag = f'' filelist = [] filenames = [] for i, mesh in enumerate(mesh_outputs): mesh.mesh_f = mesh.mesh_f[:, ::-1] mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) name = str(i) + "_out_mesh.obj" filepath = gradio_cached_dir + "/" + name mesh_output.export(filepath, include_normals=True) filelist.append(filepath) filenames.append(name) filelist.append(html_filepath) return iframe_tag, filelist def image2mesh(image: np.ndarray, model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03", num_samples: int = 4, guidance_scale: int = 7.5, octree_depth: int = 7): global device, gradio_cached_dir, image_model_config_dict, box_v # load model model = load_model(model_name, image_model_config_dict, image2mesh_model) # prepare image inputs image_pt = prepare_img(image) image_pt = repeat(image_pt, "c h w -> b c h w", b=num_samples) sample_inputs = { "image": image_pt } mesh_outputs = model.sample( sample_inputs, sample_times=1, guidance_scale=guidance_scale, return_intermediates=False, bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], octree_depth=octree_depth, )[0] iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs) return iframe_tag, gr.update(value=filelist, visible=True) def text2mesh(text: str, model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03", num_samples: int = 4, guidance_scale: int = 7.5, octree_depth: int = 7): global device, gradio_cached_dir, text_model_config_dict, text2mesh_model, box_v # load model model = load_model(model_name, text_model_config_dict, text2mesh_model) # prepare text inputs sample_inputs = { "text": [text] * num_samples } mesh_outputs = model.sample( sample_inputs, sample_times=1, guidance_scale=guidance_scale, return_intermediates=False, bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], octree_depth=octree_depth, )[0] iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs) return iframe_tag, gr.update(value=filelist, visible=True) example_dir = './gradio_cached_dir/example/img_example' first_page_items = [ 'alita.jpg', 'burger.jpg' 'loopy.jpg' 'building.jpg', 'mario.jpg', 'car.jpg', 'airplane.jpg', 'bag.jpg', 'bench.jpg', 'ship.jpg' ] raw_example_items = [ # (os.path.join(example_dir, x), x) os.path.join(example_dir, x) for x in os.listdir(example_dir) if x.endswith(('.jpg', '.png')) ] example_items = [x for x in raw_example_items if os.path.basename(x) in first_page_items] + [x for x in raw_example_items if os.path.basename(x) not in first_page_items] example_text = [ ["A 3D model of a car; Audi A6."], ["A 3D model of police car; Highway Patrol Charger"] ], def set_cache(data: gr.SelectData): img_name = os.path.basename(example_items[data.index]) return os.path.join(example_dir, img_name), os.path.join(img_name) def disable_cache(): return "" with gr.Blocks() as app: gr.Markdown("# Michelangelo") gr.Markdown("## [Github](https://github.com/NeuralCarver/Michelangelo) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Project Page](https://neuralcarver.github.io/michelangelo/)") gr.Markdown("Michelangelo is a conditional 3D shape generation system that trains based on the shape-image-text aligned latent representation.") gr.Markdown("### Hint:") gr.Markdown("1. We provide two APIs: Image-conditioned generation and Text-conditioned generation") gr.Markdown("2. Note that the Image-conditioned model is trained on multiple 3D datasets like ShapeNet and Objaverse") gr.Markdown("3. We provide some examples for you to try. You can also upload images or text as input.") gr.Markdown("4. To make it convenient to take favor results home, we provide download buttons for each OBJ file and a combined HTML file.") gr.Markdown("5. Welcome to share suggestions or amazing results with us, and thanks for your interest in our work!") gr.Markdown("6. Please note that the model may require some time to download the weights and set up during the first launch; we are working to fix this issue.") with gr.Row(): with gr.Column(): with gr.Tab("Image to 3D"): img = gr.Image(label="Image") gr.Markdown("For the best results, we suggest that the images uploaded meet the following three criteria: 1. The object is positioned at the center of the image, 2. The image size is square, and 3. The background is relatively clean.") btn_generate_img2obj = gr.Button(value="Generate") with gr.Accordion("Advanced settings", open=False): image_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256-obj",choices=list(image_model_config_dict.keys())) num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1) guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1) octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1) cache_dir = gr.Textbox(value="", visible=False) examples = gr.Gallery(label='Examples', value=example_items, elem_id="gallery", allow_preview=False, columns=[4], object_fit="contain") with gr.Tab("Text to 3D"): prompt = gr.Textbox(label="Prompt", placeholder="A 3D model of motorcar; Porche Cayenne Turbo.") gr.Markdown("For the best results, we suggest that the prompt follows 'A 3D model of CATEGORY; DESCRIPTION'. For example, A 3D model of motorcar; Porche Cayenne Turbo.") btn_generate_txt2obj = gr.Button(value="Generate") with gr.Accordion("Advanced settings", open=False): text_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256",choices=list(text_model_config_dict.keys())) num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1) guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1) octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1) gr.Markdown("#### Examples:") gr.Markdown("1. A 3D model of an airplane; Airbus.") gr.Markdown("2. A 3D model of a fighter aircraft; Attack Fighter.") gr.Markdown("3. A 3D model of a chair; Simple Wooden Chair.") gr.Markdown("4. A 3D model of a laptop computer; Dell Laptop.") gr.Markdown("5. A 3D model of a coupe; Audi A6.") gr.Markdown("6. A 3D model of a motorcar; Hummer H2 SUT.") gr.Markdown("7. A 3D model of a lamp; Light Post.") gr.Markdown("8. A 3D model of a rifle; AK47.") gr.Markdown("9. A 3D model of a knife; Sword.") gr.Markdown("10. A 3D model of a vase; Plant in pot.") with gr.Column(): model_3d = gr.HTML() file_out = gr.File(label="Files", visible=False) outputs = [model_3d, file_out] img.upload(disable_cache, outputs=cache_dir) examples.select(set_cache, outputs=[img, cache_dir]) print(os.path.abspath(os.path.dirname(__file__)), flush=True) print(model_path, flush=True) fps = os.listdir(model_path) print(fps) print(f'line:404: {cache_dir}', flush=True) btn_generate_img2obj.click(image2mesh, inputs=[img, image_dropdown_models, num_samples, guidance_scale, octree_depth], outputs=outputs, api_name="generate_img2obj") btn_generate_txt2obj.click(text2mesh, inputs=[prompt, text_dropdown_models, num_samples, guidance_scale, octree_depth], outputs=outputs, api_name="generate_txt2obj") app.launch()