eyal.benaroche commited on
Commit
7b0a20d
β€’
1 Parent(s): 561be1b

fix device

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -26,22 +26,20 @@ with open("sdxl_lora.json", "r") as file:
26
  # Sort the loras by likes
27
  sdxl_loras_raw = sorted(sdxl_loras_raw, key=lambda x: x["likes"], reverse=True)
28
 
29
- device = "cuda" if torch.cuda.is_available() else "cpu"
30
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
31
 
32
  if gr.NO_RELOAD:
33
- torch.cuda.max_memory_allocated(device=device)
34
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
35
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
36
  pipe.load_lora_weights("jasperai/flash-sdxl", adapter_name="lora")
37
- pipe.to(device="cuda", dtype=torch.float16)
38
 
39
 
40
  MAX_SEED = np.iinfo(np.int32).max
41
  MAX_IMAGE_SIZE = 1024
42
 
43
 
44
- @spaces.GPU
45
  def check_and_load_lora_user(user_lora_selector, user_lora_weight, gr_lora_loaded):
46
  flash_sdxl_id = "jasperai/flash-sdxl"
47
 
@@ -277,7 +275,7 @@ with gr.Blocks(css=css) as demo:
277
  guidance_scale,
278
  ],
279
  outputs=[result],
280
- show_progress=True,
281
  )
282
 
283
  user_lora_weight.change(
@@ -294,7 +292,7 @@ with gr.Blocks(css=css) as demo:
294
  user_lora_selector,
295
  pre_prompt,
296
  ],
297
- show_progress=False,
298
  )
299
 
300
  gr.Markdown("**Disclaimer:**")
 
26
  # Sort the loras by likes
27
  sdxl_loras_raw = sorted(sdxl_loras_raw, key=lambda x: x["likes"], reverse=True)
28
 
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
31
 
32
  if gr.NO_RELOAD:
 
33
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
34
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
35
  pipe.load_lora_weights("jasperai/flash-sdxl", adapter_name="lora")
36
+ pipe.to(device=DEVICE, dtype=torch.float16)
37
 
38
 
39
  MAX_SEED = np.iinfo(np.int32).max
40
  MAX_IMAGE_SIZE = 1024
41
 
42
 
 
43
  def check_and_load_lora_user(user_lora_selector, user_lora_weight, gr_lora_loaded):
44
  flash_sdxl_id = "jasperai/flash-sdxl"
45
 
 
275
  guidance_scale,
276
  ],
277
  outputs=[result],
278
+ show_progress="minimal",
279
  )
280
 
281
  user_lora_weight.change(
 
292
  user_lora_selector,
293
  pre_prompt,
294
  ],
295
+ show_progress="hidden",
296
  )
297
 
298
  gr.Markdown("**Disclaimer:**")