from typing import Dict import tensorflow as tf import torch from keras_cv.models import stable_diffusion MAX_SEQ_LENGTH = 77 def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]: """Populates the state dict from the provided TensorFlow model (applicable only for the text encoder).""" text_state_dict = dict() num_encoder_layers = 0 # Position ids. text_state_dict["text_model.embeddings.position_ids"] = torch.tensor( list(range(MAX_SEQ_LENGTH)) ).unsqueeze(0) for layer in tf_text_encoder.layers: # Embeddings. if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding): text_state_dict[ "text_model.embeddings.token_embedding.weight" ] = torch.from_numpy(layer.token_embedding.get_weights()[0]) text_state_dict[ "text_model.embeddings.position_embedding.weight" ] = torch.from_numpy(layer.position_embedding.get_weights()[0]) # Encoder blocks. elif isinstance(layer, stable_diffusion.text_encoder.CLIPEncoderLayer): # LayerNorms for i in range(1, 3): if i == 1: text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.weight" ] = torch.from_numpy(layer.layer_norm1.get_weights()[0]) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.bias" ] = torch.from_numpy(layer.layer_norm1.get_weights()[1]) else: text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.weight" ] = torch.from_numpy(layer.layer_norm2.get_weights()[0]) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.bias" ] = torch.from_numpy(layer.layer_norm2.get_weights()[1]) # Attention. q_proj = layer.clip_attn.q_proj k_proj = layer.clip_attn.k_proj v_proj = layer.clip_attn.v_proj out_proj = layer.clip_attn.out_proj text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.weight" ] = torch.from_numpy(q_proj.get_weights()[0].transpose()) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.bias" ] = torch.from_numpy(q_proj.get_weights()[1]) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.weight" ] = torch.from_numpy(k_proj.get_weights()[0].transpose()) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.bias" ] = torch.from_numpy(k_proj.get_weights()[1]) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.weight" ] = torch.from_numpy(v_proj.get_weights()[0].transpose()) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.bias" ] = torch.from_numpy(v_proj.get_weights()[1]) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.weight" ] = torch.from_numpy(out_proj.get_weights()[0].transpose()) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.bias" ] = torch.from_numpy(out_proj.get_weights()[1]) # MLPs. fc1 = layer.fc1 fc2 = layer.fc2 text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.weight" ] = torch.from_numpy(fc1.get_weights()[0].transpose()) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.bias" ] = torch.from_numpy(fc1.get_weights()[1]) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.weight" ] = torch.from_numpy(fc2.get_weights()[0].transpose()) text_state_dict[ f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.bias" ] = torch.from_numpy(fc2.get_weights()[1]) num_encoder_layers += 1 # Final LayerNorm. elif isinstance(layer, tf.keras.layers.LayerNormalization): text_state_dict["text_model.final_layer_norm.weight"] = torch.from_numpy( layer.get_weights()[0] ) text_state_dict["text_model.final_layer_norm.bias"] = torch.from_numpy( layer.get_weights()[1] ) return text_state_dict