Update checkpoints according to diffusers integration

#10
by guiyrt - opened

As part of the integration to diffusers (see PR), we can simplify things changing the ip-adapter state dict keys here, instead of adding extra conversion code. The "ip-adapter.bin" file is updated to conform with the integration, and I also uploaded the equivalent in parameters in safetensors. The PR is nearly finished, so shouldn't be long until the integration is merged. The new checkpoints were obtained with the following code:

import torch
from safetensors.torch import save_file
from diffusers.utils import _get_model_file
from diffusers.models.modeling_utils import load_state_dict

# Get weights from the hub
model_file = _get_model_file(
    "InstantX/SD3.5-Large-IP-Adapter",
    weights_name="ip-adapter.bin",
)
state_dict = load_state_dict(model_file)

# ip_adapter stays the same
safetensors_dict = {f"ip_adapter.{key}": value for key, value in state_dict["ip_adapter"].items()}

# image_proj to diffusers
for key, value in state_dict["image_proj"].items():
    if key.startswith("layers."):
        idx = key.split(".")[1]
        key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
        key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
        key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
        key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
        key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
        key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
        key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
        key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
        key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
    safetensors_dict[f"image_proj.{key}"] = value

# Save safetensors
save_file(safetensors_dict, "ip-adapter.safetensors")


# ip_adapter stays the same
torch_new_dict = {"ip_adapter": state_dict["ip_adapter"], "image_proj": {}}

# image_proj to diffusers
for key, value in state_dict["image_proj"].items():
    if key.startswith("layers."):
        idx = key.split(".")[1]
        key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
        key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
        key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
        key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
        key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
        key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
        key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
        key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
        key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
    torch_new_dict["image_proj"][key] = value

# Save torch pickle
torch.save(torch_new_dict, "ip-adapter.bin")
guiyrt changed pull request title from Upload 2 files to Update checkpoints according to diffusers integration
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment