File size: 3,306 Bytes
45c0347
 
 
9d178f2
ea7d27e
6d057c8
45c0347
9d178f2
45c0347
 
 
 
6d057c8
 
45c0347
 
 
6d057c8
45c0347
 
 
 
 
 
6d057c8
 
9d178f2
45c0347
 
 
 
 
 
cf5c20f
45c0347
 
 
 
 
 
 
 
 
9d178f2
785096a
9d178f2
45c0347
 
 
 
 
 
4a0748c
954caab
45c0347
 
 
 
 
 
 
4a0748c
45c0347
 
ea7d27e
45c0347
 
 
 
 
 
ea7d27e
45c0347
dba8464
45c0347
 
dba8464
 
 
45c0347
954caab
45c0347
 
 
 
 
 
 
785096a
cf5c20f
45c0347
dba8464
 
954caab
 
45c0347
954caab
45c0347
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
from pathlib import Path

import gradio as gr
from icecream import ic
import torch
from diffusers import DiffusionPipeline

from visual_anagrams.views import get_views, VIEW_MAP_NAMES
from visual_anagrams.samplers import sample_stage_1, sample_stage_2
from visual_anagrams.utils import add_args, save_illusion, save_metadata
from visual_anagrams.animate import animate_two_view

stage_1 = DiffusionPipeline.from_pretrained(
                "DeepFloyd/IF-I-M-v1.0",
                variant="fp16",
                torch_dtype=torch.float16)
stage_2 = DiffusionPipeline.from_pretrained(
                "DeepFloyd/IF-II-M-v1.0",
                text_encoder=None,
                variant="fp16",
                torch_dtype=torch.float16,
            )
stage_1.enable_model_cpu_offload()
stage_2.enable_model_cpu_offload()


def generate_content(
    style,
    prompt_for_original,
    prompt_for_transformed,
    transformation,
    num_inference_steps,
    seed,
):
    prompts = [f'{style} {p}'.strip() for p in [prompt_for_original, prompt_for_transformed]]
    prompt_embeds = [stage_1.encode_prompt(p) for p in prompts]
    prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
    prompt_embeds = torch.cat(prompt_embeds)
    negative_prompt_embeds = torch.cat(negative_prompt_embeds)

    views = ['identity', VIEW_MAP_NAMES[transformation]]
    views = get_views(views)

    generator = torch.manual_seed(seed + 42)

    print("Sample stage 1")
    image = sample_stage_1(stage_1,
                           prompt_embeds,
                           negative_prompt_embeds,
                           views,
                           num_inference_steps=num_inference_steps,
                           generator=generator)

    print("Sample stage 2")
    image = sample_stage_2(stage_2,
                           image,
                           prompt_embeds,
                           negative_prompt_embeds,
                           views,
                           num_inference_steps=num_inference_steps,
                           generator=generator)
    save_illusion(image, views, Path(""))
    
    output_name = f"illusion.mp4"
    size = image.shape[-1]
    animate_two_view(
        f"sample_{size}.png",
        views[1],
        prompts[0],
        prompts[1],
        save_video_path=output_name,
    )
    return output_name, f"sample_{size}.views.png"


with open("description.txt") as f:
    description = f.read()

choices = list(VIEW_MAP_NAMES.keys())
gradio_app = gr.Interface(
    fn=generate_content,
    title="Multi-View Illusion Diffusion",
    inputs=[
        gr.Textbox(label="Style", placeholder="an oil painting of"),
        gr.Textbox(label="Prompt for original view", placeholder="a dress"),
        gr.Textbox(label="Prompt for transformed view", placeholder="an old man"),
        gr.Dropdown(label="View transformation", choices=choices, value=choices[0]),
        gr.Number(label="Number of diffusion steps", value=50, step=1, minimum=1, maximum=300),
        gr.Number(label="Random seed", value=0, step=1, minimum=0, maximum=100000),
    ],
    outputs=[gr.Video(label="Illusion"), gr.Image(label="Before and After")],
    description=description,
)


if __name__ == "__main__":
    gradio_app.launch() # server_name="0.0.0.0"