Spaces:
Running
on
A100
Running
on
A100
add randomize seed and sfast
Browse files- app.py +44 -27
- requirements.txt +3 -0
app.py
CHANGED
@@ -12,10 +12,13 @@ from PIL import Image
|
|
12 |
import numpy as np
|
13 |
import gradio as gr
|
14 |
import psutil
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
|
18 |
-
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
|
19 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
20 |
# check if MPS is available OSX only M1/M2/M3 chips
|
21 |
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
@@ -27,7 +30,6 @@ torch_device = device
|
|
27 |
torch_dtype = torch.float16
|
28 |
|
29 |
print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
|
30 |
-
print(f"TORCH_COMPILE: {TORCH_COMPILE}")
|
31 |
print(f"device: {device}")
|
32 |
|
33 |
if mps_available:
|
@@ -43,24 +45,21 @@ else:
|
|
43 |
pipe = DiffusionPipeline.from_pretrained(model_id, safety_checker=None)
|
44 |
|
45 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
46 |
-
pipe.to(device=torch_device, dtype=torch_dtype).to(device)
|
47 |
-
pipe.unet.to(memory_format=torch.channels_last)
|
48 |
-
|
49 |
-
# check if computer has less than 64GB of RAM using sys or os
|
50 |
-
if psutil.virtual_memory().total < 64 * 1024**3:
|
51 |
-
pipe.enable_attention_slicing()
|
52 |
-
|
53 |
-
if TORCH_COMPILE:
|
54 |
-
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
55 |
-
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
|
56 |
-
|
57 |
-
pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
|
58 |
-
|
59 |
-
# Load LCM LoRA
|
60 |
pipe.load_lora_weights(
|
61 |
"latent-consistency/lcm-lora-sdxl",
|
62 |
use_auth_token=HF_TOKEN,
|
63 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
compel_proc = Compel(
|
66 |
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
|
@@ -71,8 +70,15 @@ compel_proc = Compel(
|
|
71 |
|
72 |
|
73 |
def predict(
|
74 |
-
prompt,
|
|
|
|
|
|
|
|
|
|
|
75 |
):
|
|
|
|
|
76 |
generator = torch.manual_seed(seed)
|
77 |
prompt_embeds, pooled_prompt_embeds = compel_proc(prompt)
|
78 |
|
@@ -94,7 +100,7 @@ def predict(
|
|
94 |
)
|
95 |
if nsfw_content_detected:
|
96 |
raise gr.Error("NSFW content detected.")
|
97 |
-
return results.images[0]
|
98 |
|
99 |
|
100 |
css = """
|
@@ -122,18 +128,28 @@ with gr.Blocks(css=css) as demo:
|
|
122 |
placeholder="Insert your prompt here:", scale=5, container=False
|
123 |
)
|
124 |
generate_bt = gr.Button("Generate", scale=1)
|
125 |
-
|
126 |
image = gr.Image(type="filepath")
|
127 |
with gr.Accordion("Advanced options", open=False):
|
128 |
guidance = gr.Slider(
|
129 |
label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
|
130 |
)
|
131 |
steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
with gr.Accordion("Run with diffusers"):
|
136 |
-
gr.Markdown(
|
|
|
137 |
```bash
|
138 |
pip install diffusers==0.23.0
|
139 |
```
|
@@ -151,10 +167,11 @@ with gr.Blocks(css=css) as demo:
|
|
151 |
)
|
152 |
results.images[0]
|
153 |
```
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
158 |
|
159 |
demo.queue()
|
160 |
demo.launch()
|
|
|
12 |
import numpy as np
|
13 |
import gradio as gr
|
14 |
import psutil
|
15 |
+
from sfast.compilers.stable_diffusion_pipeline_compiler import (
|
16 |
+
compile,
|
17 |
+
CompilationConfig,
|
18 |
+
)
|
19 |
|
20 |
|
21 |
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
|
|
|
22 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
23 |
# check if MPS is available OSX only M1/M2/M3 chips
|
24 |
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
|
30 |
torch_dtype = torch.float16
|
31 |
|
32 |
print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
|
|
|
33 |
print(f"device: {device}")
|
34 |
|
35 |
if mps_available:
|
|
|
45 |
pipe = DiffusionPipeline.from_pretrained(model_id, safety_checker=None)
|
46 |
|
47 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
pipe.load_lora_weights(
|
49 |
"latent-consistency/lcm-lora-sdxl",
|
50 |
use_auth_token=HF_TOKEN,
|
51 |
)
|
52 |
+
if device.type != "mps":
|
53 |
+
pipe.unet.to(memory_format=torch.channels_last)
|
54 |
+
pipe.to(device=torch_device, dtype=torch_dtype).to(device)
|
55 |
+
|
56 |
+
# Load LCM LoRA
|
57 |
+
|
58 |
+
config = CompilationConfig.Default()
|
59 |
+
config.enable_xformers = True
|
60 |
+
config.enable_triton = True
|
61 |
+
config.enable_cuda_graph = True
|
62 |
+
pipe = compile(pipe, config=config)
|
63 |
|
64 |
compel_proc = Compel(
|
65 |
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
|
|
|
70 |
|
71 |
|
72 |
def predict(
|
73 |
+
prompt,
|
74 |
+
guidance,
|
75 |
+
steps,
|
76 |
+
seed=1231231,
|
77 |
+
randomize_bt=False,
|
78 |
+
progress=gr.Progress(track_tqdm=True),
|
79 |
):
|
80 |
+
if randomize_bt:
|
81 |
+
seed = np.random.randint(0, 2**32 - 1)
|
82 |
generator = torch.manual_seed(seed)
|
83 |
prompt_embeds, pooled_prompt_embeds = compel_proc(prompt)
|
84 |
|
|
|
100 |
)
|
101 |
if nsfw_content_detected:
|
102 |
raise gr.Error("NSFW content detected.")
|
103 |
+
return results.images[0], seed
|
104 |
|
105 |
|
106 |
css = """
|
|
|
128 |
placeholder="Insert your prompt here:", scale=5, container=False
|
129 |
)
|
130 |
generate_bt = gr.Button("Generate", scale=1)
|
131 |
+
|
132 |
image = gr.Image(type="filepath")
|
133 |
with gr.Accordion("Advanced options", open=False):
|
134 |
guidance = gr.Slider(
|
135 |
label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
|
136 |
)
|
137 |
steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
|
138 |
+
with gr.Row():
|
139 |
+
seed = gr.Slider(
|
140 |
+
randomize=True,
|
141 |
+
minimum=0,
|
142 |
+
maximum=12013012031030,
|
143 |
+
label="Seed",
|
144 |
+
step=1,
|
145 |
+
scale=5,
|
146 |
+
)
|
147 |
+
with gr.Group():
|
148 |
+
randomize_bt = gr.Checkbox(label="Randomize", value=False)
|
149 |
+
random_seed = gr.Textbox(show_label=False)
|
150 |
with gr.Accordion("Run with diffusers"):
|
151 |
+
gr.Markdown(
|
152 |
+
"""## Running LCM-LoRAs it with `diffusers`
|
153 |
```bash
|
154 |
pip install diffusers==0.23.0
|
155 |
```
|
|
|
167 |
)
|
168 |
results.images[0]
|
169 |
```
|
170 |
+
"""
|
171 |
+
)
|
172 |
+
|
173 |
+
inputs = [prompt, guidance, steps, seed, randomize_bt]
|
174 |
+
generate_bt.click(fn=predict, inputs=inputs, outputs=[image, random_seed])
|
175 |
|
176 |
demo.queue()
|
177 |
demo.launch()
|
requirements.txt
CHANGED
@@ -11,3 +11,6 @@ accelerate==0.24.0
|
|
11 |
compel==2.0.2
|
12 |
controlnet-aux==0.0.7
|
13 |
peft==0.6.0
|
|
|
|
|
|
|
|
11 |
compel==2.0.2
|
12 |
controlnet-aux==0.0.7
|
13 |
peft==0.6.0
|
14 |
+
stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/v0.0.15.post1/stable_fast-0.0.15.post1+torch211cu121-cp310-cp310-manylinux2014_x86_64.whl
|
15 |
+
xformers
|
16 |
+
triton
|