|
import argparse |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import torch |
|
from diffusers import DiffusionPipeline |
|
|
|
from visual_anagrams.views import get_views, VIEW_MAP_NAMES |
|
from visual_anagrams.samplers import sample_stage_1, sample_stage_2 |
|
from visual_anagrams.utils import add_args, save_illusion, save_metadata |
|
|
|
stage_1 = DiffusionPipeline.from_pretrained( |
|
"DeepFloyd/IF-I-M-v1.0", |
|
variant="fp16", |
|
torch_dtype=torch.float16) |
|
stage_2 = DiffusionPipeline.from_pretrained( |
|
"DeepFloyd/IF-II-M-v1.0", |
|
text_encoder=None, |
|
variant="fp16", |
|
torch_dtype=torch.float16, |
|
) |
|
stage_1.enable_model_cpu_offload() |
|
stage_2.enable_model_cpu_offload() |
|
|
|
|
|
def generate_content( |
|
style, |
|
prompt_for_original, |
|
prompt_for_transformed, |
|
transformation, |
|
num_inference_steps, |
|
seed |
|
): |
|
prompts = [prompt_for_original, prompt_for_transformed] |
|
prompt_embeds = [stage_1.encode_prompt(f'{style} {p}'.strip()) for p in [prompts]] |
|
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds) |
|
prompt_embeds = torch.cat(prompt_embeds) |
|
negative_prompt_embeds = torch.cat(negative_prompt_embeds) |
|
|
|
views = ['identity', transformation] |
|
views = get_views(views) |
|
|
|
generator = torch.manual_seed(seed) |
|
image = sample_stage_1(stage_1, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
views, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator) |
|
|
|
image = sample_stage_2(stage_2, |
|
image, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
views, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator) |
|
|
|
return image, image_transformed, transformation_gif |
|
|
|
|
|
choices = list(VIEW_MAP_NAMES.keys()) |
|
gradio_app = gr.Interface( |
|
fn=generate_content, |
|
inputs=[ |
|
gr.Textbox(label="Style", placeholder="an oil painting of"), |
|
gr.Textbox(label="Prompt for original view", placeholder="a penguin"), |
|
gr.Textbox(label="Prompt for transformed view", placeholder="a giraffe"), |
|
gr.Dropdown(label="View transformation", choices=choices, value=choices[0]), |
|
gr.Number(label="Number of diffusion steps", value=30, step=1, minimum=1, maximum=100), |
|
gr.Number(label="Random seed", value=0, step=1, minimum=0, maximum=100000) |
|
], |
|
outputs=[gr.Image(label="Illusion"), gr.Image(label="Original"), gr.Image(label="Transformed")], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
gradio_app.launch(server_name="0.0.0.0") |
|
|