Hello! Thanks for trying out our model and finetuning/distilling upon it! We're planning on releasing our code once we get our technical paper out. Might I know more about what your loss strategy exactly was and your experience with the outputs?
Also, nice idea with the cosine similarities! Might I ask what you compared and how similar they are?

I attempted a merge based on the following formula.
NewModel = SSD-1B + w * ( FINETUNEDSDXL - SDXL)

However, since the correspondence was not known, especially for layers with reduced transformer_depth, a comparison was made in terms of cosine similarity.

import torch
from torch.nn.functional import cosine_similarity
import numpy as np

sdxl = torch.load("") # diffusers unet
ssd = torch.load("")

sdxl_keys = [key.replace("to_q.weight", "") for key in sdxl.keys() if "to_q.weight" in key]
ssd_keys = [key.replace("to_q.weight", "") for key in ssd.keys()  if "to_q.weight" in key]

ssd2sdxl = {
    "to_q": {},
    "to_k": {},
    "to_v": {},
    "to_out.0": {}

for to_x in ssd2sdxl.keys():
    for ssd_key in ssd_keys:
        sims = []
        target = ssd[ssd_key+to_x+".weight"]
        for sdxl_key in sdxl_keys:
            if target.shape == sdxl[sdxl_key+to_x+".weight"].shape:
                sims.append(cosine_similarity(target.view(1,-1), sdxl[sdxl_key+to_x+".weight"].view(1,-1)).item())
        ssd2sdxl[to_x][ssd_key] = sdxl_keys[np.array(sims).argmax()]

print(ssd2sdxl["to_q"] == ssd2sdxl["to_k"] == ssd2sdxl["to_v"] == ssd2sdxl["to_out.0"]) # True


The output is here
The results for up_blocks.0.attentions.2 were odd, so I changed them manually.

Since w=1 had little effect and w=1.5 resulted in a coarser image, w=1.3 was used.

To further improve accuracy, I have distilled the model down to the original model. The only loss is the squared error of the final output.
The data set consists of 30,000 actual images.

I don't know about a detailed performance comparison, but I believe this method is superior to distilling or fine-tuning from scratch.

Ah, nice observations. Arigato!

