import math from typing import Tuple, Union, Optional def make_unet_conversion_map(): unet_conversion_map_layer = [] for i in range(3): # num_blocks is 3 in sdxl # loop over downblocks/upblocks for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." sd_up_res_prefix = f"output_blocks.{3*i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) # if i > 0: commentout for sdxl # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." sd_mid_atn_prefix = "middle_block.1." unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." sd_mid_res_prefix = f"middle_block.{2*j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_resnet = [ # (stable-diffusion, HF Diffusers) ("in_layers.0.", "norm1."), ("in_layers.2.", "conv1."), ("out_layers.0.", "norm2."), ("out_layers.3.", "conv2."), ("emb_layers.1.", "time_emb_proj."), ("skip_connection.", "conv_shortcut."), ] unet_conversion_map = [] for sd, hf in unet_conversion_map_layer: if "resnets" in hf: for sd_res, hf_res in unet_conversion_map_resnet: unet_conversion_map.append((sd + sd_res, hf + hf_res)) else: unet_conversion_map.append((sd, hf)) for j in range(2): hf_time_embed_prefix = f"time_embedding.linear_{j+1}." sd_time_embed_prefix = f"time_embed.{j*2}." unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) for j in range(2): hf_label_embed_prefix = f"add_embedding.linear_{j+1}." sd_label_embed_prefix = f"label_emb.0.{j*2}." unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) unet_conversion_map.append(("out.0.", "conv_norm_out.")) unet_conversion_map.append(("out.2.", "conv_out.")) return unet_conversion_map def convert_unet_state_dict(src_sd, conversion_map): converted_sd = {} for src_key, value in src_sd.items(): src_key_fragments = src_key.split(".")[:-1] # remove weight/bias while len(src_key_fragments) > 0: src_key_prefix = ".".join(src_key_fragments) + "." if src_key_prefix in conversion_map: converted_prefix = conversion_map[src_key_prefix] converted_key = converted_prefix + src_key[len(src_key_prefix):] converted_sd[converted_key] = value break src_key_fragments.pop(-1) assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" return converted_sd def convert_sdxl_unet_state_dict_to_diffusers(sd): unet_conversion_map = make_unet_conversion_map() conversion_dict = {sd: hf for sd, hf in unet_conversion_map} return convert_unet_state_dict(sd, conversion_dict) def extract_unet_state_dict(state_dict): unet_sd = {} UNET_KEY_PREFIX = "model.diffusion_model." for k, v in state_dict.items(): if k.startswith(UNET_KEY_PREFIX): unet_sd[k[len(UNET_KEY_PREFIX):]] = v return unet_sd def log_model_info(model, name): sd = model.state_dict() if hasattr(model, "state_dict") else model print( f"{name}:", f" number of parameters: {sum(p.numel() for p in sd.values())}", f" dtype: {sd[next(iter(sd))].dtype}", sep='\n' ) def around_reso(img_w, img_h, reso: Union[Tuple[int, int], int], divisible: Optional[int] = None, max_width=None, max_height=None) -> Tuple[int, int]: r""" w*h = reso*reso w/h = img_w/img_h => w = img_ar*h => img_ar*h^2 = reso => h = sqrt(reso / img_ar) """ reso = reso if isinstance(reso, tuple) else (reso, reso) divisible = divisible or 1 if img_w * img_h <= reso[0] * reso[1] and (not max_width or img_w <= max_width) and (not max_height or img_h <= max_height) and img_w % divisible == 0 and img_h % divisible == 0: return (img_w, img_h) img_ar = img_w / img_h around_h = math.sqrt(reso[0]*reso[1] / img_ar) around_w = img_ar * around_h // divisible * divisible if max_width and around_w > max_width: around_h = around_h * max_width // around_w around_w = max_width elif max_height and around_h > max_height: around_w = around_w * max_height // around_h around_h = max_height around_h = min(around_h, max_height) if max_height else around_h around_w = min(around_w, max_width) if max_width else around_w around_h = int(around_h // divisible * divisible) around_w = int(around_w // divisible * divisible) return (around_w, around_h)