# https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion from collections import defaultdict from random import shuffle from typing import NamedTuple import torch from scipy.optimize import linear_sum_assignment from modules.shared import log SPECIAL_KEYS = [ "first_stage_model.decoder.norm_out.weight", "first_stage_model.decoder.norm_out.bias", "first_stage_model.encoder.norm_out.weight", "first_stage_model.encoder.norm_out.bias", "model.diffusion_model.out.0.weight", "model.diffusion_model.out.0.bias", ] class PermutationSpec(NamedTuple): perm_to_axes: dict axes_to_perm: dict def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec: perm_to_axes = defaultdict(list) for wk, axis_perms in axes_to_perm.items(): for axis, perm in enumerate(axis_perms): if perm is not None: perm_to_axes[perm].append((wk, axis)) return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm) def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None): """Get parameter `k` from `params`, with the permutations applied.""" w = params[k] for axis, p in enumerate(ps.axes_to_perm[k]): # Skip the axis we're trying to permute. if axis == except_axis: continue # None indicates that there is no permutation relevant to that axis. if p: w = torch.index_select(w, axis, perm[p].int()) return w def apply_permutation(ps: PermutationSpec, perm, params): """Apply a `perm` to `params`.""" return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()} def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha): for k in model_a: try: perm_params = get_permuted_param( ps, perm, k, model_a ) model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params except RuntimeError: # dealing with pix2pix and inpainting models continue return model_a def inner_matching( n, ps, p, params_a, params_b, usefp16, progress, number, linear_sum, perm, device, ): A = torch.zeros((n, n), dtype=torch.float16) if usefp16 else torch.zeros((n, n)) A = A.to(device) for wk, axis in ps.perm_to_axes[p]: w_a = params_a[wk] w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device) w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device) if usefp16: w_a = w_a.half().to(device) w_b = w_b.half().to(device) try: A += torch.matmul(w_a, w_b) except RuntimeError: A += torch.matmul(torch.dequantize(w_a), torch.dequantize(w_b)) A = A.cpu() ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True) A = A.to(device) assert (torch.tensor(ri) == torch.arange(len(ri))).all() eye_tensor = torch.eye(n).to(device) oldL = torch.vdot( torch.flatten(A).float(), torch.flatten(eye_tensor[perm[p].long()]) ) newL = torch.vdot(torch.flatten(A).float(), torch.flatten(eye_tensor[ci, :])) if usefp16: oldL = oldL.half() newL = newL.half() if newL - oldL != 0: linear_sum += abs((newL - oldL).item()) number += 1 log.debug(f"Merge Rebasin permutation: {p}={newL-oldL}") progress = progress or newL > oldL + 1e-12 perm[p] = torch.Tensor(ci).to(device) return linear_sum, number, perm, progress def weight_matching( ps: PermutationSpec, params_a, params_b, max_iter=1, init_perm=None, usefp16=False, device="cpu", ): perm_sizes = { p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items() if axes[0][0] in params_a.keys() } perm = {} perm = ( {p: torch.arange(n).to(device) for p, n in perm_sizes.items()} if init_perm is None else init_perm ) linear_sum = 0 number = 0 special_layers = ["P_bg324"] for _i in range(max_iter): progress = False shuffle(special_layers) for p in special_layers: n = perm_sizes[p] linear_sum, number, perm, progress = inner_matching( n, ps, p, params_a, params_b, usefp16, progress, number, linear_sum, perm, device, ) progress = True if not progress: break average = linear_sum / number if number > 0 else 0 return perm, average