Flux1-Depth-Dev / scripts /MergeSafetensors2.py
srcphag's picture
New Diffusers structure
1bf81fc
from safetensors.torch import save_file, load_file
import torch
import os
def inspect_keys(file_path, max_keys=10):
"""Helper function to inspect the structure of a safetensors file."""
state = load_file(file_path)
keys = list(state.keys())
print(f"\n{os.path.basename(file_path)} - Total keys: {len(keys)}")
print(f"First {max_keys} keys:")
for k in keys[:max_keys]:
print(f" {k}")
return keys
def merge_for_comfyui(
unet_path,
vae_path,
text_encoder_path,
output_path,
model_type="flux" # "flux", "sd15", "sdxl"
):
"""
Merge components into ComfyUI-compatible safetensors checkpoint.
Args:
unet_path: Path to the main model/transformer safetensors
vae_path: Path to the VAE safetensors
text_encoder_path: Path to the text encoder/CLIP safetensors
output_path: Path for the merged checkpoint
model_type: Type of model (flux, sd15, sdxl)
"""
print("=" * 60)
print("STEP 1: Inspecting input files...")
print("=" * 60)
# Inspect each file to understand structure
unet_keys = inspect_keys(unet_path)
vae_keys = inspect_keys(vae_path)
text_encoder_keys = inspect_keys(text_encoder_path)
print("\n" + "=" * 60)
print("STEP 2: Loading weights...")
print("=" * 60)
unet_state = load_file(unet_path)
vae_state = load_file(vae_path)
text_encoder_state = load_file(text_encoder_path)
print("\n" + "=" * 60)
print("STEP 3: Merging with proper key structure...")
print("=" * 60)
merged_state = {}
# Determine key prefixes based on existing structure
sample_unet_key = unet_keys[0]
sample_vae_key = vae_keys[0]
sample_te_key = text_encoder_keys[0]
print(f"\nDetected key patterns:")
print(f" UNet: {sample_unet_key}")
print(f" VAE: {sample_vae_key}")
print(f" Text Encoder: {sample_te_key}")
# Add UNet/Transformer weights
for key, value in unet_state.items():
# Keep original keys or add model prefix if needed
if key.startswith('model.') or key.startswith('diffusion_model.'):
merged_state[key] = value
else:
# Add ComfyUI-expected prefix
merged_state[f'model.diffusion_model.{key}'] = value
# Add VAE weights with proper structure
for key, value in vae_state.items():
if key.startswith('first_stage_model.') or key.startswith('vae.'):
merged_state[key] = value
elif key.startswith('decoder.') or key.startswith('encoder.'):
merged_state[f'first_stage_model.{key}'] = value
else:
merged_state[f'first_stage_model.decoder.{key}'] = value
# Add text encoder weights
for key, value in text_encoder_state.items():
if key.startswith('cond_stage_model.') or key.startswith('text_encoder.'):
merged_state[key] = value
else:
# For FLUX, might need different structure
if model_type.lower() == "flux":
merged_state[f'text_encoders.{key}'] = value
else:
merged_state[f'cond_stage_model.transformer.{key}'] = value
print(f"\nMerged state contains {len(merged_state)} parameters")
# Add metadata for ComfyUI recognition
print("\n" + "=" * 60)
print("STEP 4: Saving merged checkpoint...")
print("=" * 60)
save_file(merged_state, output_path)
print("\n✅ Merge complete!")
print(f"File saved to: {output_path}")
size_gb = os.path.getsize(output_path) / (1024**3)
print(f"File size: {size_gb:.2f} GB")
# Verify the merged file
print("\n" + "=" * 60)
print("STEP 5: Verifying merged file...")
print("=" * 60)
inspect_keys(output_path, max_keys=20)
def simple_merge_keep_structure(
unet_path,
vae_path,
text_encoder_path,
output_path
):
"""
Simple merge that preserves original key structure.
Use this if the files already have proper ComfyUI keys.
"""
print("Loading all components...")
unet_state = load_file(unet_path)
vae_state = load_file(vae_path)
text_encoder_state = load_file(text_encoder_path)
print("Merging...")
merged_state = {}
merged_state.update(unet_state)
merged_state.update(vae_state)
merged_state.update(text_encoder_state)
print(f"Saving {len(merged_state)} parameters...")
save_file(merged_state, output_path)
size_gb = os.path.getsize(output_path) / (1024**3)
print(f"✅ Done! File size: {size_gb:.2f} GB")
# Example usage
if __name__ == "__main__":
# Option 1: Smart merge with key detection
merge_for_comfyui(
unet_path="../flux1-depth-dev.safetensors",
vae_path="../vae/diffusion_pytorch_model.safetensors",
text_encoder_path="../text_encoder/model.safetensors",
output_path="../flux1-depth-dev_merged_model.safetensors",
model_type="flux"
)
# Option 2: Simple merge (if keys are already correct)
# simple_merge_keep_structure(
# unet_path="path/to/model.safetensors",
# vae_path="path/to/vae.safetensors",
# text_encoder_path="path/to/text_encoder.safetensors",
# output_path="merged_checkpoint.safetensors"
# )