|
|
from safetensors.torch import save_file, load_file |
|
|
import torch |
|
|
import os |
|
|
|
|
|
def inspect_keys(file_path, max_keys=10): |
|
|
"""Helper function to inspect the structure of a safetensors file.""" |
|
|
state = load_file(file_path) |
|
|
keys = list(state.keys()) |
|
|
print(f"\n{os.path.basename(file_path)} - Total keys: {len(keys)}") |
|
|
print(f"First {max_keys} keys:") |
|
|
for k in keys[:max_keys]: |
|
|
print(f" {k}") |
|
|
return keys |
|
|
|
|
|
def merge_for_comfyui( |
|
|
unet_path, |
|
|
vae_path, |
|
|
text_encoder_path, |
|
|
output_path, |
|
|
model_type="flux" |
|
|
): |
|
|
""" |
|
|
Merge components into ComfyUI-compatible safetensors checkpoint. |
|
|
|
|
|
Args: |
|
|
unet_path: Path to the main model/transformer safetensors |
|
|
vae_path: Path to the VAE safetensors |
|
|
text_encoder_path: Path to the text encoder/CLIP safetensors |
|
|
output_path: Path for the merged checkpoint |
|
|
model_type: Type of model (flux, sd15, sdxl) |
|
|
""" |
|
|
|
|
|
print("=" * 60) |
|
|
print("STEP 1: Inspecting input files...") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
unet_keys = inspect_keys(unet_path) |
|
|
vae_keys = inspect_keys(vae_path) |
|
|
text_encoder_keys = inspect_keys(text_encoder_path) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("STEP 2: Loading weights...") |
|
|
print("=" * 60) |
|
|
|
|
|
unet_state = load_file(unet_path) |
|
|
vae_state = load_file(vae_path) |
|
|
text_encoder_state = load_file(text_encoder_path) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("STEP 3: Merging with proper key structure...") |
|
|
print("=" * 60) |
|
|
|
|
|
merged_state = {} |
|
|
|
|
|
|
|
|
sample_unet_key = unet_keys[0] |
|
|
sample_vae_key = vae_keys[0] |
|
|
sample_te_key = text_encoder_keys[0] |
|
|
|
|
|
print(f"\nDetected key patterns:") |
|
|
print(f" UNet: {sample_unet_key}") |
|
|
print(f" VAE: {sample_vae_key}") |
|
|
print(f" Text Encoder: {sample_te_key}") |
|
|
|
|
|
|
|
|
for key, value in unet_state.items(): |
|
|
|
|
|
if key.startswith('model.') or key.startswith('diffusion_model.'): |
|
|
merged_state[key] = value |
|
|
else: |
|
|
|
|
|
merged_state[f'model.diffusion_model.{key}'] = value |
|
|
|
|
|
|
|
|
for key, value in vae_state.items(): |
|
|
if key.startswith('first_stage_model.') or key.startswith('vae.'): |
|
|
merged_state[key] = value |
|
|
elif key.startswith('decoder.') or key.startswith('encoder.'): |
|
|
merged_state[f'first_stage_model.{key}'] = value |
|
|
else: |
|
|
merged_state[f'first_stage_model.decoder.{key}'] = value |
|
|
|
|
|
|
|
|
for key, value in text_encoder_state.items(): |
|
|
if key.startswith('cond_stage_model.') or key.startswith('text_encoder.'): |
|
|
merged_state[key] = value |
|
|
else: |
|
|
|
|
|
if model_type.lower() == "flux": |
|
|
merged_state[f'text_encoders.{key}'] = value |
|
|
else: |
|
|
merged_state[f'cond_stage_model.transformer.{key}'] = value |
|
|
|
|
|
print(f"\nMerged state contains {len(merged_state)} parameters") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("STEP 4: Saving merged checkpoint...") |
|
|
print("=" * 60) |
|
|
|
|
|
save_file(merged_state, output_path) |
|
|
|
|
|
print("\n✅ Merge complete!") |
|
|
print(f"File saved to: {output_path}") |
|
|
|
|
|
size_gb = os.path.getsize(output_path) / (1024**3) |
|
|
print(f"File size: {size_gb:.2f} GB") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("STEP 5: Verifying merged file...") |
|
|
print("=" * 60) |
|
|
inspect_keys(output_path, max_keys=20) |
|
|
|
|
|
|
|
|
def simple_merge_keep_structure( |
|
|
unet_path, |
|
|
vae_path, |
|
|
text_encoder_path, |
|
|
output_path |
|
|
): |
|
|
""" |
|
|
Simple merge that preserves original key structure. |
|
|
Use this if the files already have proper ComfyUI keys. |
|
|
""" |
|
|
print("Loading all components...") |
|
|
|
|
|
unet_state = load_file(unet_path) |
|
|
vae_state = load_file(vae_path) |
|
|
text_encoder_state = load_file(text_encoder_path) |
|
|
|
|
|
print("Merging...") |
|
|
merged_state = {} |
|
|
merged_state.update(unet_state) |
|
|
merged_state.update(vae_state) |
|
|
merged_state.update(text_encoder_state) |
|
|
|
|
|
print(f"Saving {len(merged_state)} parameters...") |
|
|
save_file(merged_state, output_path) |
|
|
|
|
|
size_gb = os.path.getsize(output_path) / (1024**3) |
|
|
print(f"✅ Done! File size: {size_gb:.2f} GB") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
merge_for_comfyui( |
|
|
unet_path="../flux1-depth-dev.safetensors", |
|
|
vae_path="../vae/diffusion_pytorch_model.safetensors", |
|
|
text_encoder_path="../text_encoder/model.safetensors", |
|
|
output_path="../flux1-depth-dev_merged_model.safetensors", |
|
|
model_type="flux" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|