multimodalart's picture
Update app.py
30bf7bc verified
raw
history blame
2.76 kB
import gradio as gr
from diffusers import StableDiffusionXLPipeline
import numpy as np
import math
import spaces
import torch
import sys
import random
from gradio_imageslider import ImageSlider
theme = gr.themes.Base(
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
)
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
custom_pipeline="multimodalart/sdxl_perturbed_attention_guidance",
torch_dtype=torch.float16
)
device="cuda"
pipe = pipe.to(device)
@spaces.GPU
def run(prompt, negative_prompt="", guidance_scale=7.0, pag_scale=3.0, randomize_seed=True, seed=42, progress=gr.Progress(track_tqdm=True)):
if(randomize_seed):
seed = random.randint(0, sys.maxsize)
if(prompt == ""):
guidance_scale = 0.0
generator = torch.Generator(device="cuda").manual_seed(seed)
image_pag = pipe(prompt, guidance_scale=guidance_scale, pag_scale=3.0, pag_applied_layers=['mid'], generator=generator, num_inference_steps=25).images[0]
generator = torch.Generator(device="cuda").manual_seed(seed)
image_normal = pipe(prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=25).images[0]
return image_pag, image_normal, seed
css = '''
.gradio-container{
max-width: 768px !important;
margin: 0 auto;
}
'''
with gr.Blocks(css=css, theme=theme) as demo:
gr.Markdown('''# Perturbed Attention Guidance SDXL
SDXL 🧨 [diffusers implementation](https://huggingface.co/multimodalart/sdxl_perturbed_attention_guidance) of [Perturbed-Attenton Guidance](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/)
''')
with gr.Group():
with gr.Row():
prompt = gr.Textbox(show_label=False, scale=4, placeholder="Your prompt", info="Leave blank to test unconditional generation")
button = gr.Button("Generate", min_width=120)
output = ImageSlider(label="Your result image", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
guidance_scale = gr.Number(label="Guidance Scale", value=7.0)
pag_scale = gr.Number(label="Pag Scale", value=3.0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed = gr.Slider(minimum=1, maximum=18446744073709551615, step=1, randomize=True)
gr.Examples(["", "an insect robot preparing a delicious meal, anime style", "a photo of a group of friends at an amusement park"])
gr.on(
triggers=[
button.click,
prompt.submit
],
fn=run,
inputs=[prompt, guidance_scale, pag_scale, seed],
outputs=[output, seed],
)
if __name__ == "__main__":
demo.launch(share=True)