dylanebert's picture
dylanebert HF staff
add zerogpu support
a57b575
raw history blame
No virus
2.78 kB
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image
@spaces.GPU
def load_text_pipeline():
# Text-to-Multi-View Diffusion pipeline
return DiffusionPipeline.from_pretrained(
"ashawkey/mvdream-sd2.1-diffusers",
custom_pipeline="dylanebert/multi_view_diffusion",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
@spaces.GPU
def load_image_pipeline():
# Image-to-Multi-View Diffusion pipeline
return DiffusionPipeline.from_pretrained(
"ashawkey/imagedream-ipmv-diffusers",
custom_pipeline="dylanebert/multi_view_diffusion",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
text_pipeline = load_text_pipeline()
image_pipeline = load_image_pipeline()
def create_image_grid(images):
images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
width, height = images[0].size
grid_img = Image.new("RGB", (2 * width, 2 * height))
grid_img.paste(images[0], (0, 0))
grid_img.paste(images[1], (width, 0))
grid_img.paste(images[2], (0, height))
grid_img.paste(images[3], (width, height))
return grid_img
@spaces.GPU
def text_to_mv(prompt):
images = text_pipeline(
prompt, guidance_scale=5, num_inference_steps=30, elevation=0
)
return create_image_grid(images)
@spaces.GPU
def image_to_mv(image, prompt):
image = image.astype("float32") / 255.0
images = image_pipeline(
prompt, image, guidance_scale=5, num_inference_steps=30, elevation=0
)
return create_image_grid(images)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Tab("Text Input"):
text_input = gr.Textbox(
lines=2,
show_label=False,
placeholder="Enter a prompt here (e.g. 'a cat statue')",
)
text_btn = gr.Button("Generate Multi-View Images")
with gr.Tab("Image Input"):
image_input = gr.Image(
label="Image Input",
type="numpy",
)
optional_text_input = gr.Textbox(
lines=2,
show_label=False,
placeholder="Enter an optional prompt here",
)
image_btn = gr.Button("Generate Multi-View Images")
with gr.Column():
output = gr.Image(label="Generated Images")
text_btn.click(fn=text_to_mv, inputs=text_input, outputs=output)
image_btn.click(
fn=image_to_mv, inputs=[image_input, optional_text_input], outputs=output
)
if __name__ == "__main__":
demo.queue().launch()