from transformers import PretrainedConfig class GraniteConfig(PretrainedConfig): model_type = "granite" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = { "hidden_size": "n_embd", "max_position_embeddings": "n_positions", "num_attention_heads": "n_head", "num_hidden_layers": "n_layer", } def __init__( self, vocab_size: int = 50257, n_positions: int = 1024, n_embd: int = 768, n_layer: int = 12, n_head: int = 12, num_key_value_heads: int = None, n_inner: int = None, activation_function: str = "gelu_pytorch_tanh", attention_head_type: str = "mqa", resid_pdrop: float = 0.1, embd_pdrop: float = 0.1, attn_pdrop: float = 0.1, normalization_function: str = "layernorm", layer_norm_epsilon: float = 1e-5, initializer_range: float = 0.02, scale_attn_weights: bool = True, attention_multiplier: float = None, use_cache: bool = True, bos_token_id: int = 50256, eos_token_id: int = 50256, pad_token_id: int = 50256, attention_softmax_in_fp32: bool = True, scale_attention_softmax_in_fp32: bool = True, add_bias: bool = True, position_embedding_type: str = "learned_absolute", rope_theta: int = 10000, **kwargs, ) -> None: self.vocab_size = vocab_size self.n_positions = n_positions self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.num_key_value_heads = num_key_value_heads self.n_inner = 4 * n_embd if n_inner is None else n_inner self.activation_function = activation_function self.attention_head_type = attention_head_type self.resid_pdrop = resid_pdrop self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop self.normalization_function = normalization_function self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.scale_attn_weights = scale_attn_weights self.attention_multiplier = attention_multiplier self.use_cache = use_cache self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 self.position_embedding_type = position_embedding_type self.add_bias = add_bias self.rope_theta = rope_theta if self.attention_multiplier is not None: assert self.scale_attn_weights # for compatibility with some features self.multi_query = attention_head_type == "mqa" if attention_head_type == "mha": if self.num_key_value_heads is None: self.num_key_value_heads = self.n_head assert ( self.n_head == self.num_key_value_heads ), "MultiHeadAttention should have same number of heads for query, keys and values" elif attention_head_type == "mqa": if self.num_key_value_heads is None: self.num_key_value_heads = 1 assert self.num_key_value_heads == 1, "MultiQueryAttention should have 1 head for keys and values" elif attention_head_type == "gqa": assert ( self.num_key_value_heads is not None ), "`num_key_value_heads` needs to be specified with GroupedQueryAttention" assert ( self.n_head % self.num_key_value_heads == 0 ), "GroupedQueryAttention should have more than 1 head for keys and values" else: raise ValueError(f"unexpected attention_head_type ({attention_head_type})") super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)