test / modules /merging /merge_rebasin.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
# https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion
from collections import defaultdict
from random import shuffle
from typing import NamedTuple
import torch
from scipy.optimize import linear_sum_assignment
from modules.shared import log
SPECIAL_KEYS = [
"first_stage_model.decoder.norm_out.weight",
"first_stage_model.decoder.norm_out.bias",
"first_stage_model.encoder.norm_out.weight",
"first_stage_model.encoder.norm_out.bias",
"model.diffusion_model.out.0.weight",
"model.diffusion_model.out.0.bias",
]
class PermutationSpec(NamedTuple):
perm_to_axes: dict
axes_to_perm: dict
def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
perm_to_axes = defaultdict(list)
for wk, axis_perms in axes_to_perm.items():
for axis, perm in enumerate(axis_perms):
if perm is not None:
perm_to_axes[perm].append((wk, axis))
return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)
def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
"""Get parameter `k` from `params`, with the permutations applied."""
w = params[k]
for axis, p in enumerate(ps.axes_to_perm[k]):
# Skip the axis we're trying to permute.
if axis == except_axis:
continue
# None indicates that there is no permutation relevant to that axis.
if p:
w = torch.index_select(w, axis, perm[p].int())
return w
def apply_permutation(ps: PermutationSpec, perm, params):
"""Apply a `perm` to `params`."""
return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}
def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha):
for k in model_a:
try:
perm_params = get_permuted_param(
ps, perm, k, model_a
)
model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params
except RuntimeError: # dealing with pix2pix and inpainting models
continue
return model_a
def inner_matching(
n,
ps,
p,
params_a,
params_b,
usefp16,
progress,
number,
linear_sum,
perm,
device,
):
A = torch.zeros((n, n), dtype=torch.float16) if usefp16 else torch.zeros((n, n))
A = A.to(device)
for wk, axis in ps.perm_to_axes[p]:
w_a = params_a[wk]
w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device)
w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device)
if usefp16:
w_a = w_a.half().to(device)
w_b = w_b.half().to(device)
try:
A += torch.matmul(w_a, w_b)
except RuntimeError:
A += torch.matmul(torch.dequantize(w_a), torch.dequantize(w_b))
A = A.cpu()
ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
A = A.to(device)
assert (torch.tensor(ri) == torch.arange(len(ri))).all()
eye_tensor = torch.eye(n).to(device)
oldL = torch.vdot(
torch.flatten(A).float(), torch.flatten(eye_tensor[perm[p].long()])
)
newL = torch.vdot(torch.flatten(A).float(), torch.flatten(eye_tensor[ci, :]))
if usefp16:
oldL = oldL.half()
newL = newL.half()
if newL - oldL != 0:
linear_sum += abs((newL - oldL).item())
number += 1
log.debug(f"Merge Rebasin permutation: {p}={newL-oldL}")
progress = progress or newL > oldL + 1e-12
perm[p] = torch.Tensor(ci).to(device)
return linear_sum, number, perm, progress
def weight_matching(
ps: PermutationSpec,
params_a,
params_b,
max_iter=1,
init_perm=None,
usefp16=False,
device="cpu",
):
perm_sizes = {
p: params_a[axes[0][0]].shape[axes[0][1]]
for p, axes in ps.perm_to_axes.items()
if axes[0][0] in params_a.keys()
}
perm = {}
perm = (
{p: torch.arange(n).to(device) for p, n in perm_sizes.items()}
if init_perm is None
else init_perm
)
linear_sum = 0
number = 0
special_layers = ["P_bg324"]
for _i in range(max_iter):
progress = False
shuffle(special_layers)
for p in special_layers:
n = perm_sizes[p]
linear_sum, number, perm, progress = inner_matching(
n,
ps,
p,
params_a,
params_b,
usefp16,
progress,
number,
linear_sum,
perm,
device,
)
progress = True
if not progress:
break
average = linear_sum / number if number > 0 else 0
return perm, average