|
from transformers.configuration_utils import PretrainedConfig |
|
from typing import Optional |
|
|
|
|
|
class MUDDPythiaConfig(PretrainedConfig): |
|
model_type = "muddpythia" |
|
|
|
''' |
|
MUDDPythiaConfig is a config class for MUDDPythia, which is adpated from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L21 |
|
''' |
|
def __init__( |
|
self, |
|
block_size: int = 2048, |
|
vocab_size: int = 32000, |
|
n_layer: int = 32, |
|
n_head: int = 32, |
|
dim: int = 2560, |
|
intermediate_size: int = None, |
|
n_local_heads: int = -1, |
|
head_dim: int = 64, |
|
rope_base: float = 10000, |
|
norm_eps: float = 1e-5, |
|
use_gradient_checkpointing: bool = False, |
|
is_training: bool = False, |
|
use_qk_norm: bool = False , |
|
pad_token_id: Optional[int]= None, |
|
use_parallel_residual: bool =True, |
|
use_linear_bias: bool = True, |
|
rotary_pct: float = 0.25, |
|
bos_token_id: int =1, |
|
eos_token_id: int =2, |
|
tie_word_embeddings: bool =False, |
|
use_layer_cache: bool = True, |
|
stack_hidden: bool = False, |
|
dense: bool = True, |
|
dynamic_dense: bool = True, |
|
sepln: bool = True, |
|
dense_type: str = 'qkvr', |
|
expand_last: bool = False, |
|
round64: bool = False, |
|
**kwargs, |
|
): |
|
self.block_size=block_size |
|
self.vocab_size=vocab_size |
|
self.n_layer=n_layer |
|
self.n_head=n_head |
|
self.dim=dim |
|
self.intermediate_size=intermediate_size |
|
self.n_local_heads=n_local_heads |
|
self.head_dim=head_dim |
|
self.rope_base=rope_base |
|
self.norm_eps=norm_eps |
|
self.use_gradient_checkpointing=use_gradient_checkpointing |
|
self.is_training=is_training |
|
self.use_qk_norm=use_qk_norm |
|
self.use_parallel_residual = use_parallel_residual |
|
self.use_linear_bias = use_linear_bias |
|
self.rotary_pct = rotary_pct |
|
|
|
self.use_layer_cache= use_layer_cache |
|
self.stack_hidden= stack_hidden |
|
self.dense= dense |
|
self.dynamic_dense= dynamic_dense |
|
self.sepln= sepln |
|
self.dense_type=dense_type |
|
self.expand_last= expand_last |
|
self.round64 = round64 |
|
|
|
if self.n_local_heads == -1: |
|
self.n_local_heads = self.n_head |
|
if self.intermediate_size is None: |
|
self.intermediate_size = 4 * self.dim |
|
self.head_dim = self.dim // self.n_head |
|
|
|
super().__init__( |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
tie_word_embeddings=tie_word_embeddings, |
|
**kwargs, |
|
) |
|
|