alvdansen commited on
Commit
c2e1982
β€’
1 Parent(s): 3abc48f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -32
app.py CHANGED
@@ -12,13 +12,29 @@ import os
12
  # Load the JSON data
13
  with open("sdxl_lora.json", "r") as file:
14
  data = json.load(file)
15
- sdxl_loras_raw = sorted(data, key=lambda x: x["likes"], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
19
 
20
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
21
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
 
22
  pipe.to(device=DEVICE, dtype=torch.float16)
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
@@ -29,10 +45,6 @@ def update_selection(selected_state: gr.SelectData, gr_sdxl_loras):
29
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
30
  return lora_id, trigger_word
31
 
32
- def load_lora_for_style(style_repo):
33
- pipe.unload_lora_weights()
34
- pipe.load_lora_weights(style_repo, adapter_name="lora")
35
-
36
  def get_image(image_data):
37
  if isinstance(image_data, str):
38
  return image_data
@@ -44,22 +56,20 @@ def get_image(image_data):
44
  print(f"Unexpected image_data format: {type(image_data)}")
45
  return None
46
 
47
- # Try loading from local path first
48
  if local_path and os.path.exists(local_path):
49
  try:
50
- Image.open(local_path).verify() # Verify that it's a valid image
51
  return local_path
52
  except Exception as e:
53
  print(f"Error loading local image {local_path}: {e}")
54
 
55
- # If local path fails or doesn't exist, try URL
56
  if hf_url:
57
  try:
58
  response = requests.get(hf_url)
59
  if response.status_code == 200:
60
  img = Image.open(requests.get(hf_url, stream=True).raw)
61
- img.verify() # Verify that it's a valid image
62
- img.save(local_path) # Save for future use
63
  return local_path
64
  else:
65
  print(f"Failed to fetch image from URL {hf_url}. Status code: {response.status_code}")
@@ -82,7 +92,19 @@ def infer(
82
  user_lora_weight,
83
  progress=gr.Progress(track_tqdm=True),
84
  ):
85
- load_lora_for_style(user_lora_selector)
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  if randomize_seed:
88
  seed = random.randint(0, MAX_SEED)
@@ -140,18 +162,16 @@ h1, h2 {
140
  with gr.Blocks(css=css) as demo:
141
  gr.Markdown(
142
  """
143
- # ⚑ FlashDiffusion: Araminta K's FlashLoRA Showcase ⚑
144
-
145
- This interactive demo showcases [Araminta K's models](https://huggingface.co/alvdansen) using [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) technology.
146
 
147
- ## Acknowledgments
148
- - Original Flash Diffusion technology by the Jasper AI team
149
- - Based on the paper: [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin
150
- - Models showcased here are created by Araminta K at Alvdansen Labs
151
-
152
- Explore the power of FlashLoRA with Araminta K's unique artistic styles!
153
  """
154
  )
 
 
 
155
 
156
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
157
  gr_lora_id = gr.State(value="")
@@ -169,9 +189,18 @@ with gr.Blocks(css=css) as demo:
169
 
170
  user_lora_selector = gr.Textbox(
171
  label="Current Selected LoRA",
 
172
  interactive=False,
173
  )
174
 
 
 
 
 
 
 
 
 
175
  with gr.Column(scale=3):
176
  prompt = gr.Textbox(
177
  label="Prompt",
@@ -218,6 +247,12 @@ with gr.Blocks(css=css) as demo:
218
  value=1,
219
  )
220
 
 
 
 
 
 
 
221
  negative_prompt = gr.Textbox(
222
  label="Negative Prompt",
223
  placeholder="Enter a negative Prompt",
@@ -225,7 +260,15 @@ with gr.Blocks(css=css) as demo:
225
  )
226
 
227
  gr.on(
228
- [run_button.click, prompt.submit],
 
 
 
 
 
 
 
 
229
  fn=infer,
230
  inputs=[
231
  pre_prompt,
@@ -236,7 +279,7 @@ with gr.Blocks(css=css) as demo:
236
  negative_prompt,
237
  guidance_scale,
238
  user_lora_selector,
239
- gr.Slider(label="Selected LoRA Weight", minimum=0.5, maximum=3, step=0.1, value=1),
240
  ],
241
  outputs=[result],
242
  )
@@ -249,17 +292,9 @@ with gr.Blocks(css=css) as demo:
249
  outputs=[user_lora_selector, pre_prompt],
250
  )
251
 
 
252
  gr.Markdown(
253
- """
254
- ## Unleash Your Creativity!
255
-
256
- This showcase brings together the speed of Flash Diffusion and the artistic flair of Araminta K's models.
257
- Craft your prompts, adjust the settings, and watch as AI brings your ideas to life in stunning detail.
258
-
259
- Remember to use this tool ethically and respect copyright and individual privacy.
260
-
261
- Enjoy exploring these unique artistic styles!
262
- """
263
  )
264
 
265
  demo.queue().launch()
 
12
  # Load the JSON data
13
  with open("sdxl_lora.json", "r") as file:
14
  data = json.load(file)
15
+ sdxl_loras_raw = [
16
+ {
17
+ "image": item["image"],
18
+ "title": item["title"],
19
+ "repo": item["repo"],
20
+ "trigger_word": item["trigger_word"],
21
+ "weights": item["weights"],
22
+ "is_pivotal": item.get("is_pivotal", False),
23
+ "text_embedding_weights": item.get("text_embedding_weights", None),
24
+ "likes": item.get("likes", 0),
25
+ }
26
+ for item in data
27
+ ]
28
+
29
+ # Sort the loras by likes
30
+ sdxl_loras_raw = sorted(sdxl_loras_raw, key=lambda x: x["likes"], reverse=True)
31
 
32
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
34
 
35
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
36
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
37
+ pipe.load_lora_weights("jasperai/flash-sdxl", adapter_name="lora")
38
  pipe.to(device=DEVICE, dtype=torch.float16)
39
 
40
  MAX_SEED = np.iinfo(np.int32).max
 
45
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
46
  return lora_id, trigger_word
47
 
 
 
 
 
48
  def get_image(image_data):
49
  if isinstance(image_data, str):
50
  return image_data
 
56
  print(f"Unexpected image_data format: {type(image_data)}")
57
  return None
58
 
 
59
  if local_path and os.path.exists(local_path):
60
  try:
61
+ Image.open(local_path).verify()
62
  return local_path
63
  except Exception as e:
64
  print(f"Error loading local image {local_path}: {e}")
65
 
 
66
  if hf_url:
67
  try:
68
  response = requests.get(hf_url)
69
  if response.status_code == 200:
70
  img = Image.open(requests.get(hf_url, stream=True).raw)
71
+ img.verify()
72
+ img.save(local_path)
73
  return local_path
74
  else:
75
  print(f"Failed to fetch image from URL {hf_url}. Status code: {response.status_code}")
 
92
  user_lora_weight,
93
  progress=gr.Progress(track_tqdm=True),
94
  ):
95
+ flash_sdxl_id = "jasperai/flash-sdxl"
96
+
97
+ new_adapter_id = user_lora_selector.replace("/", "_")
98
+ loaded_adapters = pipe.get_list_adapters()
99
+
100
+ if new_adapter_id not in loaded_adapters["unet"]:
101
+ gr.Info("Swapping LoRA")
102
+ pipe.unload_lora_weights()
103
+ pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
104
+ pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
105
+
106
+ pipe.set_adapters(["lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
107
+ gr.Info("LoRA setup done")
108
 
109
  if randomize_seed:
110
  seed = random.randint(0, MAX_SEED)
 
162
  with gr.Blocks(css=css) as demo:
163
  gr.Markdown(
164
  """
165
+ # ⚑ FlashDiffusion: FlashLoRA ⚑
166
+ This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) **on top of** existing LoRAs.
 
167
 
168
+ The distillation method proposed in [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin* from Jasper Research.
169
+ The LoRAs can be added **without** any retraining for similar results in most cases. Feel free to tweak the parameters and use your own LoRAs by giving a look at the [Github Repo](https://github.com/gojasper/flash-diffusion)
 
 
 
 
170
  """
171
  )
172
+ gr.Markdown(
173
+ "If you enjoy the space, please also promote *open-source* by giving a ⭐ to our repo [![GitHub Stars](https://img.shields.io/github/stars/gojasper/flash-diffusion?style=social)](https://github.com/gojasper/flash-diffusion)"
174
+ )
175
 
176
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
177
  gr_lora_id = gr.State(value="")
 
189
 
190
  user_lora_selector = gr.Textbox(
191
  label="Current Selected LoRA",
192
+ max_lines=1,
193
  interactive=False,
194
  )
195
 
196
+ user_lora_weight = gr.Slider(
197
+ label="Selected LoRA Weight",
198
+ minimum=0.5,
199
+ maximum=3,
200
+ step=0.1,
201
+ value=1,
202
+ )
203
+
204
  with gr.Column(scale=3):
205
  prompt = gr.Textbox(
206
  label="Prompt",
 
247
  value=1,
248
  )
249
 
250
+ hint_negative = gr.Markdown(
251
+ """πŸ’‘ _Hint : Negative Prompt will only work with Guidance > 1 but the model was
252
+ trained to be used with guidance = 1 (ie. without guidance).
253
+ Can degrade the results, use cautiously._"""
254
+ )
255
+
256
  negative_prompt = gr.Textbox(
257
  label="Negative Prompt",
258
  placeholder="Enter a negative Prompt",
 
260
  )
261
 
262
  gr.on(
263
+ [
264
+ run_button.click,
265
+ seed.change,
266
+ randomize_seed.change,
267
+ prompt.submit,
268
+ negative_prompt.change,
269
+ negative_prompt.submit,
270
+ guidance_scale.change,
271
+ ],
272
  fn=infer,
273
  inputs=[
274
  pre_prompt,
 
279
  negative_prompt,
280
  guidance_scale,
281
  user_lora_selector,
282
+ user_lora_weight,
283
  ],
284
  outputs=[result],
285
  )
 
292
  outputs=[user_lora_selector, pre_prompt],
293
  )
294
 
295
+ gr.Markdown("**Disclaimer:**")
296
  gr.Markdown(
297
+ "This demo is only for research purpose. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards."
 
 
 
 
 
 
 
 
 
298
  )
299
 
300
  demo.queue().launch()