Spaces:
Paused
Paused
Commit
•
75a89e7
1
Parent(s):
0df5d5a
Update app.py
Browse files
app.py
CHANGED
@@ -24,6 +24,7 @@ saved_names = [
|
|
24 |
hf_hub_download(repo_id, filename) for _, _, repo_id, _, filename, _ in sdxl_loras
|
25 |
]
|
26 |
|
|
|
27 |
|
28 |
def update_selection(selected_state: gr.SelectData):
|
29 |
lora_repo = sdxl_loras[selected_state.index][2]
|
@@ -41,7 +42,7 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
41 |
torch_dtype=torch.float16,
|
42 |
).to("cpu")
|
43 |
original_pipe = copy.deepcopy(pipe)
|
44 |
-
pipe.to(
|
45 |
|
46 |
last_lora = ""
|
47 |
last_merged = False
|
@@ -58,7 +59,7 @@ def run_lora(prompt, negative, weight, selected_state):
|
|
58 |
if last_lora != repo_name:
|
59 |
if last_merged:
|
60 |
pipe = copy.deepcopy(original_pipe)
|
61 |
-
pipe.to(
|
62 |
else:
|
63 |
pipe.unload_lora_weights()
|
64 |
is_compatible = sdxl_loras[selected_state.index][5]
|
@@ -85,6 +86,7 @@ def run_lora(prompt, negative, weight, selected_state):
|
|
85 |
#lora_model.merge_to(
|
86 |
# pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
|
87 |
#)
|
|
|
88 |
last_merged = True
|
89 |
|
90 |
image = pipe(
|
|
|
24 |
hf_hub_download(repo_id, filename) for _, _, repo_id, _, filename, _ in sdxl_loras
|
25 |
]
|
26 |
|
27 |
+
device = "cuda" #replace this to `mps` if on a MacOS Silicon
|
28 |
|
29 |
def update_selection(selected_state: gr.SelectData):
|
30 |
lora_repo = sdxl_loras[selected_state.index][2]
|
|
|
42 |
torch_dtype=torch.float16,
|
43 |
).to("cpu")
|
44 |
original_pipe = copy.deepcopy(pipe)
|
45 |
+
pipe.to(device)
|
46 |
|
47 |
last_lora = ""
|
48 |
last_merged = False
|
|
|
59 |
if last_lora != repo_name:
|
60 |
if last_merged:
|
61 |
pipe = copy.deepcopy(original_pipe)
|
62 |
+
pipe.to(device)
|
63 |
else:
|
64 |
pipe.unload_lora_weights()
|
65 |
is_compatible = sdxl_loras[selected_state.index][5]
|
|
|
86 |
#lora_model.merge_to(
|
87 |
# pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
|
88 |
#)
|
89 |
+
pipe.to(device)
|
90 |
last_merged = True
|
91 |
|
92 |
image = pipe(
|