fffiloni commited on
Commit
08a0da5
1 Parent(s): 51dd65f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -53,21 +53,17 @@ def clear_gpu():
53
 
54
  def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
55
 
56
-
57
  lora_path = "checkpoints/"
58
  if orbit_type == "Left":
59
  weight_name = "orbit_left_lora_weights.safetensors"
 
60
  elif orbit_type == "Up":
61
  weight_name = "orbit_up_lora_weights.safetensors"
 
62
  lora_rank = 256
63
 
64
- pipe.unload_lora_weights()
65
-
66
- # Generate a timestamp for adapter_name
67
- adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
68
-
69
  # Load LoRA weights on CPU, move to GPU afterward
70
- pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"{adapter_timestamp}")
71
  pipe.fuse_lora(lora_scale=1 / lora_rank)
72
 
73
  # Move the pipeline to GPU for inference
@@ -87,6 +83,10 @@ def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True))
87
  use_dynamic_cfg=True,
88
  generator=torch.Generator(device="cpu").manual_seed(seed)
89
  )
 
 
 
 
90
 
91
  # Generate and save output video
92
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
53
 
54
  def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
55
 
 
56
  lora_path = "checkpoints/"
57
  if orbit_type == "Left":
58
  weight_name = "orbit_left_lora_weights.safetensors"
59
+ adapter_name = "orbit_left_lora_weights"
60
  elif orbit_type == "Up":
61
  weight_name = "orbit_up_lora_weights.safetensors"
62
+ adapter_name = "orbit_lup_lora_weights"
63
  lora_rank = 256
64
 
 
 
 
 
 
65
  # Load LoRA weights on CPU, move to GPU afterward
66
+ pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=adapter_name)
67
  pipe.fuse_lora(lora_scale=1 / lora_rank)
68
 
69
  # Move the pipeline to GPU for inference
 
83
  use_dynamic_cfg=True,
84
  generator=torch.Generator(device="cpu").manual_seed(seed)
85
  )
86
+
87
+ pipe.unfuse_lora()
88
+ pipe.unload_lora_weights()
89
+
90
 
91
  # Generate and save output video
92
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")