Implementation of diff. merge

#2
by alfredplpl - opened

Great work!

According to you said, I implemented the diff. merge:

import torch
from diffusers import StableDiffusionXLPipeline
from safetensors.torch import load_file, save_file
from huggingface_hub import hf_hub_download

w=1.3

ssd2sdxl={
# your gist
}

file_name=hf_hub_download(repo_id="alfredplpl/emi", subfolder="unet", filename="diffusion_pytorch_model.safetensors")
emi = load_file(file_name)

file_name=hf_hub_download(repo_id="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", filename="diffusion_pytorch_model.safetensors")
sdxl = load_file(file_name)

file_name=hf_hub_download(repo_id="segmind/SSD-1B", subfolder="unet", filename="diffusion_pytorch_model.fp16.safetensors")
ssd = load_file(file_name)

diff=sdxl
for k,v in diff.items():
    diff[k]=emi[k]-v

for ssd_key,sdxl_key in ssd2sdxl.items():
    for to_x in [".to_q",".to_k",".to_v",".to_out.0"]:
        diff[sdxl_key+to_x+".weight"]=ssd[ssd_key+to_x+".weight"]+w*diff[sdxl_key+to_x+".weight"]


emi_mid=StableDiffusionXLPipeline.from_pretrained("alfredplpl/emi",torch_dtype=torch.float32)
emi_mid.save_pretrained('/mnt/NVM/demi_mid' )
save_file(diff, '/mnt/NVM/demi_mid/unet/diffusion_pytorch_model.safetensors' )

Is it right?
Thanks in advance.

Merge code is bellow.
Note that transformer is not only q,k,v,o.

import json
with open("ssd2sdxl.json","r") as f:
    ssd2sdxl = json.load(f)

new_dic = {}
def replace_string_using_dict(s, replacement_dict):
    for key, value in replacement_dict.items():
        s = s.replace(key, value)
    return s

for key in ssd:
    sdxl_key = replace_string_using_dict(key, ssd2sdxl)
    new_dic[key] = ssd[key] + 1.3 * (neko[sdxl_key] - sdxl[sdxl_key])

May I ask how good the merges are without any training?

Amazing!

I could merge the models!

for key in ssd:
    sdxl_key = replace_string_using_dict(key, ssd2sdxl)
    new_dic[key] = ssd[key] + 2.0 * (emi[sdxl_key] - sdxl[sdxl_key])

emi_mid=StableDiffusionXLPipeline.from_pretrained("segmind/SSD-1B", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
emi_mid.save_pretrained('/mnt/NVM/demi_mid' )
save_file(new_dic, '/mnt/NVM/demi_mid/unet/diffusion_pytorch_model.safetensors'  )

It is very fast.

girl.png

I will continue to do distillation it or to fine tune it.

@Icar Your turn

It is not all that bad apparently with just a direct merge. Finetuning / distillation should be faster with this model.

FYI: I success to finetune it.

girl.png

I will close this discussion.

alfredplpl changed discussion status to closed

Sign up or log in to comment