File size: 2,569 Bytes
0353592
a57b575
0353592
 
 
 
 
84b9349
 
2dd3270
06bf081
84b9349
 
 
 
 
 
 
06bf081
 
84b9349
 
 
0353592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57b575
0353592
 
 
 
 
 
 
a57b575
0353592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image


# Text-to-Multi-View Diffusion pipeline
text_pipeline = DiffusionPipeline.from_pretrained(
    "dylanebert/mvdream",
    custom_pipeline="dylanebert/multi-view-diffusion",
    torch_dtype=torch.float16,
    trust_remote_code=True,
).to("cuda")


# Image-to-Multi-View Diffusion pipeline
image_pipeline = DiffusionPipeline.from_pretrained(
    "dylanebert/multi-view-diffusion",
    custom_pipeline="dylanebert/multi-view-diffusion",
    torch_dtype=torch.float16,
    trust_remote_code=True,
).to("cuda")


def create_image_grid(images):
    images = [Image.fromarray((img * 255).astype("uint8")) for img in images]

    width, height = images[0].size
    grid_img = Image.new("RGB", (2 * width, 2 * height))

    grid_img.paste(images[0], (0, 0))
    grid_img.paste(images[1], (width, 0))
    grid_img.paste(images[2], (0, height))
    grid_img.paste(images[3], (width, height))

    return grid_img


@spaces.GPU
def text_to_mv(prompt):
    images = text_pipeline(
        prompt, guidance_scale=5, num_inference_steps=30, elevation=0
    )
    return create_image_grid(images)


@spaces.GPU
def image_to_mv(image, prompt):
    image = image.astype("float32") / 255.0
    images = image_pipeline(
        prompt, image, guidance_scale=5, num_inference_steps=30, elevation=0
    )
    return create_image_grid(images)


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            with gr.Tab("Text Input"):
                text_input = gr.Textbox(
                    lines=2,
                    show_label=False,
                    placeholder="Enter a prompt here (e.g. 'a cat statue')",
                )
                text_btn = gr.Button("Generate Multi-View Images")
            with gr.Tab("Image Input"):
                image_input = gr.Image(
                    label="Image Input",
                    type="numpy",
                )
                optional_text_input = gr.Textbox(
                    lines=2,
                    show_label=False,
                    placeholder="Enter an optional prompt here",
                )
                image_btn = gr.Button("Generate Multi-View Images")
        with gr.Column():
            output = gr.Image(label="Generated Images")

    text_btn.click(fn=text_to_mv, inputs=text_input, outputs=output)
    image_btn.click(
        fn=image_to_mv, inputs=[image_input, optional_text_input], outputs=output
    )


if __name__ == "__main__":
    demo.queue().launch()