Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -53,21 +53,17 @@ def clear_gpu():
|
|
53 |
|
54 |
def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
|
55 |
|
56 |
-
|
57 |
lora_path = "checkpoints/"
|
58 |
if orbit_type == "Left":
|
59 |
weight_name = "orbit_left_lora_weights.safetensors"
|
|
|
60 |
elif orbit_type == "Up":
|
61 |
weight_name = "orbit_up_lora_weights.safetensors"
|
|
|
62 |
lora_rank = 256
|
63 |
|
64 |
-
pipe.unload_lora_weights()
|
65 |
-
|
66 |
-
# Generate a timestamp for adapter_name
|
67 |
-
adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
68 |
-
|
69 |
# Load LoRA weights on CPU, move to GPU afterward
|
70 |
-
pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=
|
71 |
pipe.fuse_lora(lora_scale=1 / lora_rank)
|
72 |
|
73 |
# Move the pipeline to GPU for inference
|
@@ -87,6 +83,10 @@ def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True))
|
|
87 |
use_dynamic_cfg=True,
|
88 |
generator=torch.Generator(device="cpu").manual_seed(seed)
|
89 |
)
|
|
|
|
|
|
|
|
|
90 |
|
91 |
# Generate and save output video
|
92 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
53 |
|
54 |
def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
|
55 |
|
|
|
56 |
lora_path = "checkpoints/"
|
57 |
if orbit_type == "Left":
|
58 |
weight_name = "orbit_left_lora_weights.safetensors"
|
59 |
+
adapter_name = "orbit_left_lora_weights"
|
60 |
elif orbit_type == "Up":
|
61 |
weight_name = "orbit_up_lora_weights.safetensors"
|
62 |
+
adapter_name = "orbit_lup_lora_weights"
|
63 |
lora_rank = 256
|
64 |
|
|
|
|
|
|
|
|
|
|
|
65 |
# Load LoRA weights on CPU, move to GPU afterward
|
66 |
+
pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=adapter_name)
|
67 |
pipe.fuse_lora(lora_scale=1 / lora_rank)
|
68 |
|
69 |
# Move the pipeline to GPU for inference
|
|
|
83 |
use_dynamic_cfg=True,
|
84 |
generator=torch.Generator(device="cpu").manual_seed(seed)
|
85 |
)
|
86 |
+
|
87 |
+
pipe.unfuse_lora()
|
88 |
+
pipe.unload_lora_weights()
|
89 |
+
|
90 |
|
91 |
# Generate and save output video
|
92 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|