|
import math |
|
from typing import Tuple, Union, Optional |
|
|
|
|
|
def make_unet_conversion_map(): |
|
unet_conversion_map_layer = [] |
|
|
|
for i in range(3): |
|
|
|
for j in range(2): |
|
|
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." |
|
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." |
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) |
|
|
|
if i < 3: |
|
|
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." |
|
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." |
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) |
|
|
|
for j in range(3): |
|
|
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." |
|
sd_up_res_prefix = f"output_blocks.{3*i + j}.0." |
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) |
|
|
|
|
|
|
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." |
|
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." |
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) |
|
|
|
if i < 3: |
|
|
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." |
|
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." |
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) |
|
|
|
|
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." |
|
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." |
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) |
|
|
|
hf_mid_atn_prefix = "mid_block.attentions.0." |
|
sd_mid_atn_prefix = "middle_block.1." |
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) |
|
|
|
for j in range(2): |
|
hf_mid_res_prefix = f"mid_block.resnets.{j}." |
|
sd_mid_res_prefix = f"middle_block.{2*j}." |
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) |
|
|
|
unet_conversion_map_resnet = [ |
|
|
|
("in_layers.0.", "norm1."), |
|
("in_layers.2.", "conv1."), |
|
("out_layers.0.", "norm2."), |
|
("out_layers.3.", "conv2."), |
|
("emb_layers.1.", "time_emb_proj."), |
|
("skip_connection.", "conv_shortcut."), |
|
] |
|
|
|
unet_conversion_map = [] |
|
for sd, hf in unet_conversion_map_layer: |
|
if "resnets" in hf: |
|
for sd_res, hf_res in unet_conversion_map_resnet: |
|
unet_conversion_map.append((sd + sd_res, hf + hf_res)) |
|
else: |
|
unet_conversion_map.append((sd, hf)) |
|
|
|
for j in range(2): |
|
hf_time_embed_prefix = f"time_embedding.linear_{j+1}." |
|
sd_time_embed_prefix = f"time_embed.{j*2}." |
|
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) |
|
|
|
for j in range(2): |
|
hf_label_embed_prefix = f"add_embedding.linear_{j+1}." |
|
sd_label_embed_prefix = f"label_emb.0.{j*2}." |
|
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) |
|
|
|
unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) |
|
unet_conversion_map.append(("out.0.", "conv_norm_out.")) |
|
unet_conversion_map.append(("out.2.", "conv_out.")) |
|
|
|
return unet_conversion_map |
|
|
|
|
|
def convert_unet_state_dict(src_sd, conversion_map): |
|
converted_sd = {} |
|
for src_key, value in src_sd.items(): |
|
src_key_fragments = src_key.split(".")[:-1] |
|
while len(src_key_fragments) > 0: |
|
src_key_prefix = ".".join(src_key_fragments) + "." |
|
if src_key_prefix in conversion_map: |
|
converted_prefix = conversion_map[src_key_prefix] |
|
converted_key = converted_prefix + src_key[len(src_key_prefix):] |
|
converted_sd[converted_key] = value |
|
break |
|
src_key_fragments.pop(-1) |
|
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" |
|
|
|
return converted_sd |
|
|
|
|
|
def convert_sdxl_unet_state_dict_to_diffusers(sd): |
|
unet_conversion_map = make_unet_conversion_map() |
|
|
|
conversion_dict = {sd: hf for sd, hf in unet_conversion_map} |
|
return convert_unet_state_dict(sd, conversion_dict) |
|
|
|
|
|
def extract_unet_state_dict(state_dict): |
|
unet_sd = {} |
|
UNET_KEY_PREFIX = "model.diffusion_model." |
|
for k, v in state_dict.items(): |
|
if k.startswith(UNET_KEY_PREFIX): |
|
unet_sd[k[len(UNET_KEY_PREFIX):]] = v |
|
return unet_sd |
|
|
|
|
|
def log_model_info(model, name): |
|
sd = model.state_dict() if hasattr(model, "state_dict") else model |
|
print( |
|
f"{name}:", |
|
f" number of parameters: {sum(p.numel() for p in sd.values())}", |
|
f" dtype: {sd[next(iter(sd))].dtype}", |
|
sep='\n' |
|
) |
|
|
|
|
|
def around_reso(img_w, img_h, reso: Union[Tuple[int, int], int], divisible: Optional[int] = None, max_width=None, max_height=None) -> Tuple[int, int]: |
|
r""" |
|
w*h = reso*reso |
|
w/h = img_w/img_h |
|
=> w = img_ar*h |
|
=> img_ar*h^2 = reso |
|
=> h = sqrt(reso / img_ar) |
|
""" |
|
reso = reso if isinstance(reso, tuple) else (reso, reso) |
|
divisible = divisible or 1 |
|
if img_w * img_h <= reso[0] * reso[1] and (not max_width or img_w <= max_width) and (not max_height or img_h <= max_height) and img_w % divisible == 0 and img_h % divisible == 0: |
|
return (img_w, img_h) |
|
img_ar = img_w / img_h |
|
around_h = math.sqrt(reso[0]*reso[1] / img_ar) |
|
around_w = img_ar * around_h // divisible * divisible |
|
if max_width and around_w > max_width: |
|
around_h = around_h * max_width // around_w |
|
around_w = max_width |
|
elif max_height and around_h > max_height: |
|
around_w = around_w * max_height // around_h |
|
around_h = max_height |
|
around_h = min(around_h, max_height) if max_height else around_h |
|
around_w = min(around_w, max_width) if max_width else around_w |
|
around_h = int(around_h // divisible * divisible) |
|
around_w = int(around_w // divisible * divisible) |
|
return (around_w, around_h) |
|
|