import argparse from pathlib import Path import gradio as gr import torch from diffusers import DiffusionPipeline from visual_anagrams.views import get_views, VIEW_MAP_NAMES from visual_anagrams.samplers import sample_stage_1, sample_stage_2 from visual_anagrams.utils import add_args, save_illusion, save_metadata stage_1 = DiffusionPipeline.from_pretrained( "DeepFloyd/IF-I-M-v1.0", variant="fp16", torch_dtype=torch.float16) stage_2 = DiffusionPipeline.from_pretrained( "DeepFloyd/IF-II-M-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, ) stage_1.enable_model_cpu_offload() stage_2.enable_model_cpu_offload() def generate_content( style, prompt_for_original, prompt_for_transformed, transformation, num_inference_steps, seed ): prompts = [prompt_for_original, prompt_for_transformed] prompt_embeds = [stage_1.encode_prompt(f'{style} {p}'.strip()) for p in [prompts]] prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds) prompt_embeds = torch.cat(prompt_embeds) negative_prompt_embeds = torch.cat(negative_prompt_embeds) views = ['identity', transformation] views = get_views(views) generator = torch.manual_seed(seed) image = sample_stage_1(stage_1, prompt_embeds, negative_prompt_embeds, views, num_inference_steps=num_inference_steps, generator=generator) image = sample_stage_2(stage_2, image, prompt_embeds, negative_prompt_embeds, views, num_inference_steps=num_inference_steps, generator=generator) return image, image_transformed, transformation_gif choices = list(VIEW_MAP_NAMES.keys()) gradio_app = gr.Interface( fn=generate_content, inputs=[ gr.Textbox(label="Style", placeholder="an oil painting of"), gr.Textbox(label="Prompt for original view", placeholder="a penguin"), gr.Textbox(label="Prompt for transformed view", placeholder="a giraffe"), gr.Dropdown(label="View transformation", choices=choices, value=choices[0]), gr.Number(label="Number of diffusion steps", value=30, step=1, minimum=1, maximum=100), gr.Number(label="Random seed", value=0, step=1, minimum=0, maximum=100000) ], outputs=[gr.Image(label="Illusion"), gr.Image(label="Original"), gr.Image(label="Transformed")], ) if __name__ == "__main__": gradio_app.launch(server_name="0.0.0.0") # server_name="0.0.0.0"