multimodalart HF staff commited on
Commit
75a89e7
β€’
1 Parent(s): 0df5d5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -24,6 +24,7 @@ saved_names = [
24
  hf_hub_download(repo_id, filename) for _, _, repo_id, _, filename, _ in sdxl_loras
25
  ]
26
 
 
27
 
28
  def update_selection(selected_state: gr.SelectData):
29
  lora_repo = sdxl_loras[selected_state.index][2]
@@ -41,7 +42,7 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
41
  torch_dtype=torch.float16,
42
  ).to("cpu")
43
  original_pipe = copy.deepcopy(pipe)
44
- pipe.to("cuda")
45
 
46
  last_lora = ""
47
  last_merged = False
@@ -58,7 +59,7 @@ def run_lora(prompt, negative, weight, selected_state):
58
  if last_lora != repo_name:
59
  if last_merged:
60
  pipe = copy.deepcopy(original_pipe)
61
- pipe.to("cuda")
62
  else:
63
  pipe.unload_lora_weights()
64
  is_compatible = sdxl_loras[selected_state.index][5]
@@ -85,6 +86,7 @@ def run_lora(prompt, negative, weight, selected_state):
85
  #lora_model.merge_to(
86
  # pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
87
  #)
 
88
  last_merged = True
89
 
90
  image = pipe(
 
24
  hf_hub_download(repo_id, filename) for _, _, repo_id, _, filename, _ in sdxl_loras
25
  ]
26
 
27
+ device = "cuda" #replace this to `mps` if on a MacOS Silicon
28
 
29
  def update_selection(selected_state: gr.SelectData):
30
  lora_repo = sdxl_loras[selected_state.index][2]
 
42
  torch_dtype=torch.float16,
43
  ).to("cpu")
44
  original_pipe = copy.deepcopy(pipe)
45
+ pipe.to(device)
46
 
47
  last_lora = ""
48
  last_merged = False
 
59
  if last_lora != repo_name:
60
  if last_merged:
61
  pipe = copy.deepcopy(original_pipe)
62
+ pipe.to(device)
63
  else:
64
  pipe.unload_lora_weights()
65
  is_compatible = sdxl_loras[selected_state.index][5]
 
86
  #lora_model.merge_to(
87
  # pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
88
  #)
89
+ pipe.to(device)
90
  last_merged = True
91
 
92
  image = pipe(