callum-canavan's picture
Update app
1a5a17f
raw
history blame
3.27 kB
import argparse
from pathlib import Path
import gradio as gr
print("hello")
from icecream import ic
import torch
from diffusers import DiffusionPipeline
from visual_anagrams.views import get_views, VIEW_MAP_NAMES
from visual_anagrams.samplers import sample_stage_1, sample_stage_2
from visual_anagrams.utils import add_args, save_illusion, save_metadata
from visual_anagrams.animate import animate_two_view
stage_1 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-M-v1.0",
variant="fp16",
torch_dtype=torch.float16)
stage_2 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-II-M-v1.0",
text_encoder=None,
variant="fp16",
torch_dtype=torch.float16,
)
stage_1.enable_model_cpu_offload()
stage_2.enable_model_cpu_offload()
def generate_content(
style,
prompt_for_original,
prompt_for_transformed,
transformation,
num_inference_steps,
seed
):
prompts = [f'{style} {p}'.strip() for p in [prompt_for_original, prompt_for_transformed]]
prompt_embeds = [stage_1.encode_prompt(p) for p in prompts]
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
prompt_embeds = torch.cat(prompt_embeds)
negative_prompt_embeds = torch.cat(negative_prompt_embeds)
views = ['identity', VIEW_MAP_NAMES[transformation]]
views = get_views(views)
generator = torch.manual_seed(seed)
print("Sample stage 1")
image = sample_stage_1(stage_1,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=num_inference_steps,
generator=generator)
print("Sample stage 2")
image = sample_stage_2(stage_2,
image,
prompt_embeds,
negative_prompt_embeds,
views,
num_inference_steps=num_inference_steps,
generator=generator)
save_illusion(image, views, Path(""))
output_name = f"illusion.mp4"
size = image.shape[-1]
animate_two_view(
f"sample_{size}.png",
views[1],
prompts[0],
prompts[1],
save_video_path=output_name,
)
return output_name, f"sample_{size}.png", f"sample_{size}.views.png"
choices = list(VIEW_MAP_NAMES.keys())
gradio_app = gr.Interface(
fn=generate_content,
title="Multi-View Illusion Diffusion",
inputs=[
gr.Textbox(label="Style", placeholder="an oil painting of"),
gr.Textbox(label="Prompt for original view", placeholder="a dress"),
gr.Textbox(label="Prompt for transformed view", placeholder="an old man"),
gr.Dropdown(label="View transformation", choices=choices, value=choices[0]),
gr.Number(label="Number of diffusion steps", value=75, step=1, minimum=1, maximum=300),
gr.Number(label="Random seed", value=0, step=1, minimum=0, maximum=100000)
],
outputs=[gr.Video(label="Illusion"), gr.Image(label="Original"), gr.Image(label="Transformed")],
)
if __name__ == "__main__":
gradio_app.launch() # server_name="0.0.0.0"