Spaces:
Running
on
A100
Running
on
A100
import argparse | |
from pathlib import Path | |
from typing import Any, Dict | |
import safetensors.torch | |
import torch | |
import json | |
import shutil | |
def load_text_encoder(index_path: Path) -> Dict: | |
with open(index_path, 'r') as f: | |
index: Dict = json.load(f) | |
loaded_tensors = {} | |
for part_file in set(index.get("weight_map", {}).values()): | |
tensors = safetensors.torch.load_file(index_path.parent / part_file, device='cpu') | |
for tensor_name in tensors: | |
loaded_tensors[tensor_name] = tensors[tensor_name] | |
return loaded_tensors | |
def convert_unet(unet: Dict, add_prefix=True) -> Dict: | |
if add_prefix: | |
return {"model.diffusion_model." + key: value for key, value in unet.items()} | |
return unet | |
def convert_vae(vae_path: Path, add_prefix=True) -> Dict: | |
state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True) | |
stats_path = vae_path / "per_channel_statistics.json" | |
if stats_path.exists(): | |
with open(stats_path, 'r') as f: | |
data = json.load(f) | |
transposed_data = list(zip(*data["data"])) | |
data_dict = { | |
f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(vals) | |
for col, vals in zip(data["columns"], transposed_data) | |
} | |
else: | |
data_dict = {} | |
result = {("vae." if add_prefix else "") + key: value for key, value in state_dict.items()} | |
result.update(data_dict) | |
return result | |
def convert_encoder(encoder: Dict) -> Dict: | |
return {"text_encoders.t5xxl.transformer." + key: value for key, value in encoder.items()} | |
def save_config(config_src: str, config_dst: str): | |
shutil.copy(config_src, config_dst) | |
def load_vae_config(vae_path: Path) -> str: | |
config_path = vae_path / "config.json" | |
if not config_path.exists(): | |
raise FileNotFoundError(f"VAE config file {config_path} not found.") | |
return str(config_path) | |
def main(unet_path: str, vae_path: str, out_path: str, mode: str, | |
unet_config_path: str = None, scheduler_config_path: str = None) -> None: | |
unet = convert_unet(torch.load(unet_path, weights_only=True), add_prefix=(mode == 'single')) | |
# Load VAE from directory and config | |
vae = convert_vae(Path(vae_path), add_prefix=(mode == 'single')) | |
vae_config_path = load_vae_config(Path(vae_path)) | |
if mode == 'single': | |
result = {**unet, **vae} | |
safetensors.torch.save_file(result, out_path) | |
elif mode == 'separate': | |
# Create directories for unet, vae, and scheduler | |
unet_dir = Path(out_path) / 'unet' | |
vae_dir = Path(out_path) / 'vae' | |
scheduler_dir = Path(out_path) / 'scheduler' | |
unet_dir.mkdir(parents=True, exist_ok=True) | |
vae_dir.mkdir(parents=True, exist_ok=True) | |
scheduler_dir.mkdir(parents=True, exist_ok=True) | |
# Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors | |
safetensors.torch.save_file(unet, unet_dir / 'diffusion_pytorch_model.safetensors') | |
safetensors.torch.save_file(vae, vae_dir / 'diffusion_pytorch_model.safetensors') | |
# Save config files for unet, vae, and scheduler | |
if unet_config_path: | |
save_config(unet_config_path, unet_dir / 'config.json') | |
if vae_config_path: | |
save_config(vae_config_path, vae_dir / 'config.json') | |
if scheduler_config_path: | |
save_config(scheduler_config_path, scheduler_dir / 'scheduler_config.json') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--unet_path', '-u', type=str, default='unet/ema-002.pt') | |
parser.add_argument('--vae_path', '-v', type=str, default='vae/') | |
parser.add_argument('--out_path', '-o', type=str, default='xora.safetensors') | |
parser.add_argument('--mode', '-m', type=str, choices=['single', 'separate'], default='single', | |
help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.") | |
parser.add_argument('--unet_config_path', type=str, help="Path to the UNet config file (for separate mode)") | |
parser.add_argument('--scheduler_config_path', type=str, | |
help="Path to the Scheduler config file (for separate mode)") | |
args = parser.parse_args() | |
main(**args.__dict__) | |