multimodalart HF staff commited on
Commit
0fe1449
1 Parent(s): 0801036

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import StableDiffusionXLPipeline, EDMEulerScheduler
3
+ from custom_pipeline import CosStableDiffusionXLInstructPix2PixPipeline
4
+ from huggingface_hub import hf_hub_download
5
+ import numpy as np
6
+ import math
7
+ import spaces
8
+ import torch
9
+ import sys
10
+ import random
11
+
12
+ from diffusers import StableDiffusionXLPipeline
13
+ from gradio_imageslider import ImageSlider
14
+
15
+ pipe = StableDiffusionXLPipeline.from_pretrained(
16
+ "stabilityai/stable-diffusion-xl-base-1.0",
17
+ custom_pipeline="multimodalart/sdxl_perturbed_attention_guidance",
18
+ torch_dtype=torch.float16
19
+ )
20
+
21
+ device="cuda"
22
+ pipe = pipe.to(device)
23
+
24
+ @spaces.GPU
25
+ def run(prompt, negative_prompt="", guidance_scale=7.0, pag_scale=3.0, randomize_seed, seed, progress=gr.Progress(track_tqdm=True)):
26
+ if(randomize_seed):
27
+ seed = random.randint(0, sys.maxsize)
28
+ if(prompt == ""):
29
+ guidance_scale = 0.0
30
+
31
+ generator = torch.Generator(device="cuda").manual_seed(seed)
32
+ image_pag = pipe(prompt, guidance_scale=guidance_scale, pag_scale=3.0, pag_applied_layers=['mid'], generator=generator, num_inference_steps=25).images[0]
33
+
34
+ generator = torch.Generator(device="cuda").manual_seed(seed)
35
+ image_normal = pipe(prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=25).images[0]
36
+ return image_pag, image_normal, seed
37
+
38
+ css = '''
39
+ .gradio-container{
40
+ max-width: 768px !important;
41
+ margin: 0 auto;
42
+ }
43
+ '''
44
+
45
+ with gr.Blocks(css=css) as demo:
46
+ gr.Markdown('''# Perturbed Attention Guidance SDXL
47
+ SDXL 🧨 [diffusers implementation](https://huggingface.co/multimodalart/sdxl_perturbed_attention_guidance) of [Perturbed-Attenton Guidance](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/)
48
+ ''')
49
+ with gr.Group():
50
+ with gr.Row():
51
+ prompt = gr.Textbox(show_label=False, scale=4, placeholder="Your prompt", info="Leave blank to test unconditional generation")
52
+ button = gr.Button("Generate", min_width=120)
53
+ output = ImageSlider(label="Your result image", interactive=False)
54
+ with gr.Accordion("Advanced Settings", open=False):
55
+ guidance_scale = gr.Number(label="Guidance Scale", value=7.0)
56
+ pag_scale = gr.Number(label="Pag Scale", value=3.0)
57
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
58
+ seed = gr.Slider(minimum=1, maximum=18446744073709551615, step=1, randomize=True)
59
+ gr.Examples(["", "an insect robot preparing a delicious meal, anime style", "a photo of a group of friends at an amusement park"])
60
+ gr.on(
61
+ triggers=[
62
+ button.click,
63
+ prompt.submit
64
+ ],
65
+ fn=run,
66
+ inputs=[prompt, guidance_scale, pag_scale, seed],
67
+ outputs=[output, seed],
68
+ )
69
+ if __name__ == "__main__":
70
+ demo.launch(share=True)