Spaces:
Running
Running
| import argparse | |
| from pathlib import Path | |
| from typing import Dict | |
| import safetensors.torch | |
| import torch | |
| import json | |
| import shutil | |
| def load_text_encoder(index_path: Path) -> Dict: | |
| with open(index_path, "r") as f: | |
| index: Dict = json.load(f) | |
| loaded_tensors = {} | |
| for part_file in set(index.get("weight_map", {}).values()): | |
| tensors = safetensors.torch.load_file( | |
| index_path.parent / part_file, device="cpu" | |
| ) | |
| for tensor_name in tensors: | |
| loaded_tensors[tensor_name] = tensors[tensor_name] | |
| return loaded_tensors | |
| def convert_unet(unet: Dict, add_prefix=True) -> Dict: | |
| if add_prefix: | |
| return {"model.diffusion_model." + key: value for key, value in unet.items()} | |
| return unet | |
| def convert_vae(vae_path: Path, add_prefix=True) -> Dict: | |
| state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True) | |
| stats_path = vae_path / "per_channel_statistics.json" | |
| if stats_path.exists(): | |
| with open(stats_path, "r") as f: | |
| data = json.load(f) | |
| transposed_data = list(zip(*data["data"])) | |
| data_dict = { | |
| f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor( | |
| vals | |
| ) | |
| for col, vals in zip(data["columns"], transposed_data) | |
| } | |
| else: | |
| data_dict = {} | |
| result = { | |
| ("vae." if add_prefix else "") + key: value for key, value in state_dict.items() | |
| } | |
| result.update(data_dict) | |
| return result | |
| def convert_encoder(encoder: Dict) -> Dict: | |
| return { | |
| "text_encoders.t5xxl.transformer." + key: value | |
| for key, value in encoder.items() | |
| } | |
| def save_config(config_src: str, config_dst: str): | |
| shutil.copy(config_src, config_dst) | |
| def load_vae_config(vae_path: Path) -> str: | |
| config_path = vae_path / "config.json" | |
| if not config_path.exists(): | |
| raise FileNotFoundError(f"VAE config file {config_path} not found.") | |
| return str(config_path) | |
| def main( | |
| unet_path: str, | |
| vae_path: str, | |
| out_path: str, | |
| mode: str, | |
| unet_config_path: str = None, | |
| scheduler_config_path: str = None, | |
| ) -> None: | |
| unet = convert_unet( | |
| torch.load(unet_path, weights_only=True), add_prefix=(mode == "single") | |
| ) | |
| # Load VAE from directory and config | |
| vae = convert_vae(Path(vae_path), add_prefix=(mode == "single")) | |
| vae_config_path = load_vae_config(Path(vae_path)) | |
| if mode == "single": | |
| result = {**unet, **vae} | |
| safetensors.torch.save_file(result, out_path) | |
| elif mode == "separate": | |
| # Create directories for unet, vae, and scheduler | |
| unet_dir = Path(out_path) / "unet" | |
| vae_dir = Path(out_path) / "vae" | |
| scheduler_dir = Path(out_path) / "scheduler" | |
| unet_dir.mkdir(parents=True, exist_ok=True) | |
| vae_dir.mkdir(parents=True, exist_ok=True) | |
| scheduler_dir.mkdir(parents=True, exist_ok=True) | |
| # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors | |
| safetensors.torch.save_file( | |
| unet, unet_dir / "unet_diffusion_pytorch_model.safetensors" | |
| ) | |
| safetensors.torch.save_file( | |
| vae, vae_dir / "vae_diffusion_pytorch_model.safetensors" | |
| ) | |
| # Save config files for unet, vae, and scheduler | |
| if unet_config_path: | |
| save_config(unet_config_path, unet_dir / "config.json") | |
| if vae_config_path: | |
| save_config(vae_config_path, vae_dir / "config.json") | |
| if scheduler_config_path: | |
| save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt") | |
| parser.add_argument("--vae_path", "-v", type=str, default="vae/") | |
| parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors") | |
| parser.add_argument( | |
| "--mode", | |
| "-m", | |
| type=str, | |
| choices=["single", "separate"], | |
| default="single", | |
| help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.", | |
| ) | |
| parser.add_argument( | |
| "--unet_config_path", | |
| type=str, | |
| help="Path to the UNet config file (for separate mode)", | |
| ) | |
| parser.add_argument( | |
| "--scheduler_config_path", | |
| type=str, | |
| help="Path to the Scheduler config file (for separate mode)", | |
| ) | |
| args = parser.parse_args() | |
| main(**args.__dict__) | |