MUDDPythia-2.8B / configuration_muddpythia.py
Hilbertmeng's picture
add model code
fbde94d
raw
history blame
2.71 kB
from transformers.configuration_utils import PretrainedConfig
from typing import Optional
class MUDDPythiaConfig(PretrainedConfig):
model_type = "muddpythia"
'''
MUDDPythiaConfig is a config class for MUDDPythia, which is adpated from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L21
'''
def __init__(
self,
block_size: int = 2048,
vocab_size: int = 32000,
n_layer: int = 32,
n_head: int = 32,
dim: int = 2560,
intermediate_size: int = None,
n_local_heads: int = -1,
head_dim: int = 64,
rope_base: float = 10000,
norm_eps: float = 1e-5,
use_gradient_checkpointing: bool = False,
is_training: bool = False,
use_qk_norm: bool = False ,
pad_token_id: Optional[int]= None,
use_parallel_residual: bool =True,
use_linear_bias: bool = True,
rotary_pct: float = 0.25,
bos_token_id: int =1,
eos_token_id: int =2,
tie_word_embeddings: bool =False,
use_layer_cache: bool = True,
stack_hidden: bool = False,
dense: bool = True,
dynamic_dense: bool = True,
sepln: bool = True,
dense_type: str = 'qkvr',
expand_last: bool = False,
round64: bool = False,
**kwargs,
):
self.block_size=block_size
self.vocab_size=vocab_size
self.n_layer=n_layer
self.n_head=n_head
self.dim=dim
self.intermediate_size=intermediate_size
self.n_local_heads=n_local_heads
self.head_dim=head_dim
self.rope_base=rope_base
self.norm_eps=norm_eps
self.use_gradient_checkpointing=use_gradient_checkpointing
self.is_training=is_training
self.use_qk_norm=use_qk_norm
self.use_parallel_residual = use_parallel_residual
self.use_linear_bias = use_linear_bias
self.rotary_pct = rotary_pct
self.use_layer_cache= use_layer_cache
self.stack_hidden= stack_hidden
self.dense= dense
self.dynamic_dense= dynamic_dense
self.sepln= sepln
self.dense_type=dense_type
self.expand_last= expand_last
self.round64 = round64
# post init
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
self.intermediate_size = 4 * self.dim
self.head_dim = self.dim // self.n_head
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)