tiny-flux-deep / scripts /convert_v3_to_v4.py
AbstractPhil's picture
Rename convert_v3_to_v4.py to scripts/convert_v3_to_v4.py
d048753 verified
"""
TinyFlux-Deep Weight Converter: v3 → v4
Converts v3 checkpoints to v4.1 architecture without destroying pretrained weights.
Changes from v3 → v4:
- expert_predictor → lune_predictor (rename)
- expert_gate: raw value → logit space (sigmoid(0)=0.5 preserved)
- NEW: sol_prior (attention statistics predictor, 70% geometric prior)
- NEW: t5_pool + text_balance (T5 vec pathway, 50/50 init)
- NEW: spatial_to_mod per attention layer (zero-init = identity)
All new modules initialize to zero-effect, so converted model behaves
identically to v3 on first forward pass.
Colab:
from convert_v3_to_v4 import run
run(401434)
API:
from convert_v3_to_v4 import convert_checkpoint, load_config
config = load_config("path/to/config.json")
result = convert_checkpoint(step=401434, config=config)
CLI:
python convert_v3_to_v4.py --step 401434
python convert_v3_to_v4.py --step 401434 --config my_config.json
"""
__version__ = "4.1.0"
import torch
import torch.nn as nn
import math
import os
import re
import json
from typing import Dict, Tuple, Optional, Union, List
from dataclasses import dataclass, field, asdict
from pathlib import Path
# =============================================================================
# Configuration
# =============================================================================
@dataclass
class TinyFluxConfig:
"""
TinyFlux-Deep v4.1 model configuration.
This config fully defines the model architecture and can be used to:
1. Initialize a new model
2. Convert checkpoints between versions
3. Validate checkpoint compatibility
All dimension constraints are validated on creation.
"""
# Core architecture
hidden_size: int = 512
num_attention_heads: int = 4
attention_head_dim: int = 128
in_channels: int = 16
patch_size: int = 1
joint_attention_dim: int = 768 # T5 sequence dim
pooled_projection_dim: int = 768 # CLIP pooled dim
num_double_layers: int = 15
num_single_layers: int = 25
mlp_ratio: float = 4.0
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
# Lune expert predictor (trajectory guidance)
use_lune_expert: bool = True
lune_expert_dim: int = 1280 # SD1.5 mid-block dimension
lune_hidden_dim: int = 512
lune_dropout: float = 0.1
# Sol attention prior (structural guidance)
use_sol_prior: bool = True
sol_spatial_size: int = 8
sol_hidden_dim: int = 256
sol_geometric_weight: float = 0.7 # 70% geometric, 30% learned
# T5 vec enhancement
use_t5_vec: bool = True
t5_pool_mode: str = "attention" # "attention", "mean", "cls"
# Loss configuration (for training)
lune_distill_mode: str = "cosine" # "hard", "soft", "cosine", "huber"
use_huber_loss: bool = True
huber_delta: float = 0.1
# Legacy
guidance_embeds: bool = False
def __post_init__(self):
"""Validate configuration constraints."""
# Validate attention dimensions
expected_hidden = self.num_attention_heads * self.attention_head_dim
if self.hidden_size != expected_hidden:
raise ValueError(
f"hidden_size ({self.hidden_size}) must equal "
f"num_attention_heads * attention_head_dim ({expected_hidden})"
)
# Validate RoPE dimensions
if isinstance(self.axes_dims_rope, list):
self.axes_dims_rope = tuple(self.axes_dims_rope)
rope_sum = sum(self.axes_dims_rope)
if rope_sum != self.attention_head_dim:
raise ValueError(
f"sum(axes_dims_rope) ({rope_sum}) must equal "
f"attention_head_dim ({self.attention_head_dim})"
)
# Validate sol_geometric_weight
if not 0.0 <= self.sol_geometric_weight <= 1.0:
raise ValueError(f"sol_geometric_weight must be in [0, 1], got {self.sol_geometric_weight}")
# Derived properties for converter compatibility
@property
def time_dim(self) -> int:
return self.hidden_size
@property
def clip_dim(self) -> int:
return self.pooled_projection_dim
@property
def num_heads(self) -> int:
return self.num_attention_heads
@property
def num_double_blocks(self) -> int:
return self.num_double_layers
@property
def num_single_blocks(self) -> int:
return self.num_single_layers
def to_dict(self) -> Dict:
"""Convert to JSON-serializable dict."""
d = asdict(self)
d["axes_dims_rope"] = list(d["axes_dims_rope"])
return d
@classmethod
def from_dict(cls, d: Dict) -> "TinyFluxConfig":
"""Create from dict, ignoring unknown keys."""
# Filter to known fields
known_fields = {f.name for f in cls.__dataclass_fields__.values()}
filtered = {k: v for k, v in d.items() if k in known_fields and not k.startswith("_")}
return cls(**filtered)
def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
"""
Validate that a checkpoint matches this config.
Returns list of warnings (empty if perfect match).
"""
warnings = []
# Check double block count
max_double = 0
for key in state_dict:
if key.startswith("double_blocks."):
idx = int(key.split(".")[1])
max_double = max(max_double, idx + 1)
if max_double != self.num_double_layers:
warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
# Check single block count
max_single = 0
for key in state_dict:
if key.startswith("single_blocks."):
idx = int(key.split(".")[1])
max_single = max(max_single, idx + 1)
if max_single != self.num_single_layers:
warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
# Check hidden size from a known weight
if "img_embed.proj.weight" in state_dict:
w = state_dict["img_embed.proj.weight"]
if w.shape[0] != self.hidden_size:
warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
return warnings
def load_config(path: Union[str, Path]) -> TinyFluxConfig:
"""
Load config from JSON file.
Args:
path: Path to config JSON file
Returns:
TinyFluxConfig instance
"""
with open(path) as f:
d = json.load(f)
return TinyFluxConfig.from_dict(d)
def save_config(config: TinyFluxConfig, path: Union[str, Path], conversion_info: Optional[Dict] = None):
"""
Save config to JSON file.
Args:
config: TinyFluxConfig instance
path: Output path
conversion_info: Optional metadata about conversion
"""
d = config.to_dict()
if conversion_info:
d["_conversion_info"] = conversion_info
with open(path, "w") as f:
json.dump(d, f, indent=2)
# Default configuration
DEFAULT_CONFIG = TinyFluxConfig()
# =============================================================================
# Checkpoint Analysis
# =============================================================================
@dataclass
class CheckpointInfo:
"""Analysis results for a checkpoint."""
version: str = "unknown"
has_expert_predictor: bool = False
has_lune_predictor: bool = False
has_sol_prior: bool = False
has_t5_pool: bool = False
has_spatial_to_mod: bool = False
num_double_blocks: int = 0
num_single_blocks: int = 0
total_params: int = 0
dtype: str = "float32"
def analyze_checkpoint(state_dict: Dict[str, torch.Tensor]) -> CheckpointInfo:
"""
Analyze a checkpoint to determine version and contents.
Args:
state_dict: Model state dictionary
Returns:
CheckpointInfo with analysis results
"""
info = CheckpointInfo()
info.total_params = sum(p.numel() for p in state_dict.values())
# Detect dtype
for v in state_dict.values():
info.dtype = str(v.dtype).replace("torch.", "")
break
for key in state_dict.keys():
if key.startswith("expert_predictor."):
info.has_expert_predictor = True
if key.startswith("lune_predictor."):
info.has_lune_predictor = True
if key.startswith("sol_prior."):
info.has_sol_prior = True
if key.startswith("t5_pool."):
info.has_t5_pool = True
if "spatial_to_mod" in key:
info.has_spatial_to_mod = True
if key.startswith("double_blocks."):
idx = int(key.split(".")[1])
info.num_double_blocks = max(info.num_double_blocks, idx + 1)
if key.startswith("single_blocks."):
idx = int(key.split(".")[1])
info.num_single_blocks = max(info.num_single_blocks, idx + 1)
# Determine version
if info.has_lune_predictor and info.has_sol_prior and info.has_t5_pool:
info.version = "v4.1"
elif info.has_lune_predictor and info.has_sol_prior:
info.version = "v4.0"
elif info.has_expert_predictor:
info.version = "v3"
elif info.has_lune_predictor:
info.version = "v3.5"
else:
info.version = "v2_or_earlier"
return info
# =============================================================================
# Conversion Result
# =============================================================================
@dataclass
class ConversionResult:
"""Results from a conversion operation."""
success: bool
model_path: Optional[str] = None
ema_path: Optional[str] = None
ema_secondary_path: Optional[str] = None
config_path: Optional[str] = None
source_version: str = "unknown"
target_version: str = "v4.1"
source_params: int = 0
target_params: int = 0
params_added: int = 0
error: Optional[str] = None
# =============================================================================
# Colab Entry Point
# =============================================================================
def run(
step: int = 401434,
name: str = "lailah",
output_dir: str = "checkpoint_runs/v4_init",
repo_id: str = "AbstractPhil/tiny-flux-deep",
upload_repo: str = "AbstractPhil/tiny-flux-deep",
upload_subdir: str = "checkpoint_runs/v4_init",
config: Optional[Union[TinyFluxConfig, Dict, str]] = None,
):
"""
One-liner for Colab. Downloads, converts, saves locally, uploads to HF.
Args:
step: Checkpoint step number to download
name: Model name prefix for output files
output_dir: Local output directory
repo_id: HuggingFace repo to download from
upload_repo: HuggingFace repo to upload to
upload_subdir: Subdirectory in upload repo
config: Model config - can be:
- None (use default)
- TinyFluxConfig instance
- Dict with config values
- Path to config JSON file
Usage:
from convert_v3_to_v4 import run
run(401434)
# With custom config
run(401434, config={"hidden_size": 768, ...})
run(401434, config="path/to/config.json")
"""
# Resolve config
if config is None:
cfg = DEFAULT_CONFIG
elif isinstance(config, TinyFluxConfig):
cfg = config
elif isinstance(config, dict):
cfg = TinyFluxConfig.from_dict(config)
elif isinstance(config, (str, Path)):
cfg = load_config(config)
else:
raise TypeError(f"config must be TinyFluxConfig, dict, path, or None, got {type(config)}")
print(f"TinyFlux-Deep v3 → v4.1 Converter")
print(f"=" * 50)
print(f"Config: hidden_size={cfg.hidden_size}, heads={cfg.num_attention_heads}")
print(f" double_layers={cfg.num_double_layers}, single_layers={cfg.num_single_layers}")
result = convert_checkpoint(
step=step,
model_name=name,
output_dir=output_dir,
repo_id=repo_id,
checkpoint_dir="checkpoints",
config=cfg,
verbose=True,
)
if not result.success:
print(f"\n❌ Conversion failed: {result.error}")
return result
print(f"\n✅ Conversion complete!")
print(f" Source: {result.source_version} ({result.source_params:,} params)")
print(f" Target: {result.target_version} ({result.target_params:,} params)")
print(f" Added: {result.params_added:,} params")
# Save config
config_path = os.path.join(output_dir, f"{name}_{step}_v4_config.json")
conversion_info = {
"source_step": step,
"source_repo": repo_id,
"source_version": result.source_version,
"target_version": result.target_version,
"source_params": result.source_params,
"target_params": result.target_params,
"params_added": result.params_added,
"converter_version": __version__,
"files": {
"model": os.path.basename(result.model_path) if result.model_path else None,
"ema": os.path.basename(result.ema_path) if result.ema_path else None,
"ema_secondary": os.path.basename(result.ema_secondary_path) if result.ema_secondary_path else None,
},
}
save_config(cfg, config_path, conversion_info)
result.config_path = config_path
print(f"💾 Config: {config_path}")
# Upload to HuggingFace
from huggingface_hub import HfApi
api = HfApi()
print(f"\n📤 Uploading to {upload_repo}/{upload_subdir}/...")
files_to_upload = [
result.model_path,
result.ema_path,
result.ema_secondary_path,
config_path,
]
for local_path in files_to_upload:
if local_path and os.path.exists(local_path):
filename = os.path.basename(local_path)
remote_path = f"{upload_subdir}/{filename}"
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=remote_path,
repo_id=upload_repo,
)
print(f" ✓ {remote_path}")
print(f"\n✅ Uploaded to {upload_repo}/{upload_subdir}/")
return result
# =============================================================================
# Weight Initialization Functions
# =============================================================================
def to_logit(p: float) -> float:
"""Convert probability to logit for sigmoid init."""
p = max(1e-4, min(p, 1 - 1e-4))
return math.log(p / (1 - p))
def create_sol_prior_init(
config: TinyFluxConfig,
dtype: torch.dtype = torch.float32,
) -> Dict[str, torch.Tensor]:
"""Create zero-effect initialization for SolAttentionPrior."""
init = {}
hidden_dim = config.sol_hidden_dim
time_dim = config.time_dim
clip_dim = config.clip_dim
num_heads = config.num_heads
spatial_size = config.sol_spatial_size
# stat_predictor
w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype)
nn.init.xavier_uniform_(w0, gain=0.1)
init['sol_prior.stat_predictor.0.weight'] = w0
init['sol_prior.stat_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype)
w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype)
nn.init.xavier_uniform_(w1, gain=0.1)
init['sol_prior.stat_predictor.2.weight'] = w1
init['sol_prior.stat_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype)
w2 = torch.empty(3, hidden_dim, dtype=dtype)
nn.init.xavier_uniform_(w2, gain=0.1)
init['sol_prior.stat_predictor.4.weight'] = w2
init['sol_prior.stat_predictor.4.bias'] = torch.zeros(3, dtype=dtype)
# spatial_predictor
w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype)
nn.init.xavier_uniform_(w0, gain=0.1)
init['sol_prior.spatial_predictor.0.weight'] = w0
init['sol_prior.spatial_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype)
w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype)
nn.init.xavier_uniform_(w1, gain=0.1)
init['sol_prior.spatial_predictor.2.weight'] = w1
init['sol_prior.spatial_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype)
w2 = torch.empty(spatial_size * spatial_size, hidden_dim, dtype=dtype)
nn.init.xavier_uniform_(w2, gain=0.1)
init['sol_prior.spatial_predictor.4.weight'] = w2
init['sol_prior.spatial_predictor.4.bias'] = torch.zeros(spatial_size * spatial_size, dtype=dtype)
# stat_to_temperature
w0 = torch.empty(hidden_dim // 2, 3, dtype=dtype)
nn.init.xavier_uniform_(w0, gain=0.1)
init['sol_prior.stat_to_temperature.0.weight'] = w0
init['sol_prior.stat_to_temperature.0.bias'] = torch.zeros(hidden_dim // 2, dtype=dtype)
w1 = torch.empty(num_heads, hidden_dim // 2, dtype=dtype)
nn.init.xavier_uniform_(w1, gain=0.1)
init['sol_prior.stat_to_temperature.2.weight'] = w1
init['sol_prior.stat_to_temperature.2.bias'] = torch.full((num_heads,), 0.54, dtype=dtype)
# spatial_to_qk_scale
init['sol_prior.spatial_to_qk_scale.weight'] = torch.zeros(num_heads, 1, dtype=dtype)
init['sol_prior.spatial_to_qk_scale.bias'] = torch.ones(num_heads, dtype=dtype)
# blend_gate
init['sol_prior.blend_gate'] = torch.tensor(to_logit(config.sol_geometric_weight), dtype=dtype)
return init
def create_t5_pool_init(
config: TinyFluxConfig,
dtype: torch.dtype = torch.float32,
) -> Dict[str, torch.Tensor]:
"""Create initialization for T5 pool pathway."""
init = {}
hidden_size = config.hidden_size
joint_attention_dim = config.joint_attention_dim
w1 = torch.empty(hidden_size, joint_attention_dim, dtype=dtype)
nn.init.xavier_uniform_(w1)
init['t5_pool.0.weight'] = w1
init['t5_pool.0.bias'] = torch.zeros(hidden_size, dtype=dtype)
w2 = torch.empty(hidden_size, hidden_size, dtype=dtype)
nn.init.xavier_uniform_(w2)
init['t5_pool.2.weight'] = w2
init['t5_pool.2.bias'] = torch.zeros(hidden_size, dtype=dtype)
init['text_balance'] = torch.tensor(0.0, dtype=dtype)
return init
def create_spatial_to_mod_init(
num_heads: int = 4,
dtype: torch.dtype = torch.float32,
) -> Dict[str, torch.Tensor]:
"""Create zero-init for spatial_to_mod Conv2d layers."""
return {
'weight': torch.zeros(num_heads, 1, 1, 1, dtype=dtype),
'bias': torch.zeros(num_heads, dtype=dtype),
}
def convert_state_dict(
v3_state: Dict[str, torch.Tensor],
config: Optional[TinyFluxConfig] = None,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, any]]:
"""
Convert v3 state dict to v4.1 format.
Args:
v3_state: v3 state dictionary
config: TinyFluxConfig (uses DEFAULT_CONFIG if None)
Returns:
Tuple of (v4_state_dict, report_dict)
"""
cfg = config or DEFAULT_CONFIG
v3_info = analyze_checkpoint(v3_state)
if v3_info.version in ("v4.0", "v4.1"):
return v3_state, {'status': 'already_v4', 'source_version': v3_info.version}
# Validate config matches checkpoint structure
warnings = cfg.validate_checkpoint(v3_state)
if warnings:
print(f"⚠️ Config validation warnings:")
for w in warnings:
print(f" - {w}")
sample_key = list(v3_state.keys())[0]
dtype = v3_state[sample_key].dtype
report = {
'status': 'converted',
'source_version': v3_info.version,
'source_params': v3_info.total_params,
'renamed': [],
'initialized': [],
'modified': [],
'warnings': warnings,
}
v4_state = {}
# Step 1: Rename expert_predictor → lune_predictor
for key, value in v3_state.items():
if key.startswith('expert_predictor.'):
new_key = key.replace('expert_predictor.', 'lune_predictor.')
v4_state[new_key] = value
report['renamed'].append((key, new_key))
else:
v4_state[key] = value
# Step 2: Fix expert_gate value (raw → logit space)
gate_key = 'lune_predictor.expert_gate'
if gate_key in v4_state:
old_val = v4_state[gate_key].item()
if abs(old_val - 0.5) < 0.3: # Looks like raw probability, not logit
new_val = to_logit(old_val)
v4_state[gate_key] = torch.tensor(new_val, dtype=dtype)
report['modified'].append((gate_key, f'{old_val:.4f}{new_val:.4f}'))
# Step 3: Initialize SolAttentionPrior (if missing)
if not v3_info.has_sol_prior and cfg.use_sol_prior:
sol_init = create_sol_prior_init(cfg, dtype)
v4_state.update(sol_init)
report['initialized'].extend(list(sol_init.keys()))
# Step 4: Initialize T5 pool (if missing)
if not v3_info.has_t5_pool and cfg.use_t5_vec:
t5_init = create_t5_pool_init(cfg, dtype)
v4_state.update(t5_init)
report['initialized'].extend(list(t5_init.keys()))
# Step 5: Initialize spatial_to_mod in attention layers (if missing)
if not v3_info.has_spatial_to_mod and cfg.use_sol_prior:
spatial_init = create_spatial_to_mod_init(cfg.num_heads, dtype)
for i in range(cfg.num_double_blocks):
prefix = f'double_blocks.{i}.attn.spatial_to_mod.'
v4_state[prefix + 'weight'] = spatial_init['weight'].clone()
v4_state[prefix + 'bias'] = spatial_init['bias'].clone()
report['initialized'].extend([prefix + 'weight', prefix + 'bias'])
for i in range(cfg.num_single_blocks):
prefix = f'single_blocks.{i}.attn.spatial_to_mod.'
v4_state[prefix + 'weight'] = spatial_init['weight'].clone()
v4_state[prefix + 'bias'] = spatial_init['bias'].clone()
report['initialized'].extend([prefix + 'weight', prefix + 'bias'])
report['target_params'] = sum(p.numel() for p in v4_state.values())
report['params_added'] = report['target_params'] - report['source_params']
return v4_state, report
# =============================================================================
# High-Level API
# =============================================================================
def download_from_hf(
step: int,
repo_id: str = "AbstractPhil/tiny-flux-deep",
checkpoint_dir: str = "checkpoints",
local_dir: str = "./downloads",
include_ema: bool = True,
) -> Tuple[str, Optional[str]]:
"""
Download checkpoint from HuggingFace.
Args:
step: Step number to download
repo_id: HuggingFace repository ID
checkpoint_dir: Subdirectory in repo containing checkpoints
local_dir: Local directory to download to
include_ema: Whether to also download EMA weights
Returns:
Tuple of (model_path, ema_path). ema_path may be None.
"""
from huggingface_hub import hf_hub_download
model_filename = f"{checkpoint_dir}/step_{step}.safetensors"
model_path = hf_hub_download(
repo_id=repo_id,
filename=model_filename,
local_dir=local_dir,
)
ema_path = None
if include_ema:
ema_filename = f"{checkpoint_dir}/step_{step}_ema.safetensors"
try:
ema_path = hf_hub_download(
repo_id=repo_id,
filename=ema_filename,
local_dir=local_dir,
)
except Exception:
pass
return model_path, ema_path
def convert_checkpoint(
step: Optional[int] = None,
input_path: Optional[str] = None,
ema_input_path: Optional[str] = None,
output_dir: str = "checkpoint_runs/v4_init",
model_name: str = "lailah",
repo_id: str = "AbstractPhil/tiny-flux-deep",
checkpoint_dir: str = "checkpoints",
create_fresh_ema: bool = True,
preserve_secondary_ema: bool = True,
config: Optional[TinyFluxConfig] = None,
verbose: bool = True,
) -> ConversionResult:
"""
Convert a v3 checkpoint to v4.1 format.
Either `step` (to download from HF) or `input_path` (for local file) must be provided.
Args:
step: Step number to download from HuggingFace
input_path: Path to local v3 checkpoint
ema_input_path: Path to local v3 EMA checkpoint
output_dir: Directory to save converted checkpoints
model_name: Prefix for output filenames
repo_id: HuggingFace repository ID (if using step)
checkpoint_dir: Subdirectory in repo (if using step)
create_fresh_ema: Create a fresh EMA from converted weights
preserve_secondary_ema: Convert and preserve old EMA as secondary
config: TinyFluxConfig for model architecture
verbose: Print progress messages
Returns:
ConversionResult with paths and statistics
"""
from safetensors.torch import load_file, save_file
cfg = config or DEFAULT_CONFIG
result = ConversionResult(success=False)
try:
# Get checkpoint paths
if step is not None:
if verbose:
print(f"📥 Downloading step_{step} from {repo_id}...")
model_path, ema_path = download_from_hf(
step=step,
repo_id=repo_id,
checkpoint_dir=checkpoint_dir,
)
if verbose:
print(f" ✓ Model: {model_path}")
if ema_path:
print(f" ✓ EMA: {ema_path}")
elif input_path is not None:
model_path = input_path
ema_path = ema_input_path
match = re.search(r'step_(\d+)', model_path)
step = int(match.group(1)) if match else 0
else:
result.error = "Must provide either step or input_path"
return result
# Load and convert
if verbose:
print(f"\n🔄 Converting to v4.1...")
v3_state = load_file(model_path)
v4_state, report = convert_state_dict(v3_state, cfg)
result.source_version = report['source_version']
result.target_version = "v4.1"
result.source_params = report.get('source_params', 0)
result.target_params = report.get('target_params', 0)
result.params_added = report.get('params_added', 0)
if verbose:
print(f" Source: {result.source_version} ({result.source_params:,} params)")
print(f" Target: {result.target_version} ({result.target_params:,} params)")
print(f" Added: {result.params_added:,} params")
# Save outputs
os.makedirs(output_dir, exist_ok=True)
# Main model
model_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init.safetensors")
save_file(v4_state, model_out)
result.model_path = model_out
if verbose:
print(f"\n💾 Model: {model_out}")
# Fresh EMA
if create_fresh_ema:
ema_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema.safetensors")
save_file(v4_state, ema_out)
result.ema_path = ema_out
if verbose:
print(f"💾 EMA (fresh): {ema_out}")
# Secondary EMA
if preserve_secondary_ema and ema_path and os.path.exists(ema_path):
if verbose:
print(f"\n🔄 Converting old EMA...")
try:
old_ema_state = load_file(ema_path)
old_ema_v4, _ = convert_state_dict(old_ema_state, cfg)
ema_secondary_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema_secondary.safetensors")
save_file(old_ema_v4, ema_secondary_out)
result.ema_secondary_path = ema_secondary_out
if verbose:
print(f"💾 EMA (secondary): {ema_secondary_out}")
except Exception as e:
if verbose:
print(f"⚠ Failed to convert old EMA: {e}")
result.success = True
except Exception as e:
result.error = str(e)
if verbose:
print(f"❌ Error: {e}")
return result
# =============================================================================
# CLI Interface
# =============================================================================
def create_parser():
"""Create argument parser for CLI."""
import argparse
parser = argparse.ArgumentParser(
description='Convert TinyFlux-Deep v3 checkpoints to v4 format',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python convert_v3_to_v4.py --step 401434
python convert_v3_to_v4.py --input model_v3.safetensors
python convert_v3_to_v4.py --step 401434 --analyze-only
python convert_v3_to_v4.py --step 401434 --output-dir my_converted --name mymodel
"""
)
# Input
input_group = parser.add_argument_group('Input (one required)')
input_group.add_argument('--step', type=int, help='Step number to download from HuggingFace')
input_group.add_argument('--input', '-i', dest='input_path', help='Path to local v3 checkpoint')
input_group.add_argument('--ema-input', dest='ema_input_path', help='Path to local v3 EMA checkpoint')
# HuggingFace
hf_group = parser.add_argument_group('HuggingFace options')
hf_group.add_argument('--repo', default='AbstractPhil/tiny-flux-deep', help='HuggingFace repo ID')
hf_group.add_argument('--checkpoint-dir', default='checkpoints', help='Subdirectory in repo')
# Output
output_group = parser.add_argument_group('Output options')
output_group.add_argument('--output-dir', '-o', default='checkpoint_runs/v4_init', help='Output directory')
output_group.add_argument('--name', default='lailah', help='Model name prefix')
# Conversion
conv_group = parser.add_argument_group('Conversion options')
conv_group.add_argument('--no-fresh-ema', action='store_true', help='Do not create fresh EMA')
conv_group.add_argument('--no-secondary-ema', action='store_true', help='Do not preserve old EMA')
conv_group.add_argument('--analyze-only', action='store_true', help='Only analyze, do not convert')
conv_group.add_argument('--quiet', '-q', action='store_true', help='Suppress progress messages')
return parser
def cli_main():
"""CLI entry point."""
parser = create_parser()
args = parser.parse_args()
if not args.step and not args.input_path:
parser.error("Must specify either --step or --input")
# Analyze only
if args.analyze_only:
from safetensors.torch import load_file
if args.step:
model_path, _ = download_from_hf(
step=args.step,
repo_id=args.repo,
checkpoint_dir=args.checkpoint_dir,
)
else:
model_path = args.input_path
state = load_file(model_path)
info = analyze_checkpoint(state)
print(f"\nCheckpoint: {model_path}")
print(f" Version: {info.version}")
print(f" Total params: {info.total_params:,}")
print(f" Double blocks: {info.num_double_blocks}")
print(f" Single blocks: {info.num_single_blocks}")
print(f" Has expert_predictor: {info.has_expert_predictor}")
print(f" Has lune_predictor: {info.has_lune_predictor}")
print(f" Has sol_prior: {info.has_sol_prior}")
print(f" Has t5_pool: {info.has_t5_pool}")
print(f" Has spatial_to_mod: {info.has_spatial_to_mod}")
return
# Convert
result = convert_checkpoint(
step=args.step,
input_path=args.input_path,
ema_input_path=args.ema_input_path,
output_dir=args.output_dir,
model_name=args.name,
repo_id=args.repo,
checkpoint_dir=args.checkpoint_dir,
create_fresh_ema=not args.no_fresh_ema,
preserve_secondary_ema=not args.no_secondary_ema,
verbose=not args.quiet,
)
if result.success:
if not args.quiet:
print("\n" + "=" * 60)
print("✅ Conversion complete!")
print("=" * 60)
print(f"\nOutput files:")
if result.model_path:
print(f" Model: {result.model_path}")
if result.ema_path:
print(f" EMA: {result.ema_path}")
if result.ema_secondary_path:
print(f" EMA (secondary): {result.ema_secondary_path}")
else:
print(f"\n❌ Conversion failed: {result.error}")
exit(1)
if __name__ == '__main__':
cli_main()