Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -589,10 +589,12 @@ def on_prepare_loras_click(
|
|
| 589 |
print("[LoRA] No LoRAs selected, resetting to base model weights")
|
| 590 |
try:
|
| 591 |
transformer = ledger.transformer()
|
| 592 |
-
|
|
|
|
|
|
|
|
|
|
| 593 |
transformer.load_state_dict(base_weights, strict=False)
|
| 594 |
-
|
| 595 |
-
transformer = transformer.to("cuda")
|
| 596 |
current_lora_key = key
|
| 597 |
progress(1.0, desc="Done")
|
| 598 |
return "✓ Reset to base model (no LoRAs active)"
|
|
@@ -624,23 +626,20 @@ def on_prepare_loras_click(
|
|
| 624 |
|
| 625 |
try:
|
| 626 |
transformer = ledger.transformer()
|
| 627 |
-
target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 628 |
|
| 629 |
-
#
|
| 630 |
-
|
| 631 |
-
torch.cuda.empty_cache() # Free VRAM from the CPU copy
|
| 632 |
|
| 633 |
-
# Load
|
| 634 |
-
|
|
|
|
|
|
|
|
|
|
| 635 |
if missing:
|
| 636 |
print(f"[LoRA] Warning: {len(missing)} keys not found in fused state")
|
| 637 |
if unexpected:
|
| 638 |
print(f"[LoRA] Warning: {len(unexpected)} unexpected keys in fused state")
|
| 639 |
|
| 640 |
-
# Move transformer to target device (GPU for generation)
|
| 641 |
-
if target_device.type != "cpu":
|
| 642 |
-
transformer = transformer.to(target_device)
|
| 643 |
-
|
| 644 |
current_lora_key = key
|
| 645 |
progress(1.0, desc="Done")
|
| 646 |
return f"✓ Applied {len(active_loras)} LoRA(s) successfully"
|
|
|
|
| 589 |
print("[LoRA] No LoRAs selected, resetting to base model weights")
|
| 590 |
try:
|
| 591 |
transformer = ledger.transformer()
|
| 592 |
+
target_device = next(transformer.parameters()).device
|
| 593 |
+
|
| 594 |
+
# Get base weights and keep them on the same device as transformer
|
| 595 |
+
base_weights = {k: v.to(target_device) for k, v in transformer.state_dict().items()}
|
| 596 |
transformer.load_state_dict(base_weights, strict=False)
|
| 597 |
+
|
|
|
|
| 598 |
current_lora_key = key
|
| 599 |
progress(1.0, desc="Done")
|
| 600 |
return "✓ Reset to base model (no LoRAs active)"
|
|
|
|
| 626 |
|
| 627 |
try:
|
| 628 |
transformer = ledger.transformer()
|
|
|
|
| 629 |
|
| 630 |
+
# Determine target device - the transformer should already be on GPU
|
| 631 |
+
target_device = next(transformer.parameters()).device
|
|
|
|
| 632 |
|
| 633 |
+
# Load fused state dict directly into GPU transformer
|
| 634 |
+
# Convert CPU tensors to GPU tensors inline, then load
|
| 635 |
+
fused_state_gpu = {k: v.to(target_device) for k, v in fused_state.items()}
|
| 636 |
+
|
| 637 |
+
missing, unexpected = transformer.load_state_dict(fused_state_gpu, strict=False)
|
| 638 |
if missing:
|
| 639 |
print(f"[LoRA] Warning: {len(missing)} keys not found in fused state")
|
| 640 |
if unexpected:
|
| 641 |
print(f"[LoRA] Warning: {len(unexpected)} unexpected keys in fused state")
|
| 642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
current_lora_key = key
|
| 644 |
progress(1.0, desc="Done")
|
| 645 |
return f"✓ Applied {len(active_loras)} LoRA(s) successfully"
|