|
from transformers import PretrainedConfig
|
|
from typing import List
|
|
|
|
|
|
class LMConfig(PretrainedConfig):
|
|
model_type = "babylm"
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int = 512,
|
|
n_layers: int = 8,
|
|
n_heads: int = 8,
|
|
n_kv_heads: int = None,
|
|
vocab_size: int = 64000,
|
|
hidden_dim: int = None,
|
|
multiple_of: int = 64,
|
|
norm_eps: float = 1e-5,
|
|
max_seq_len: int = 512,
|
|
dropout: float = 0.0,
|
|
**kwargs,
|
|
):
|
|
|
|
self.dim = dim
|
|
self.n_layers = n_layers
|
|
self.n_heads = n_heads
|
|
self.n_kv_heads = n_kv_heads
|
|
self.vocab_size = vocab_size
|
|
self.hidden_dim = hidden_dim
|
|
self.multiple_of = multiple_of
|
|
self.norm_eps = norm_eps
|
|
self.max_seq_len = max_seq_len
|
|
self.dropout = dropout
|
|
super().__init__(**kwargs)
|
|
|