import sys from pathlib import Path from typing import Optional import gradio as gr from PIL import Image from diffusers.training_utils import set_seed from appearance_transfer_model import AppearanceTransferModel from run import run_appearance_transfer from utils.latent_utils import load_latents_or_invert_images from utils.model_utils import get_stable_diffusion_model sys.path.append(".") sys.path.append("..") from config import RunConfig DESCRIPTION = '''

Cross-Image Attention for Zero-Shot Appearance Transfer

This is a demo for our paper: ''Cross-Image Attention for Zero-Shot Appearance Transfer''.
Given two images depicting a source structure and a target appearance, our method generates an image merging the structure of one image with the appearance of the other.
We do so in a zero-shot manner, with no optimization or model training required while supporting appearance transfer across images that may differ in size and shape.

''' pipe = get_stable_diffusion_model() def main_pipeline(app_image_path: str, struct_image_path: str, domain_name: str, seed: int, prompt: Optional[str] = None) -> Image.Image: if prompt == "": prompt = None config = RunConfig( app_image_path=Path(app_image_path), struct_image_path=Path(struct_image_path), domain_name=domain_name, prompt=prompt, seed=seed, load_latents=False ) print(config) set_seed(config.seed) model = AppearanceTransferModel(config=config, pipe=pipe) latents_app, latents_struct, noise_app, noise_struct = load_latents_or_invert_images(model=model, cfg=config) model.set_latents(latents_app, latents_struct) model.set_noise(noise_app, noise_struct) print("Running appearance transfer...") images = run_appearance_transfer(model=model, cfg=config) print("Done.") return [images[0]] with gr.Blocks(css='style.css') as demo: gr.Markdown(DESCRIPTION) gr.HTML('''Duplicate Space''') with gr.Row(): with gr.Column(): app_image_path = gr.Image(label="Upload appearance image", type="filepath") struct_image_path = gr.Image(label="Upload structure image", type="filepath") domain_name = gr.Text(label="Domain name", max_lines=1, info="Specifies the domain the objects are coming from (e.g., 'animal', 'building', etc).") prompt = gr.Text(label="Prompt to use for inversion.", value='', info='If this kept empty, we will use the domain name to define ' 'the prompt as "A photo of a ".') random_seed = gr.Number(value=42, label="Random seed", precision=0) run_button = gr.Button('Generate') with gr.Column(): result = gr.Gallery(label='Result') inputs = [app_image_path, struct_image_path, domain_name, random_seed, prompt] outputs = [result] run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs) with gr.Row(): examples = [ ['inputs/zebra.png', 'inputs/giraffe.png', 'animal', 20, None], ['inputs/taj_mahal.jpg', 'inputs/duomo.png', 'building', 42, None], ['inputs/red_velvet_cake.jpg', 'inputs/chocolate_cake.jpg', 'cake', 42, 'A photo of cake'], ] gr.Examples(examples=examples, inputs=[app_image_path, struct_image_path, domain_name, random_seed, prompt], outputs=[result], fn=main_pipeline, cache_examples=True) demo.queue(max_size=50).launch(share=False)