File size: 4,085 Bytes
0fe1449
30bf7bc
0fe1449
 
 
 
 
 
 
 
b1b860e
 
 
 
0fe1449
 
 
 
 
 
 
 
 
 
71d6339
48df6f9
d94d7b7
f9b09af
0fe1449
4a1f895
d94d7b7
 
0fe1449
71d6339
d7795fb
9c48cce
71d6339
 
f9b09af
0fe1449
fd9e499
0fe1449
 
fd9e499
f9b09af
acb6168
0fe1449
 
 
 
 
 
 
 
b1b860e
855f3b2
0fe1449
 
c4f32a7
0fe1449
 
 
477a209
0fe1449
fd9e499
44f95c7
0fe1449
b12e444
0fe1449
4a1f895
0b9fed4
6560c53
0fe1449
 
 
 
 
 
71d6339
0fe1449
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from diffusers import StableDiffusionXLPipeline
import numpy as np
import math
import spaces 
import torch 
import random

from gradio_imageslider import ImageSlider

theme = gr.themes.Base(
    font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
)

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    custom_pipeline="multimodalart/sdxl_perturbed_attention_guidance",
    torch_dtype=torch.float16
)

device="cuda"
pipe = pipe.to(device)

@spaces.GPU
def run(prompt, negative_prompt=None, guidance_scale=7.0, pag_scale=3.0, pag_layers=["mid"], randomize_seed=True, seed=42, lora=None, progress=gr.Progress(track_tqdm=True)):
    prompt = prompt.strip()
    negative_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else None
    print(f"Initial seed for prompt `{prompt}`", seed)
    if(randomize_seed):
        seed = random.randint(0, 9007199254740991)
    
    if not prompt and not negative_prompt:
        guidance_scale = 0.0
    pipe.unload_lora_weights()
    pipe.unfuse_lora()
    if lora:
        pipe.load_lora_weights(lora, adapter_name="custom")
        pipe.fuse_lora(lora_scale=0.9)
    print(f"Seed before sending to generator for prompt: `{prompt}`", seed)
    generator = torch.Generator(device="cuda").manual_seed(seed)
    image_pag = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, pag_scale=pag_scale, pag_applied_layers=pag_layers, generator=generator, num_inference_steps=25).images[0]    
    
    generator = torch.Generator(device="cuda").manual_seed(seed)
    image_normal = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=25).images[0]
    print(f"Seed at the end of generation for prompt: `{prompt}`", seed)
    return (image_pag, image_normal), seed

css = '''
.gradio-container{
max-width: 768px !important;
margin: 0 auto;
}
'''

with gr.Blocks(css=css, theme=theme) as demo:
    gr.Markdown('''# Perturbed-Attention Guidance SDXL
    SDXL 🧨 [diffusers implementation](https://huggingface.co/multimodalart/sdxl_perturbed_attention_guidance) of [Perturbed-Attenton Guidance](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/)
    ''')
    with gr.Group():
      with gr.Row():
        prompt = gr.Textbox(show_label=False, scale=4, placeholder="Your prompt", info="Leave blank to test unconditional generation")
        button = gr.Button("Generate", min_width=120)
      output = ImageSlider(label="Left: PAG, Right: No PAG", interactive=False)
      with gr.Accordion("Advanced Settings", open=False):
        guidance_scale = gr.Number(label="CFG Guidance Scale", info="The guidance scale for CFG, ignored if no prompt is entered (unconditional generation)", value=7.0)
        negative_prompt = gr.Textbox(label="Negative prompt", info="Is only applied for the CFG part, leave blank for unconditional generation")
        pag_scale = gr.Number(label="Pag Scale", value=3.0)
        pag_layers = gr.Dropdown(label="Model layers to apply Pag to", info="mid is the one used on the paper, up and down blocks seem unstable", choices=["up", "mid", "down"], multiselect=True, value="mid")
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        seed = gr.Slider(minimum=1, maximum=9007199254740991, step=1, randomize=True)
        lora = gr.Textbox(label="Custom LoRA path", info="Load a custom LoRA from Hugging Face to use PAG with")
    gr.Examples(fn=run, examples=[" ", "an insect robot preparing a delicious meal, anime style", "a photo of a group of friends at an amusement park"], inputs=prompt, outputs=[output, seed], cache_examples="lazy")
    gr.on(
        triggers=[
            button.click,
            prompt.submit
        ],
        fn=run,
        inputs=[prompt, negative_prompt, guidance_scale, pag_scale, pag_layers, randomize_seed, seed, lora],
        outputs=[output, seed],
    )
if __name__ == "__main__":
    demo.launch(share=True)