from itertools import product from typing import Dict import tensorflow as tf import torch from keras_cv.models import stable_diffusion def port_transformer_block( transformer_block: tf.keras.Model, up_down: int, block_id: int, attention_id: int ) -> Dict[str, torch.Tensor]: """Populates a Transformer block.""" transformer_dict = dict() if block_id is not None: prefix = f"{up_down}_blocks.{block_id}" else: prefix = "mid_block" # Norms. for i in range(1, 4): if i == 1: norm = transformer_block.norm1 elif i == 2: norm = transformer_block.norm2 elif i == 3: norm = transformer_block.norm3 transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.weight" ] = torch.from_numpy(norm.get_weights()[0]) transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.bias" ] = torch.from_numpy(norm.get_weights()[1]) # Attentions. for i in range(1, 3): if i == 1: attn = transformer_block.attn1 else: attn = transformer_block.attn2 transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_q.weight" ] = torch.from_numpy(attn.to_q.get_weights()[0].transpose()) transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_k.weight" ] = torch.from_numpy(attn.to_k.get_weights()[0].transpose()) transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_v.weight" ] = torch.from_numpy(attn.to_v.get_weights()[0].transpose()) transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.weight" ] = torch.from_numpy(attn.out_proj.get_weights()[0].transpose()) transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.bias" ] = torch.from_numpy(attn.out_proj.get_weights()[1]) # Dense. for i in range(0, 3, 2): if i == 0: layer = transformer_block.geglu.dense transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.weight" ] = torch.from_numpy(layer.get_weights()[0].transpose()) transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.bias" ] = torch.from_numpy(layer.get_weights()[1]) else: layer = transformer_block.dense transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.weight" ] = torch.from_numpy(layer.get_weights()[0].transpose()) transformer_dict[ f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.bias" ] = torch.from_numpy(layer.get_weights()[1]) return transformer_dict def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]: """Populates the state dict from the provided TensorFlow model (applicable only for the UNet).""" unet_state_dict = dict() timstep_emb = 1 padded_conv = 1 up_block = 0 up_res_blocks = list(product([0, 1, 2, 3], [0, 1, 2])) up_res_block_flag = 0 up_spatial_transformer_blocks = list(product([1, 2, 3], [0, 1, 2])) up_spatial_transformer_flag = 0 for layer in tf_unet.layers: # Timstep embedding. if isinstance(layer, tf.keras.layers.Dense): unet_state_dict[ f"time_embedding.linear_{timstep_emb}.weight" ] = torch.from_numpy(layer.get_weights()[0].transpose()) unet_state_dict[ f"time_embedding.linear_{timstep_emb}.bias" ] = torch.from_numpy(layer.get_weights()[1]) timstep_emb += 1 # Padded convs (downsamplers). elif isinstance( layer, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D ): if padded_conv == 1: # Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104 unet_state_dict["conv_in.weight"] = torch.from_numpy( layer.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict["conv_in.bias"] = torch.from_numpy( layer.get_weights()[1] ) elif padded_conv in [2, 3, 4]: unet_state_dict[ f"down_blocks.{padded_conv-2}.downsamplers.0.conv.weight" ] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[ f"down_blocks.{padded_conv-2}.downsamplers.0.conv.bias" ] = torch.from_numpy(layer.get_weights()[1]) elif padded_conv == 5: unet_state_dict["conv_out.weight"] = torch.from_numpy( layer.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict["conv_out.bias"] = torch.from_numpy( layer.get_weights()[1] ) padded_conv += 1 # Upsamplers. elif isinstance(layer, stable_diffusion.diffusion_model.Upsample): conv = layer.conv unet_state_dict[ f"up_blocks.{up_block}.upsamplers.0.conv.weight" ] = torch.from_numpy(conv.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[ f"up_blocks.{up_block}.upsamplers.0.conv.bias" ] = torch.from_numpy(conv.get_weights()[1]) up_block += 1 # Output norms. elif isinstance( layer, stable_diffusion.__internal__.layers.group_normalization.GroupNormalization, ): unet_state_dict["conv_norm_out.weight"] = torch.from_numpy( layer.get_weights()[0] ) unet_state_dict["conv_norm_out.bias"] = torch.from_numpy( layer.get_weights()[1] ) # All ResBlocks. elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock): layer_name = layer.name parts = layer_name.split("_") # Down. if len(parts) == 2 or int(parts[-1]) < 8: entry_flow = layer.entry_flow embedding_flow = layer.embedding_flow exit_flow = layer.exit_flow down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2 down_resnet_id = 0 if len(parts) == 2 else int(parts[-1]) % 2 # Conv blocks. first_conv_layer = entry_flow[-1] unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.weight" ] = torch.from_numpy( first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.bias" ] = torch.from_numpy(first_conv_layer.get_weights()[1]) second_conv_layer = exit_flow[-1] unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.weight" ] = torch.from_numpy( second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.bias" ] = torch.from_numpy(second_conv_layer.get_weights()[1]) # Residual blocks. if hasattr(layer, "residual_projection"): if isinstance( layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D, ): residual = layer.residual_projection unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.weight" ] = torch.from_numpy( residual.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.bias" ] = torch.from_numpy(residual.get_weights()[1]) # Timestep embedding. embedding_proj = embedding_flow[-1] unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.weight" ] = torch.from_numpy(embedding_proj.get_weights()[0].transpose()) unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.bias" ] = torch.from_numpy(embedding_proj.get_weights()[1]) # Norms. first_group_norm = entry_flow[0] unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.weight" ] = torch.from_numpy(first_group_norm.get_weights()[0]) unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.bias" ] = torch.from_numpy(first_group_norm.get_weights()[1]) second_group_norm = exit_flow[0] unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.weight" ] = torch.from_numpy(second_group_norm.get_weights()[0]) unet_state_dict[ f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.bias" ] = torch.from_numpy(second_group_norm.get_weights()[1]) # Middle. elif int(parts[-1]) == 8 or int(parts[-1]) == 9: entry_flow = layer.entry_flow embedding_flow = layer.embedding_flow exit_flow = layer.exit_flow mid_resnet_id = int(parts[-1]) % 2 # Conv blocks. first_conv_layer = entry_flow[-1] unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.conv1.weight" ] = torch.from_numpy( first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.conv1.bias" ] = torch.from_numpy(first_conv_layer.get_weights()[1]) second_conv_layer = exit_flow[-1] unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.conv2.weight" ] = torch.from_numpy( second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.conv2.bias" ] = torch.from_numpy(second_conv_layer.get_weights()[1]) # Residual blocks. if hasattr(layer, "residual_projection"): if isinstance( layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D, ): residual = layer.residual_projection unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.weight" ] = torch.from_numpy( residual.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.bias" ] = torch.from_numpy(residual.get_weights()[1]) # Timestep embedding. embedding_proj = embedding_flow[-1] unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.weight" ] = torch.from_numpy(embedding_proj.get_weights()[0].transpose()) unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.bias" ] = torch.from_numpy(embedding_proj.get_weights()[1]) # Norms. first_group_norm = entry_flow[0] unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.norm1.weight" ] = torch.from_numpy(first_group_norm.get_weights()[0]) unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.norm1.bias" ] = torch.from_numpy(first_group_norm.get_weights()[1]) second_group_norm = exit_flow[0] unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.norm2.weight" ] = torch.from_numpy(second_group_norm.get_weights()[0]) unet_state_dict[ f"mid_block.resnets.{mid_resnet_id}.norm2.bias" ] = torch.from_numpy(second_group_norm.get_weights()[1]) # Up. elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks): entry_flow = layer.entry_flow embedding_flow = layer.embedding_flow exit_flow = layer.exit_flow up_res_block = up_res_blocks[up_res_block_flag] up_block_id = up_res_block[0] up_resnet_id = up_res_block[1] # Conv blocks. first_conv_layer = entry_flow[-1] unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.weight" ] = torch.from_numpy( first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.bias" ] = torch.from_numpy(first_conv_layer.get_weights()[1]) second_conv_layer = exit_flow[-1] unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.weight" ] = torch.from_numpy( second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.bias" ] = torch.from_numpy(second_conv_layer.get_weights()[1]) # Residual blocks. if hasattr(layer, "residual_projection"): if isinstance( layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D, ): residual = layer.residual_projection unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.weight" ] = torch.from_numpy( residual.get_weights()[0].transpose(3, 2, 0, 1) ) unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.bias" ] = torch.from_numpy(residual.get_weights()[1]) # Timestep embedding. embedding_proj = embedding_flow[-1] unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.weight" ] = torch.from_numpy(embedding_proj.get_weights()[0].transpose()) unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.bias" ] = torch.from_numpy(embedding_proj.get_weights()[1]) # Norms. first_group_norm = entry_flow[0] unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.weight" ] = torch.from_numpy(first_group_norm.get_weights()[0]) unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.bias" ] = torch.from_numpy(first_group_norm.get_weights()[1]) second_group_norm = exit_flow[0] unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.weight" ] = torch.from_numpy(second_group_norm.get_weights()[0]) unet_state_dict[ f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.bias" ] = torch.from_numpy(second_group_norm.get_weights()[1]) up_res_block_flag += 1 # All SpatialTransformer blocks. elif isinstance(layer, stable_diffusion.diffusion_model.SpatialTransformer): layer_name = layer.name parts = layer_name.split("_") # Down. if len(parts) == 2 or int(parts[-1]) < 6: down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2 down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2 # Convs. proj1 = layer.proj1 unet_state_dict[ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.weight" ] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.bias" ] = torch.from_numpy(proj1.get_weights()[1]) proj2 = layer.proj2 unet_state_dict[ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.weight" ] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.bias" ] = torch.from_numpy(proj2.get_weights()[1]) # Transformer blocks. transformer_block = layer.transformer_block unet_state_dict.update( port_transformer_block( transformer_block, "down", down_block_id, down_attention_id ) ) # Norms. norm = layer.norm unet_state_dict[ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.weight" ] = torch.from_numpy(norm.get_weights()[0]) unet_state_dict[ f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.bias" ] = torch.from_numpy(norm.get_weights()[1]) # Middle. elif int(parts[-1]) == 6: mid_attention_id = int(parts[-1]) % 2 # Convs. proj1 = layer.proj1 unet_state_dict[ f"mid_block.attentions.{mid_attention_id}.proj_in.weight" ] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[ f"mid_block.attentions.{mid_attention_id}.proj_in.bias" ] = torch.from_numpy(proj1.get_weights()[1]) proj2 = layer.proj2 unet_state_dict[ f"mid_block.attentions.{mid_resnet_id}.proj_out.weight" ] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[ f"mid_block.attentions.{mid_attention_id}.proj_out.bias" ] = torch.from_numpy(proj2.get_weights()[1]) # Transformer blocks. transformer_block = layer.transformer_block unet_state_dict.update( port_transformer_block( transformer_block, "mid", None, mid_attention_id ) ) # Norms. norm = layer.norm unet_state_dict[ f"mid_block.attentions.{mid_attention_id}.norm.weight" ] = torch.from_numpy(norm.get_weights()[0]) unet_state_dict[ f"mid_block.attentions.{mid_attention_id}.norm.bias" ] = torch.from_numpy(norm.get_weights()[1]) # Up. elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len( up_spatial_transformer_blocks ): up_spatial_transformer_block = up_spatial_transformer_blocks[ up_spatial_transformer_flag ] up_block_id = up_spatial_transformer_block[0] up_attention_id = up_spatial_transformer_block[1] # Convs. proj1 = layer.proj1 unet_state_dict[ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.weight" ] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.bias" ] = torch.from_numpy(proj1.get_weights()[1]) proj2 = layer.proj2 unet_state_dict[ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.weight" ] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.bias" ] = torch.from_numpy(proj2.get_weights()[1]) # Transformer blocks. transformer_block = layer.transformer_block unet_state_dict.update( port_transformer_block( transformer_block, "up", up_block_id, up_attention_id ) ) # Norms. norm = layer.norm unet_state_dict[ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.weight" ] = torch.from_numpy(norm.get_weights()[0]) unet_state_dict[ f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.bias" ] = torch.from_numpy(norm.get_weights()[1]) up_spatial_transformer_flag += 1 return unet_state_dict