dylanebert's picture
dylanebert HF staff
Update app.py
06bf081 verified
raw history blame
No virus
2.57 kB
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image
# Text-to-Multi-View Diffusion pipeline
text_pipeline = DiffusionPipeline.from_pretrained(
"dylanebert/mvdream",
custom_pipeline="dylanebert/multi-view-diffusion",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
# Image-to-Multi-View Diffusion pipeline
image_pipeline = DiffusionPipeline.from_pretrained(
"dylanebert/multi-view-diffusion",
custom_pipeline="dylanebert/multi-view-diffusion",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
def create_image_grid(images):
images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
width, height = images[0].size
grid_img = Image.new("RGB", (2 * width, 2 * height))
grid_img.paste(images[0], (0, 0))
grid_img.paste(images[1], (width, 0))
grid_img.paste(images[2], (0, height))
grid_img.paste(images[3], (width, height))
return grid_img
@spaces.GPU
def text_to_mv(prompt):
images = text_pipeline(
prompt, guidance_scale=5, num_inference_steps=30, elevation=0
)
return create_image_grid(images)
@spaces.GPU
def image_to_mv(image, prompt):
image = image.astype("float32") / 255.0
images = image_pipeline(
prompt, image, guidance_scale=5, num_inference_steps=30, elevation=0
)
return create_image_grid(images)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Tab("Text Input"):
text_input = gr.Textbox(
lines=2,
show_label=False,
placeholder="Enter a prompt here (e.g. 'a cat statue')",
)
text_btn = gr.Button("Generate Multi-View Images")
with gr.Tab("Image Input"):
image_input = gr.Image(
label="Image Input",
type="numpy",
)
optional_text_input = gr.Textbox(
lines=2,
show_label=False,
placeholder="Enter an optional prompt here",
)
image_btn = gr.Button("Generate Multi-View Images")
with gr.Column():
output = gr.Image(label="Generated Images")
text_btn.click(fn=text_to_mv, inputs=text_input, outputs=output)
image_btn.click(
fn=image_to_mv, inputs=[image_input, optional_text_input], outputs=output
)
if __name__ == "__main__":
demo.queue().launch()