RNRI / app.py
linoyts's picture
linoyts HF staff
Update app.py
8f5a8fc verified
raw
history blame
5.46 kB
import gradio as gr
import torch
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from src.config import RunConfig
from src.editor import ImageEditorDemo
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
scheduler_class = MyEulerAncestralDiscreteScheduler
pipe_inversion = SDXLDDIMPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True).to(device)
pipe_inference = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True).to(device)
pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
# if torch.cuda.is_available():
# torch.cuda.max_memory_allocated(device=device)
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
# pipe.enable_xformers_memory_efficient_attention()
# pipe = pipe.to(device)
# else:
# pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
# pipe = pipe.to(device)
@spaces.GPU
def infer(input_image, description_prompt, target_prompt, edit_guidance_scale, num_inference_steps=4,
num_inversion_steps=4,
inversion_max_step=0.6):
config = RunConfig(num_inference_steps=num_inference_steps,
num_inversion_steps=num_inversion_steps,
edit_guidance_scale=edit_guidance_scale,
inversion_max_step=inversion_max_step)
editor = ImageEditorDemo(pipe_inversion, pipe_inference, input_image, description_prompt, config, device)
image = editor.edit(target_prompt)
return image
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
# css = """
# #col-container-1 {
# margin: 0 auto;
# max-width: 520px;
# }
# #col-container-2 {
# margin: 0 auto;
# max-width: 520px;
# }
# """
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
# with gr.Blocks(css=css) as demo:
with gr.Blocks(css="style.css") as demo:
gr.Markdown(f""" # Real Time Editing with RNRI Inversion 🍎⚡️
This is a demo for our [paper](https://arxiv.org/abs/2312.12540) **RNRI: Regularized Newton Raphson Inversion for Text-to-Image Diffusion Models**.
Image editing using our RNRI for inversion demonstrates significant speed-up and improved quality compared to previous state-of-the-art methods.
Take a look at our [project page](https://barakmam.github.io/rnri.github.io/).
""")
with gr.Row():
with gr.Column(elem_id="col-container-1"):
with gr.Row():
input_image = gr.Image(label="Input image", sources=['upload', 'webcam'], type="pil")
with gr.Row():
description_prompt = gr.Text(
label="Image description",
info = "Enter your image description ",
show_label=False,
max_lines=1,
placeholder="a cake on a table",
container=False,
)
with gr.Row():
target_prompt = gr.Text(
label="Edit prompt",
info = "Enter your edit prompt",
show_label=False,
max_lines=1,
placeholder="an oreo cake on a table",
container=False,
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
edit_guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.2,
)
num_inference_steps = gr.Slider(
label="Number of RNRI iterations",
minimum=1,
maximum=12,
step=1,
value=4,
)
inversion_max_step = gr.Slider(
label="Inversion strength",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.6,
)
with gr.Row():
run_button = gr.Button("Edit", scale=1)
with gr.Column(elem_id="col-container-2"):
result = gr.Image(label="Result")
# gr.Examples(
# examples = examples,
# inputs = [prompt]
# )
run_button.click(
fn=infer,
inputs=[input_image, description_prompt, target_prompt, edit_guidance_scale, num_inference_steps,
num_inference_steps],
outputs=[result]
)
demo.queue().launch()
# im = infer(input_image, description_prompt, target_prompt, edit_guidance_scale, num_inference_steps=4, num_inversion_steps=4,
# inversion_max_step=0.6)