dagloop5 commited on
Commit
3cc8e3c
·
verified ·
1 Parent(s): f366d6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -589,10 +589,12 @@ def on_prepare_loras_click(
589
  print("[LoRA] No LoRAs selected, resetting to base model weights")
590
  try:
591
  transformer = ledger.transformer()
592
- base_weights = {k: v.cpu() for k, v in transformer.state_dict().items()}
 
 
 
593
  transformer.load_state_dict(base_weights, strict=False)
594
- if torch.cuda.is_available():
595
- transformer = transformer.to("cuda")
596
  current_lora_key = key
597
  progress(1.0, desc="Done")
598
  return "✓ Reset to base model (no LoRAs active)"
@@ -624,23 +626,20 @@ def on_prepare_loras_click(
624
 
625
  try:
626
  transformer = ledger.transformer()
627
- target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
628
 
629
- # Move transformer to CPU for loading (avoids device mismatch)
630
- transformer = transformer.to("cpu")
631
- torch.cuda.empty_cache() # Free VRAM from the CPU copy
632
 
633
- # Load the fused state dict
634
- missing, unexpected = transformer.load_state_dict(fused_state, strict=False)
 
 
 
635
  if missing:
636
  print(f"[LoRA] Warning: {len(missing)} keys not found in fused state")
637
  if unexpected:
638
  print(f"[LoRA] Warning: {len(unexpected)} unexpected keys in fused state")
639
 
640
- # Move transformer to target device (GPU for generation)
641
- if target_device.type != "cpu":
642
- transformer = transformer.to(target_device)
643
-
644
  current_lora_key = key
645
  progress(1.0, desc="Done")
646
  return f"✓ Applied {len(active_loras)} LoRA(s) successfully"
 
589
  print("[LoRA] No LoRAs selected, resetting to base model weights")
590
  try:
591
  transformer = ledger.transformer()
592
+ target_device = next(transformer.parameters()).device
593
+
594
+ # Get base weights and keep them on the same device as transformer
595
+ base_weights = {k: v.to(target_device) for k, v in transformer.state_dict().items()}
596
  transformer.load_state_dict(base_weights, strict=False)
597
+
 
598
  current_lora_key = key
599
  progress(1.0, desc="Done")
600
  return "✓ Reset to base model (no LoRAs active)"
 
626
 
627
  try:
628
  transformer = ledger.transformer()
 
629
 
630
+ # Determine target device - the transformer should already be on GPU
631
+ target_device = next(transformer.parameters()).device
 
632
 
633
+ # Load fused state dict directly into GPU transformer
634
+ # Convert CPU tensors to GPU tensors inline, then load
635
+ fused_state_gpu = {k: v.to(target_device) for k, v in fused_state.items()}
636
+
637
+ missing, unexpected = transformer.load_state_dict(fused_state_gpu, strict=False)
638
  if missing:
639
  print(f"[LoRA] Warning: {len(missing)} keys not found in fused state")
640
  if unexpected:
641
  print(f"[LoRA] Warning: {len(unexpected)} unexpected keys in fused state")
642
 
 
 
 
 
643
  current_lora_key = key
644
  progress(1.0, desc="Done")
645
  return f"✓ Applied {len(active_loras)} LoRA(s) successfully"