Spaces:
Runtime error
Runtime error
| import argparse | |
| import huggingface_hub | |
| import k_diffusion as K | |
| import torch | |
| from diffusers import UNet2DConditionModel | |
| UPSCALER_REPO = "pcuenq/k-upscaler" | |
| def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): | |
| rv = { | |
| # norm1 | |
| f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"], | |
| f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"], | |
| # conv1 | |
| f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"], | |
| f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"], | |
| # norm2 | |
| f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"], | |
| f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"], | |
| # conv2 | |
| f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"], | |
| f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"], | |
| } | |
| if resnet.conv_shortcut is not None: | |
| rv.update( | |
| { | |
| f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"], | |
| } | |
| ) | |
| return rv | |
| def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): | |
| weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0) | |
| bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0) | |
| rv = { | |
| # norm | |
| f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"], | |
| f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"], | |
| # to_q | |
| f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1), | |
| f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q, | |
| # to_k | |
| f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1), | |
| f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k, | |
| # to_v | |
| f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1), | |
| f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v, | |
| # to_out | |
| f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"] | |
| .squeeze(-1) | |
| .squeeze(-1), | |
| f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"], | |
| } | |
| return rv | |
| def cross_attn_to_diffusers_checkpoint( | |
| checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix | |
| ): | |
| weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0) | |
| bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0) | |
| rv = { | |
| # norm2 (ada groupnorm) | |
| f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[ | |
| f"{attention_prefix}.norm_dec.mapper.weight" | |
| ], | |
| f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[ | |
| f"{attention_prefix}.norm_dec.mapper.bias" | |
| ], | |
| # layernorm on encoder_hidden_state | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[ | |
| f"{attention_prefix}.norm_enc.weight" | |
| ], | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[ | |
| f"{attention_prefix}.norm_enc.bias" | |
| ], | |
| # to_q | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[ | |
| f"{attention_prefix}.q_proj.weight" | |
| ] | |
| .squeeze(-1) | |
| .squeeze(-1), | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[ | |
| f"{attention_prefix}.q_proj.bias" | |
| ], | |
| # to_k | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1), | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k, | |
| # to_v | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1), | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v, | |
| # to_out | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[ | |
| f"{attention_prefix}.out_proj.weight" | |
| ] | |
| .squeeze(-1) | |
| .squeeze(-1), | |
| f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[ | |
| f"{attention_prefix}.out_proj.bias" | |
| ], | |
| } | |
| return rv | |
| def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type): | |
| block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks" | |
| block_prefix = f"{block_prefix}.{block_idx}" | |
| diffusers_checkpoint = {} | |
| if not hasattr(block, "attentions"): | |
| n = 1 # resnet only | |
| elif not block.attentions[0].add_self_attention: | |
| n = 2 # resnet -> cross-attention | |
| else: | |
| n = 3 # resnet -> self-attention -> cross-attention) | |
| for resnet_idx, resnet in enumerate(block.resnets): | |
| # diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" | |
| diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}" | |
| idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1 | |
| resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}" | |
| diffusers_checkpoint.update( | |
| resnet_to_diffusers_checkpoint( | |
| resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix | |
| ) | |
| ) | |
| if hasattr(block, "attentions"): | |
| for attention_idx, attention in enumerate(block.attentions): | |
| diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}" | |
| idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2 | |
| self_attention_prefix = f"{block_prefix}.{idx}" | |
| cross_attention_prefix = f"{block_prefix}.{idx }" | |
| cross_attention_index = 1 if not attention.add_self_attention else 2 | |
| idx = ( | |
| n * attention_idx + cross_attention_index | |
| if block_type == "up" | |
| else n * attention_idx + cross_attention_index + 1 | |
| ) | |
| cross_attention_prefix = f"{block_prefix}.{idx }" | |
| diffusers_checkpoint.update( | |
| cross_attn_to_diffusers_checkpoint( | |
| checkpoint, | |
| diffusers_attention_prefix=diffusers_attention_prefix, | |
| diffusers_attention_index=2, | |
| attention_prefix=cross_attention_prefix, | |
| ) | |
| ) | |
| if attention.add_self_attention is True: | |
| diffusers_checkpoint.update( | |
| self_attn_to_diffusers_checkpoint( | |
| checkpoint, | |
| diffusers_attention_prefix=diffusers_attention_prefix, | |
| attention_prefix=self_attention_prefix, | |
| ) | |
| ) | |
| return diffusers_checkpoint | |
| def unet_to_diffusers_checkpoint(model, checkpoint): | |
| diffusers_checkpoint = {} | |
| # pre-processing | |
| diffusers_checkpoint.update( | |
| { | |
| "conv_in.weight": checkpoint["inner_model.proj_in.weight"], | |
| "conv_in.bias": checkpoint["inner_model.proj_in.bias"], | |
| } | |
| ) | |
| # timestep and class embedding | |
| diffusers_checkpoint.update( | |
| { | |
| "time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1), | |
| "time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"], | |
| "time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"], | |
| "time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"], | |
| "time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"], | |
| "time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"], | |
| } | |
| ) | |
| # down_blocks | |
| for down_block_idx, down_block in enumerate(model.down_blocks): | |
| diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down")) | |
| # up_blocks | |
| for up_block_idx, up_block in enumerate(model.up_blocks): | |
| diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up")) | |
| # post-processing | |
| diffusers_checkpoint.update( | |
| { | |
| "conv_out.weight": checkpoint["inner_model.proj_out.weight"], | |
| "conv_out.bias": checkpoint["inner_model.proj_out.bias"], | |
| } | |
| ) | |
| return diffusers_checkpoint | |
| def unet_model_from_original_config(original_config): | |
| in_channels = original_config["input_channels"] + original_config["unet_cond_dim"] | |
| out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0) | |
| block_out_channels = original_config["channels"] | |
| assert ( | |
| len(set(original_config["depths"])) == 1 | |
| ), "UNet2DConditionModel currently do not support blocks with different number of layers" | |
| layers_per_block = original_config["depths"][0] | |
| class_labels_dim = original_config["mapping_cond_dim"] | |
| cross_attention_dim = original_config["cross_cond_dim"] | |
| attn1_types = [] | |
| attn2_types = [] | |
| for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]): | |
| if s: | |
| a1 = "self" | |
| a2 = "cross" if c else None | |
| elif c: | |
| a1 = "cross" | |
| a2 = None | |
| else: | |
| a1 = None | |
| a2 = None | |
| attn1_types.append(a1) | |
| attn2_types.append(a2) | |
| unet = UNet2DConditionModel( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"), | |
| mid_block_type=None, | |
| up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"), | |
| block_out_channels=block_out_channels, | |
| layers_per_block=layers_per_block, | |
| act_fn="gelu", | |
| norm_num_groups=None, | |
| cross_attention_dim=cross_attention_dim, | |
| attention_head_dim=64, | |
| time_cond_proj_dim=class_labels_dim, | |
| resnet_time_scale_shift="scale_shift", | |
| time_embedding_type="fourier", | |
| timestep_post_act="gelu", | |
| conv_in_kernel=1, | |
| conv_out_kernel=1, | |
| ) | |
| return unet | |
| def main(args): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json") | |
| orig_weights_path = huggingface_hub.hf_hub_download( | |
| UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth" | |
| ) | |
| print(f"loading original model configuration from {orig_config_path}") | |
| print(f"loading original model checkpoint from {orig_weights_path}") | |
| print("converting to diffusers unet") | |
| orig_config = K.config.load_config(open(orig_config_path))["model"] | |
| model = unet_model_from_original_config(orig_config) | |
| orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"] | |
| converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint) | |
| model.load_state_dict(converted_checkpoint, strict=True) | |
| model.save_pretrained(args.dump_path) | |
| print(f"saving converted unet model in {args.dump_path}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | |
| args = parser.parse_args() | |
| main(args) | |