from transformers import BertModel, BertConfig, T5EncoderModel, T5Config import torch class HunyuanDiTCLIPTextEncoder(BertModel): def __init__(self): config = BertConfig( _name_or_path = "", architectures = ["BertModel"], attention_probs_dropout_prob = 0.1, bos_token_id = 0, classifier_dropout = None, directionality = "bidi", eos_token_id = 2, hidden_act = "gelu", hidden_dropout_prob = 0.1, hidden_size = 1024, initializer_range = 0.02, intermediate_size = 4096, layer_norm_eps = 1e-12, max_position_embeddings = 512, model_type = "bert", num_attention_heads = 16, num_hidden_layers = 24, output_past = True, pad_token_id = 0, pooler_fc_size = 768, pooler_num_attention_heads = 12, pooler_num_fc_layers = 3, pooler_size_per_head = 128, pooler_type = "first_token_transform", position_embedding_type = "absolute", torch_dtype = "float32", transformers_version = "4.37.2", type_vocab_size = 2, use_cache = True, vocab_size = 47020 ) super().__init__(config, add_pooling_layer=False) self.eval() def forward(self, input_ids, attention_mask, clip_skip=1): input_shape = input_ids.size() batch_size, seq_length = input_shape device = input_ids.device past_key_values_length = 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) embedding_output = self.embeddings( input_ids=input_ids, position_ids=None, token_type_ids=None, inputs_embeds=None, past_key_values_length=0, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=False, output_attentions=False, output_hidden_states=True, return_dict=True, ) all_hidden_states = encoder_outputs.hidden_states prompt_emb = all_hidden_states[-clip_skip] if clip_skip > 1: mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std() prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean return prompt_emb @staticmethod def state_dict_converter(): return HunyuanDiTCLIPTextEncoderStateDictConverter() class HunyuanDiTT5TextEncoder(T5EncoderModel): def __init__(self): config = T5Config( _name_or_path = "../HunyuanDiT/t2i/mt5", architectures = ["MT5ForConditionalGeneration"], classifier_dropout = 0.0, d_ff = 5120, d_kv = 64, d_model = 2048, decoder_start_token_id = 0, dense_act_fn = "gelu_new", dropout_rate = 0.1, eos_token_id = 1, feed_forward_proj = "gated-gelu", initializer_factor = 1.0, is_encoder_decoder = True, is_gated_act = True, layer_norm_epsilon = 1e-06, model_type = "t5", num_decoder_layers = 24, num_heads = 32, num_layers = 24, output_past = True, pad_token_id = 0, relative_attention_max_distance = 128, relative_attention_num_buckets = 32, tie_word_embeddings = False, tokenizer_class = "T5Tokenizer", transformers_version = "4.37.2", use_cache = True, vocab_size = 250112 ) super().__init__(config) self.eval() def forward(self, input_ids, attention_mask, clip_skip=1): outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) prompt_emb = outputs.hidden_states[-clip_skip] if clip_skip > 1: mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std() prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean return prompt_emb @staticmethod def state_dict_converter(): return HunyuanDiTT5TextEncoderStateDictConverter() class HunyuanDiTCLIPTextEncoderStateDictConverter(): def __init__(self): pass def from_diffusers(self, state_dict): state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")} return state_dict_ def from_civitai(self, state_dict): return self.from_diffusers(state_dict) class HunyuanDiTT5TextEncoderStateDictConverter(): def __init__(self): pass def from_diffusers(self, state_dict): state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")} state_dict_["shared.weight"] = state_dict["shared.weight"] return state_dict_ def from_civitai(self, state_dict): return self.from_diffusers(state_dict)