Spaces:
dylanebert
/
Running on Zero

LGM-mini / app.py
dylanebert's picture
dylanebert HF staff
install from wheel
1803df3
raw
history blame
No virus
3.5 kB
import os
import shlex
import subprocess
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from gradio_client import Client, file
subprocess.run(
shlex.split(
"pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"
)
)
TMP_DIR = "/tmp"
os.makedirs(TMP_DIR, exist_ok=True)
image_pipeline = DiffusionPipeline.from_pretrained(
"ashawkey/imagedream-ipmv-diffusers",
custom_pipeline="dylanebert/multi_view_diffusion",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
splat_pipeline = DiffusionPipeline.from_pretrained(
"dylanebert/LGM",
custom_pipeline="dylanebert/LGM",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
@spaces.GPU
def run(input_image, convert):
input_image = input_image.astype("float32") / 255.0
images = image_pipeline(
"", input_image, guidance_scale=5, num_inference_steps=30, elevation=0
)
gaussians = splat_pipeline(images)
output_ply_path = os.path.join(TMP_DIR, "output.ply")
splat_pipeline.save_ply(gaussians, output_ply_path)
if convert:
output_mesh_path = convert_to_mesh(output_ply_path)
return output_mesh_path
else:
return output_ply_path
def convert_to_mesh(input_ply):
client = Client("https://dylanebert-splat-to-mesh.hf.space/")
output_mesh_path = client.predict(file(input_ply), api_name="/run")
client.close()
return output_mesh_path
_TITLE = """LGM Mini"""
_DESCRIPTION = """
<div>
A lightweight version of <a href="https://huggingface.co/spaces/ashawkey/LGM">LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation</a>.
</div>
"""
css = """
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
block = gr.Blocks(title=_TITLE, css=css)
with block:
gr.DuplicateButton(
value="Duplicate Space for private use", elem_id="duplicate-button"
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("# " + _TITLE)
gr.Markdown(_DESCRIPTION)
with gr.Row(variant="panel"):
with gr.Column(scale=1):
def update_warning(checked):
if checked:
return '<span style="color: #ff0000;">Warning: Mesh conversion takes several minutes</span>'
else:
return ""
input_image = gr.Image(label="image", type="numpy")
convert_checkbox = gr.Checkbox(label="Convert to Mesh")
warning = gr.HTML()
convert_checkbox.change(
fn=update_warning, inputs=[convert_checkbox], outputs=[warning]
)
button_gen = gr.Button("Generate")
with gr.Column(scale=1):
output_splat = gr.Model3D(label="3D Gaussians")
button_gen.click(
fn=run, inputs=[input_image, convert_checkbox], outputs=[output_splat]
)
gr.Examples(
examples=[
"data_test/frog_sweater.jpg",
"data_test/bird.jpg",
"data_test/boy.jpg",
"data_test/cat_statue.jpg",
"data_test/dragontoy.jpg",
"data_test/gso_rabbit.jpg",
],
inputs=[input_image],
outputs=[output_splat],
fn=lambda x: run(input_image=x, convert=False),
cache_examples=True,
label="Image-to-3D Examples",
)
block.queue().launch(debug=True, share=True)