import gradio as gr import spaces import torch from diffusers import DiffusionPipeline from PIL import Image # Text-to-Multi-View Diffusion pipeline text_pipeline = DiffusionPipeline.from_pretrained( "dylanebert/mvdream", custom_pipeline="dylanebert/multi-view-diffusion", torch_dtype=torch.float16, trust_remote_code=True, ).to("cuda") # Image-to-Multi-View Diffusion pipeline image_pipeline = DiffusionPipeline.from_pretrained( "dylanebert/multi-view-diffusion", custom_pipeline="dylanebert/multi-view-diffusion", torch_dtype=torch.float16, trust_remote_code=True, ).to("cuda") def create_image_grid(images): images = [Image.fromarray((img * 255).astype("uint8")) for img in images] width, height = images[0].size grid_img = Image.new("RGB", (2 * width, 2 * height)) grid_img.paste(images[0], (0, 0)) grid_img.paste(images[1], (width, 0)) grid_img.paste(images[2], (0, height)) grid_img.paste(images[3], (width, height)) return grid_img @spaces.GPU def text_to_mv(prompt): images = text_pipeline( prompt, guidance_scale=5, num_inference_steps=30, elevation=0 ) return create_image_grid(images) @spaces.GPU def image_to_mv(image, prompt): image = image.astype("float32") / 255.0 images = image_pipeline( prompt, image, guidance_scale=5, num_inference_steps=30, elevation=0 ) return create_image_grid(images) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): with gr.Tab("Text Input"): text_input = gr.Textbox( lines=2, show_label=False, placeholder="Enter a prompt here (e.g. 'a cat statue')", ) text_btn = gr.Button("Generate Multi-View Images") with gr.Tab("Image Input"): image_input = gr.Image( label="Image Input", type="numpy", ) optional_text_input = gr.Textbox( lines=2, show_label=False, placeholder="Enter an optional prompt here", ) image_btn = gr.Button("Generate Multi-View Images") with gr.Column(): output = gr.Image(label="Generated Images") text_btn.click(fn=text_to_mv, inputs=text_input, outputs=output) image_btn.click( fn=image_to_mv, inputs=[image_input, optional_text_input], outputs=output ) if __name__ == "__main__": demo.queue().launch()