from transformers import PretrainedConfig class GEBConfig(PretrainedConfig): model_type = "geblm" def __init__( self, num_layers=24, padded_vocab_size=64896, hidden_size=2048, ffn_hidden_size=5632, kv_channels=128, num_attention_heads=16, torch_dtype='bfloat16', seq_length=4096, hidden_dropout=0.0, attention_dropout=0.0, layernorm_epsilon=1e-5, max_position_embeddings=4096, bias_dropout_fusion=True, use_cache=True, apply_residual_connection_post_layernorm=False, post_layer_norm=True, add_bias_linear=False, use_flash_attn=True, num_key_value_heads=4, apply_query_key_layer_scaling=False, attention_softmax_in_fp32=False, fp32_residual_connection=False, pre_seq_len=None, prefix_projection=False, tie_word_embeddings=False, **kwargs ): self.num_layers=num_layers self.padded_vocab_size=padded_vocab_size self.hidden_size=hidden_size self.ffn_hidden_size=ffn_hidden_size self.kv_channels=kv_channels self.num_attention_heads=num_attention_heads self.torch_dtype=torch_dtype self.seq_length=seq_length self.hidden_dropout=hidden_dropout, self.attention_dropout=attention_dropout self.layernorm_epsilon=layernorm_epsilon self.max_position_embeddings=max_position_embeddings self.bias_dropout_fusion=bias_dropout_fusion self.use_cache=use_cache self.apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm self.post_layer_norm=post_layer_norm self.add_bias_linear=add_bias_linear self.use_flash_attn=use_flash_attn self.num_key_value_heads=num_key_value_heads self.apply_query_key_layer_scaling=apply_query_key_layer_scaling self.attention_softmax_in_fp32=attention_softmax_in_fp32 self.fp32_residual_connection=fp32_residual_connection self.pre_seq_len=pre_seq_len self.prefix_projection=prefix_projection self.tie_word_embeddings=tie_word_embeddings super().__init__(**kwargs)