levanter-backpack-1b-100k / backpack_config.py
ivanzhouyq's picture
Upload model
de682e6
raw
history blame
668 Bytes
from transformers import GPT2Config
class BackpackGPT2Config(GPT2Config):
model_type = "backpack-gpt2"
def __init__(
self,
num_senses: int = 16,
sense_intermediate_scale: int = 4,
vocab_size: int = 50264,
n_positions: int = 512,
scale_attn_by_inverse_layer_idx: bool = True,
**kwargs,
):
self.num_senses = num_senses
self.sense_intermediate_scale = sense_intermediate_scale
super().__init__(
vocab_size=vocab_size,
n_positions=n_positions,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
**kwargs,
)