Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import spaces
|
|
7 |
import torch
|
8 |
from diffusers import DiffusionPipeline, LCMScheduler
|
9 |
|
|
|
10 |
with open("sdxl_lora.json", "r") as file:
|
11 |
data = json.load(file)
|
12 |
sdxl_loras_raw = [
|
@@ -31,24 +32,19 @@ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
|
31 |
|
32 |
pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
|
33 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
34 |
-
pipe.load_lora_weights("alvdansen/flash-lora-araminta-k-styles", adapter_name="lora")
|
35 |
pipe.to(device=DEVICE, dtype=torch.float16)
|
36 |
|
37 |
-
|
38 |
MAX_SEED = np.iinfo(np.int32).max
|
39 |
MAX_IMAGE_SIZE = 1024
|
40 |
|
41 |
-
|
42 |
-
def update_selection(
|
43 |
-
selected_state: gr.SelectData,
|
44 |
-
gr_sdxl_loras,
|
45 |
-
):
|
46 |
-
|
47 |
lora_id = gr_sdxl_loras[selected_state.index]["repo"]
|
48 |
trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
|
49 |
-
|
50 |
return lora_id, trigger_word
|
51 |
|
|
|
|
|
|
|
52 |
|
53 |
@spaces.GPU
|
54 |
def infer(
|
@@ -63,19 +59,8 @@ def infer(
|
|
63 |
user_lora_weight,
|
64 |
progress=gr.Progress(track_tqdm=True),
|
65 |
):
|
66 |
-
|
67 |
-
|
68 |
-
new_adapter_id = user_lora_selector.replace("/", "_")
|
69 |
-
loaded_adapters = pipe.get_list_adapters()
|
70 |
-
|
71 |
-
if new_adapter_id not in loaded_adapters["unet"]:
|
72 |
-
gr.Info("Swapping LoRA")
|
73 |
-
pipe.unload_lora_weights()
|
74 |
-
pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
|
75 |
-
pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
|
76 |
-
|
77 |
-
pipe.set_adapters(["lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
|
78 |
-
gr.Info("LoRA setup done")
|
79 |
|
80 |
if randomize_seed:
|
81 |
seed = random.randint(0, MAX_SEED)
|
@@ -95,19 +80,15 @@ def infer(
|
|
95 |
|
96 |
return image
|
97 |
|
98 |
-
|
99 |
css = """
|
100 |
-
|
101 |
h1 {
|
102 |
text-align: center;
|
103 |
display:block;
|
104 |
}
|
105 |
-
|
106 |
p {
|
107 |
text-align: justify;
|
108 |
display:block;
|
109 |
}
|
110 |
-
|
111 |
"""
|
112 |
|
113 |
if torch.cuda.is_available():
|
@@ -116,11 +97,9 @@ else:
|
|
116 |
power_device = "CPU"
|
117 |
|
118 |
with gr.Blocks(css=css) as demo:
|
119 |
-
|
120 |
gr.Markdown(
|
121 |
f"""
|
122 |
# β‘ FlashDiffusion: FlashLoRA β‘
|
123 |
-
|
124 |
This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) **on top of** existing LoRAs.
|
125 |
|
126 |
The distillation method proposed in [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin* from Jasper Research.
|
@@ -137,11 +116,8 @@ with gr.Blocks(css=css) as demo:
|
|
137 |
gr_lora_id = gr.State(value="")
|
138 |
|
139 |
with gr.Row():
|
140 |
-
|
141 |
with gr.Blocks():
|
142 |
-
|
143 |
with gr.Column():
|
144 |
-
|
145 |
user_lora_selector = gr.Textbox(
|
146 |
label="Current Selected LoRA",
|
147 |
max_lines=1,
|
@@ -166,9 +142,7 @@ with gr.Blocks(css=css) as demo:
|
|
166 |
)
|
167 |
|
168 |
with gr.Column():
|
169 |
-
|
170 |
with gr.Row():
|
171 |
-
|
172 |
prompt = gr.Text(
|
173 |
label="Prompt",
|
174 |
show_label=False,
|
@@ -183,7 +157,6 @@ with gr.Blocks(css=css) as demo:
|
|
183 |
result = gr.Image(label="Result", show_label=False)
|
184 |
|
185 |
with gr.Accordion("Advanced Settings", open=False):
|
186 |
-
|
187 |
pre_prompt = gr.Text(
|
188 |
label="Pre-Prompt",
|
189 |
show_label=True,
|
@@ -204,7 +177,6 @@ with gr.Blocks(css=css) as demo:
|
|
204 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
205 |
|
206 |
with gr.Row():
|
207 |
-
|
208 |
num_inference_steps = gr.Slider(
|
209 |
label="Number of inference steps",
|
210 |
minimum=4,
|
@@ -214,7 +186,6 @@ with gr.Blocks(css=css) as demo:
|
|
214 |
)
|
215 |
|
216 |
with gr.Row():
|
217 |
-
|
218 |
guidance_scale = gr.Slider(
|
219 |
label="Guidance Scale",
|
220 |
minimum=1,
|
@@ -242,7 +213,6 @@ with gr.Blocks(css=css) as demo:
|
|
242 |
run_button.click,
|
243 |
seed.change,
|
244 |
randomize_seed.change,
|
245 |
-
# prompt.change,
|
246 |
prompt.submit,
|
247 |
negative_prompt.change,
|
248 |
negative_prompt.submit,
|
@@ -261,7 +231,6 @@ with gr.Blocks(css=css) as demo:
|
|
261 |
user_lora_weight,
|
262 |
],
|
263 |
outputs=[result],
|
264 |
-
# show_progress="full",
|
265 |
)
|
266 |
|
267 |
gallery.select(
|
@@ -279,5 +248,4 @@ with gr.Blocks(css=css) as demo:
|
|
279 |
"This demo is only for research purpose. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards."
|
280 |
)
|
281 |
|
282 |
-
|
283 |
-
demo.queue().launch()
|
|
|
7 |
import torch
|
8 |
from diffusers import DiffusionPipeline, LCMScheduler
|
9 |
|
10 |
+
# Load the JSON data
|
11 |
with open("sdxl_lora.json", "r") as file:
|
12 |
data = json.load(file)
|
13 |
sdxl_loras_raw = [
|
|
|
32 |
|
33 |
pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
|
34 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
|
|
35 |
pipe.to(device=DEVICE, dtype=torch.float16)
|
36 |
|
|
|
37 |
MAX_SEED = np.iinfo(np.int32).max
|
38 |
MAX_IMAGE_SIZE = 1024
|
39 |
|
40 |
+
def update_selection(selected_state: gr.SelectData, gr_sdxl_loras):
|
|
|
|
|
|
|
|
|
|
|
41 |
lora_id = gr_sdxl_loras[selected_state.index]["repo"]
|
42 |
trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
|
|
|
43 |
return lora_id, trigger_word
|
44 |
|
45 |
+
def load_lora_for_style(style_repo):
|
46 |
+
pipe.unload_lora_weights() # Unload any previously loaded weights
|
47 |
+
pipe.load_lora_weights(style_repo, adapter_name="lora")
|
48 |
|
49 |
@spaces.GPU
|
50 |
def infer(
|
|
|
59 |
user_lora_weight,
|
60 |
progress=gr.Progress(track_tqdm=True),
|
61 |
):
|
62 |
+
# Load the appropriate LoRA weights
|
63 |
+
load_lora_for_style(user_lora_selector)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
if randomize_seed:
|
66 |
seed = random.randint(0, MAX_SEED)
|
|
|
80 |
|
81 |
return image
|
82 |
|
|
|
83 |
css = """
|
|
|
84 |
h1 {
|
85 |
text-align: center;
|
86 |
display:block;
|
87 |
}
|
|
|
88 |
p {
|
89 |
text-align: justify;
|
90 |
display:block;
|
91 |
}
|
|
|
92 |
"""
|
93 |
|
94 |
if torch.cuda.is_available():
|
|
|
97 |
power_device = "CPU"
|
98 |
|
99 |
with gr.Blocks(css=css) as demo:
|
|
|
100 |
gr.Markdown(
|
101 |
f"""
|
102 |
# β‘ FlashDiffusion: FlashLoRA β‘
|
|
|
103 |
This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) **on top of** existing LoRAs.
|
104 |
|
105 |
The distillation method proposed in [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin* from Jasper Research.
|
|
|
116 |
gr_lora_id = gr.State(value="")
|
117 |
|
118 |
with gr.Row():
|
|
|
119 |
with gr.Blocks():
|
|
|
120 |
with gr.Column():
|
|
|
121 |
user_lora_selector = gr.Textbox(
|
122 |
label="Current Selected LoRA",
|
123 |
max_lines=1,
|
|
|
142 |
)
|
143 |
|
144 |
with gr.Column():
|
|
|
145 |
with gr.Row():
|
|
|
146 |
prompt = gr.Text(
|
147 |
label="Prompt",
|
148 |
show_label=False,
|
|
|
157 |
result = gr.Image(label="Result", show_label=False)
|
158 |
|
159 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
160 |
pre_prompt = gr.Text(
|
161 |
label="Pre-Prompt",
|
162 |
show_label=True,
|
|
|
177 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
178 |
|
179 |
with gr.Row():
|
|
|
180 |
num_inference_steps = gr.Slider(
|
181 |
label="Number of inference steps",
|
182 |
minimum=4,
|
|
|
186 |
)
|
187 |
|
188 |
with gr.Row():
|
|
|
189 |
guidance_scale = gr.Slider(
|
190 |
label="Guidance Scale",
|
191 |
minimum=1,
|
|
|
213 |
run_button.click,
|
214 |
seed.change,
|
215 |
randomize_seed.change,
|
|
|
216 |
prompt.submit,
|
217 |
negative_prompt.change,
|
218 |
negative_prompt.submit,
|
|
|
231 |
user_lora_weight,
|
232 |
],
|
233 |
outputs=[result],
|
|
|
234 |
)
|
235 |
|
236 |
gallery.select(
|
|
|
248 |
"This demo is only for research purpose. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards."
|
249 |
)
|
250 |
|
251 |
+
demo.queue().launch()
|
|