| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import argparse
|
| | import os
|
| | from typing import Dict, Tuple, List, Optional
|
| | from collections import defaultdict
|
| | import math
|
| |
|
| | import torch
|
| | from safetensors.torch import load_file, save_file
|
| | from safetensors import safe_open
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | NORMALIZE_OVERLAPS = True
|
| |
|
| |
|
| |
|
| |
|
| | CLIP_RATIO: Optional[float] = 1.0
|
| |
|
| |
|
| |
|
| |
|
| | def parse_lora_list(path: str) -> List[Tuple[str, float, float, float]]:
|
| | """
|
| | Parse list_of_loras.txt with lines like:
|
| | filename.safetensors,0.7,0.0
|
| | filename2.safetensors,1.0,0.5,0.3
|
| |
|
| | Returns list of tuples:
|
| | (path, video_strength, lerp_with_existing, audio_strength)
|
| |
|
| | Where:
|
| | video_strength: base strength for video/shared weights
|
| | audio_strength: base strength for audio weights
|
| | (defaults to video_strength if omitted)
|
| | lerp_with_existing in [0, 1]:
|
| | 0.0 -> fully normalized
|
| | 1.0 -> fully direct
|
| | between -> blend between normalized and direct
|
| | """
|
| | loras: List[Tuple[str, float, float, float]] = []
|
| | with open(path, "r", encoding="utf-8") as f:
|
| | for line in f:
|
| | line = line.strip()
|
| | if not line or line.startswith("#"):
|
| | continue
|
| |
|
| | parts = [p.strip() for p in line.split(",")]
|
| | if len(parts) < 3:
|
| | raise ValueError(f"Invalid LoRA line (need at least file,video_strength,lerp): {line}")
|
| |
|
| | filename = parts[0]
|
| | video_strength = float(parts[1])
|
| | lerp = float(parts[2])
|
| |
|
| | if len(parts) >= 4:
|
| | audio_strength = float(parts[3])
|
| | else:
|
| | audio_strength = video_strength
|
| |
|
| | lerp = max(0.0, min(1.0, lerp))
|
| |
|
| | loras.append((filename, video_strength, lerp, audio_strength))
|
| |
|
| | return loras
|
| |
|
| |
|
| |
|
| |
|
| | def load_base_with_metadata(path: str):
|
| | with safe_open(path, framework="pt", device="cpu") as f:
|
| | metadata = f.metadata() or {}
|
| | tensors = load_file(path, device="cpu")
|
| | return tensors, metadata
|
| |
|
| |
|
| |
|
| |
|
| | def group_lora_pairs(lora_tensors: Dict[str, torch.Tensor]):
|
| | prefixes = {}
|
| | for k in lora_tensors.keys():
|
| | if k.endswith(".lora_A.weight"):
|
| | prefix = k[: -len(".lora_A.weight")]
|
| | prefixes.setdefault(prefix, {})["A"] = k
|
| | elif k.endswith(".lora_B.weight"):
|
| | prefix = k[: -len(".lora_B.weight")]
|
| | prefixes.setdefault(prefix, {})["B"] = k
|
| | elif k.endswith(".alpha"):
|
| | prefix = k[: -len(".alpha")]
|
| | prefixes.setdefault(prefix, {})["alpha"] = k
|
| |
|
| | for prefix, keys in prefixes.items():
|
| | if "A" not in keys or "B" not in keys:
|
| | print(f"Warning: incomplete LoRA prefix {prefix}")
|
| | continue
|
| | yield prefix, keys["A"], keys["B"], keys.get("alpha")
|
| |
|
| |
|
| | def find_base_weight_key(base_tensors, lora_prefix):
|
| | candidates = [
|
| | f"{lora_prefix}.weight",
|
| | f"model.{lora_prefix}.weight",
|
| | lora_prefix,
|
| | f"model.{lora_prefix}",
|
| | ]
|
| | for c in candidates:
|
| | if c in base_tensors:
|
| | return c
|
| | return None
|
| |
|
| |
|
| |
|
| |
|
| | def classify_prefix(prefix: str) -> str:
|
| | """
|
| | Classify a LoRA prefix as 'audio', 'video', 'cross', or 'shared'.
|
| | """
|
| | p = prefix.lower()
|
| |
|
| |
|
| | if "audio_to_video" in p or "video_to_audio" in p:
|
| | return "cross"
|
| |
|
| |
|
| | if "audio_attn" in p or "audio_ff" in p or ".audio_" in p:
|
| | return "audio"
|
| |
|
| |
|
| | if "video_attn" in p or "video_ff" in p or ".video_" in p:
|
| | return "video"
|
| |
|
| |
|
| | return "shared"
|
| |
|
| |
|
| | def effective_strength_for_prefix(
|
| | prefix: str,
|
| | video_strength: float,
|
| | audio_strength: float,
|
| | ) -> float:
|
| | kind = classify_prefix(prefix)
|
| | if kind == "audio":
|
| | return audio_strength
|
| | elif kind == "video":
|
| | return video_strength
|
| | elif kind == "cross":
|
| |
|
| | return math.sqrt(max(video_strength, 0.0) * max(audio_strength, 0.0))
|
| | else:
|
| |
|
| | return video_strength
|
| |
|
| |
|
| |
|
| |
|
| | def compute_strength_sums(
|
| | base_tensors,
|
| | lora_specs: List[Tuple[str, float, float, float]],
|
| | ) -> Dict[str, float]:
|
| | """
|
| | For each base weight key, compute the sum of effective strengths of all LoRAs
|
| | that touch it (using video/audio/cross classification).
|
| | """
|
| | strength_sum: Dict[str, float] = defaultdict(float)
|
| |
|
| | for lora_path, video_strength, lerp, audio_strength in lora_specs:
|
| | print(f"[Pass 1] Scanning {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
|
| | lora_tensors = load_file(lora_path, device="cpu")
|
| |
|
| | for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
|
| | base_key = find_base_weight_key(base_tensors, prefix)
|
| | if base_key is None:
|
| | continue
|
| |
|
| | eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)
|
| | strength_sum[base_key] += eff_strength
|
| |
|
| | del lora_tensors
|
| |
|
| | print(f"[Pass 1] Keys with strength contributions: {len(strength_sum)}")
|
| | return strength_sum
|
| |
|
| |
|
| |
|
| |
|
| | def apply_loras_streaming(
|
| | base_tensors,
|
| | lora_specs: List[Tuple[str, float, float, float]],
|
| | strength_sum: Dict[str, float],
|
| | clip_ratio: Optional[float] = CLIP_RATIO,
|
| | ):
|
| | for lora_path, video_strength, lerp, audio_strength in lora_specs:
|
| | print(f"[Pass 2] Applying {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
|
| | lora_tensors = load_file(lora_path, device="cpu")
|
| |
|
| | applied = 0
|
| | skipped = 0
|
| |
|
| | for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
|
| | base_key = find_base_weight_key(base_tensors, prefix)
|
| | if base_key is None:
|
| | skipped += 1
|
| | continue
|
| |
|
| | W = base_tensors[base_key]
|
| |
|
| | A = lora_tensors[A_key].to(torch.float32)
|
| | B = lora_tensors[B_key].to(torch.float32)
|
| | delta = B @ A
|
| |
|
| | if delta.shape != W.shape:
|
| | raise ValueError(
|
| | f"Shape mismatch for {prefix}: delta {delta.shape} vs base {W.shape}"
|
| | )
|
| |
|
| | rank = A.shape[0] if A.dim() == 2 else A.numel()
|
| |
|
| |
|
| | eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)
|
| |
|
| |
|
| | if alpha_key is not None:
|
| | alpha = float(lora_tensors[alpha_key].to(torch.float32).item())
|
| | base_scale = eff_strength * alpha / max(rank, 1)
|
| | else:
|
| | base_scale = eff_strength
|
| |
|
| |
|
| | if NORMALIZE_OVERLAPS:
|
| | total_strength = strength_sum.get(base_key, 0.0)
|
| | denom = max(1.0, total_strength)
|
| | scale_norm = base_scale / denom
|
| | else:
|
| | scale_norm = base_scale
|
| |
|
| |
|
| | scale_direct = base_scale
|
| |
|
| |
|
| | scale = (1.0 - lerp) * scale_norm + lerp * scale_direct
|
| |
|
| | delta_scaled = delta * scale
|
| |
|
| |
|
| | if clip_ratio is not None:
|
| | Wf = W.to(torch.float32)
|
| | base_norm = Wf.norm().item()
|
| | delta_norm = delta_scaled.norm().item()
|
| |
|
| | if delta_norm > clip_ratio * base_norm and delta_norm > 0:
|
| | delta_scaled *= (clip_ratio * base_norm) / delta_norm
|
| |
|
| |
|
| | W_new = W.to(torch.float32) + delta_scaled
|
| | base_tensors[base_key] = W_new.to(W.dtype)
|
| |
|
| | applied += 1
|
| |
|
| | print(f"[Pass 2] {lora_path}: applied {applied}, skipped {skipped}")
|
| | del lora_tensors
|
| |
|
| |
|
| | def apply_loras_to_base(base_tensors, lora_specs):
|
| | strength_sum = compute_strength_sums(base_tensors, lora_specs)
|
| | apply_loras_streaming(base_tensors, lora_specs, strength_sum)
|
| |
|
| |
|
| |
|
| |
|
| | def is_vae_key(key: str) -> bool:
|
| | return any(key.startswith(p) for p in [
|
| | "first_stage_model.",
|
| | "model.first_stage_model.",
|
| | "vae.",
|
| | "model.vae.",
|
| | ])
|
| |
|
| |
|
| | def is_text_encoder_key(key: str) -> bool:
|
| | return any(key.startswith(p) for p in [
|
| | "text_encoder.",
|
| | "model.text_encoder.",
|
| | "cond_stage_model.",
|
| | "model.cond_stage_model.",
|
| | ])
|
| |
|
| |
|
| | def is_unet_key(key: str) -> bool:
|
| | return any(key.startswith(p) for p in [
|
| | "model.diffusion_model.",
|
| | "diffusion_model.",
|
| | ])
|
| |
|
| |
|
| | def convert_to_fp8_inplace(tensors: Dict[str, torch.Tensor]):
|
| | fp8_dtype = torch.float8_e4m3fn
|
| |
|
| | converted = 0
|
| | skipped_vae = 0
|
| | skipped_other = 0
|
| |
|
| | for k, v in list(tensors.items()):
|
| | if not torch.is_floating_point(v):
|
| | skipped_other += 1
|
| | continue
|
| |
|
| | if is_vae_key(k):
|
| | skipped_vae += 1
|
| | continue
|
| |
|
| | if is_unet_key(k) or is_text_encoder_key(k):
|
| | tensors[k] = v.to(fp8_dtype)
|
| | converted += 1
|
| | else:
|
| | skipped_other += 1
|
| |
|
| | print(
|
| | f"FP8 conversion: converted={converted}, "
|
| | f"skipped_vae={skipped_vae}, skipped_other={skipped_other}"
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser(
|
| | description=(
|
| | "Apply LTX2-style LoRAs with separate video/audio strengths, "
|
| | "strength-weighted normalization, LERP blending, per‑LoRA clipping, "
|
| | "FP8 conversion, and metadata preservation (streaming, memory‑efficient)."
|
| | )
|
| | )
|
| | parser.add_argument("base", help="Base checkpoint (.safetensors)")
|
| | parser.add_argument("lora_list", help="Text file: path,video_strength,lerp[,audio_strength]")
|
| | parser.add_argument("output", help="Output FP8 checkpoint (.safetensors)")
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | if not os.path.isfile(args.base):
|
| | raise FileNotFoundError(args.base)
|
| |
|
| | lora_specs = parse_lora_list(args.lora_list)
|
| | if not lora_specs:
|
| | raise ValueError("No LoRAs specified.")
|
| |
|
| | print(f"Loading base checkpoint: {args.base}")
|
| | base_tensors, metadata = load_base_with_metadata(args.base)
|
| | print(f"Base checkpoint has {len(base_tensors)} tensors.")
|
| |
|
| | apply_loras_to_base(base_tensors, lora_specs)
|
| |
|
| | print("Converting UNet + text encoder to FP8 (leaving VAE untouched)...")
|
| | convert_to_fp8_inplace(base_tensors)
|
| |
|
| | print(f"Saving merged FP8 checkpoint to: {args.output}")
|
| | os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
| | save_file(base_tensors, args.output, metadata=metadata)
|
| | print("Done.")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|