callum-canavan's picture
Update app to illusion generation
a65ed45
raw
history blame
2.8 kB
import argparse
from pathlib import Path
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from visual_anagrams.views import get_views, VIEW_MAP_NAMES
from visual_anagrams.samplers import sample_stage_1, sample_stage_2
from visual_anagrams.utils import add_args, save_illusion, save_metadata
stage_1 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-M-v1.0",
variant="fp16",
torch_dtype=torch.float16)
stage_2 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-II-M-v1.0",
text_encoder=None,
variant="fp16",
torch_dtype=torch.float16,
)
stage_1.enable_model_cpu_offload()
stage_2.enable_model_cpu_offload()
def generate_content(
style,
prompt_for_original,
prompt_for_transformed,
transformation,
num_inference_steps,
seed
):
prompts = [prompt_for_original, prompt_for_transformed]
prompt_embeds = [stage_1.encode_prompt(f'{style} {p}'.strip()) for p in [prompts]]
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
prompt_embeds = torch.cat(prompt_embeds)
negative_prompt_embeds = torch.cat(negative_prompt_embeds)
views = ['identity', transformation]
views = get_views(views)
generator = torch.manual_seed(seed)
image = sample_stage_1(stage_1,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=num_inference_steps,
generator=generator)
image = sample_stage_2(stage_2,
image,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=num_inference_steps,
generator=generator)
return image, image_transformed, transformation_gif
choices = list(VIEW_MAP_NAMES.keys())
gradio_app = gr.Interface(
fn=generate_content,
inputs=[
gr.Textbox(label="Style", placeholder="an oil painting of"),
gr.Textbox(label="Prompt for original view", placeholder="a penguin"),
gr.Textbox(label="Prompt for transformed view", placeholder="a giraffe"),
gr.Dropdown(label="View transformation", choices=choices, value=choices[0]),
gr.Number(label="Number of diffusion steps", value=30, step=1, minimum=1, maximum=100),
gr.Number(label="Random seed", value=0, step=1, minimum=0, maximum=100000)
],
outputs=[gr.Image(label="Illusion"), gr.Image(label="Original"), gr.Image(label="Transformed")],
)
if __name__ == "__main__":
gradio_app.launch(server_name="0.0.0.0") # server_name="0.0.0.0"