Ensure repo only accesses CublasLinear lazily
Browse files- float8_quantize.py +5 -1
- lora_loading.py +5 -4
float8_quantize.py
CHANGED
|
@@ -447,7 +447,11 @@ def quantize_flow_transformer_and_dispatch_float8(
|
|
| 447 |
quantize_modulation=quantize_modulation,
|
| 448 |
)
|
| 449 |
torch.cuda.empty_cache()
|
| 450 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
swap_to_cublaslinear(flow_model)
|
| 452 |
elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
|
| 453 |
logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
|
|
|
|
| 447 |
quantize_modulation=quantize_modulation,
|
| 448 |
)
|
| 449 |
torch.cuda.empty_cache()
|
| 450 |
+
if (
|
| 451 |
+
swap_linears_with_cublaslinear
|
| 452 |
+
and flow_dtype == torch.float16
|
| 453 |
+
and isinstance(CublasLinear, type(torch.nn.Linear))
|
| 454 |
+
):
|
| 455 |
swap_to_cublaslinear(flow_model)
|
| 456 |
elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
|
| 457 |
logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
|
lora_loading.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
import torch
|
| 2 |
-
from cublas_ops import CublasLinear
|
| 3 |
from loguru import logger
|
| 4 |
from safetensors.torch import load_file
|
| 5 |
from tqdm import tqdm
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from float8_quantize import F8Linear
|
| 8 |
from modules.flux_model import Flux
|
| 9 |
|
|
@@ -383,7 +386,7 @@ def apply_lora_weight_to_module(
|
|
| 383 |
|
| 384 |
|
| 385 |
@torch.inference_mode()
|
| 386 |
-
def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0):
|
| 387 |
has_guidance = model.params.guidance_embed
|
| 388 |
logger.info(f"Loading LoRA weights for {lora_path}")
|
| 389 |
lora_weights = load_file(lora_path)
|
|
@@ -408,8 +411,6 @@ def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0):
|
|
| 408 |
]
|
| 409 |
logger.debug("Keys extracted")
|
| 410 |
keys_without_ab = list(set(keys_without_ab))
|
| 411 |
-
if len(keys_without_ab) > 0:
|
| 412 |
-
logger.warning("Missing unconverted state dict keys!", len(keys_without_ab))
|
| 413 |
|
| 414 |
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
|
| 415 |
module = get_module_for_key(key, model)
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from loguru import logger
|
| 3 |
from safetensors.torch import load_file
|
| 4 |
from tqdm import tqdm
|
| 5 |
|
| 6 |
+
try:
|
| 7 |
+
from cublas_ops import CublasLinear
|
| 8 |
+
except Exception as e:
|
| 9 |
+
CublasLinear = type(None)
|
| 10 |
from float8_quantize import F8Linear
|
| 11 |
from modules.flux_model import Flux
|
| 12 |
|
|
|
|
| 386 |
|
| 387 |
|
| 388 |
@torch.inference_mode()
|
| 389 |
+
def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0) -> Flux:
|
| 390 |
has_guidance = model.params.guidance_embed
|
| 391 |
logger.info(f"Loading LoRA weights for {lora_path}")
|
| 392 |
lora_weights = load_file(lora_path)
|
|
|
|
| 411 |
]
|
| 412 |
logger.debug("Keys extracted")
|
| 413 |
keys_without_ab = list(set(keys_without_ab))
|
|
|
|
|
|
|
| 414 |
|
| 415 |
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
|
| 416 |
module = get_module_for_key(key, model)
|