babylm / LMConfig.py
gongjingyao's picture
Upload Transformer
f26cdb7 verified
from transformers import PretrainedConfig
from typing import List
class LMConfig(PretrainedConfig):
model_type = "babylm"
def __init__(
self,
dim: int = 512,
n_layers: int = 8,
n_heads: int = 8,
n_kv_heads: int = None,
vocab_size: int = 64000,
hidden_dim: int = None,
multiple_of: int = 64,
norm_eps: float = 1e-5,
max_seq_len: int = 512,
dropout: float = 0.0,
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.dropout = dropout
super().__init__(**kwargs)