|
from transformers import PretrainedConfig |
|
import json |
|
|
|
|
|
class StripedHyenaConfig(PretrainedConfig): |
|
model_type = "stripedhyena" |
|
|
|
def __init__( |
|
self, |
|
vocab_size=32000, |
|
hidden_size=4096, |
|
num_filters=4096, |
|
inner_mlp_size=14336, |
|
attn_layer_idxs=[], |
|
hyena_layer_idxs=[], |
|
num_layers=32, |
|
tie_embeddings=False, |
|
short_filter_length=3, |
|
num_attention_heads=32, |
|
proj_groups=4, |
|
hyena_filter_groups=1, |
|
split_k0=True, |
|
column_split_hyena=True, |
|
column_split=False, |
|
model_parallel_size=1, |
|
pipe_parallel_size=1, |
|
short_filter_bias=True, |
|
mha_out_proj_bias=False, |
|
qkv_proj_bias=False, |
|
final_norm=True, |
|
use_cache=True, |
|
use_flash_attention_2=True, |
|
use_flash_rmsnorm=True, |
|
use_flash_depthwise=False, |
|
use_flashfft=False, |
|
inference_mode=False, |
|
prefill_style="fft", |
|
max_seqlen=32768, |
|
eps=1e-5, |
|
state_size=2, |
|
rotary_emb_base=500000, |
|
smeared_gqa=False, |
|
make_vocab_size_divisible_by=8, |
|
log_intermediate_values=False, |
|
**kwargs, |
|
): |
|
self.vocab_size = vocab_size |
|
self.hidden_size = hidden_size |
|
self.num_filters = num_filters |
|
self.inner_mlp_size = inner_mlp_size |
|
self.attn_layer_idxs = attn_layer_idxs |
|
self.hyena_layer_idxs = hyena_layer_idxs |
|
self.num_layers = num_layers |
|
self.tie_embeddings = tie_embeddings |
|
self.short_filter_length = short_filter_length |
|
self.num_attention_heads = num_attention_heads |
|
self.proj_groups = proj_groups |
|
self.hyena_filter_groups = hyena_filter_groups |
|
self.split_k0 = split_k0 |
|
self.column_split_hyena = column_split_hyena |
|
self.column_split = column_split |
|
self.model_parallel_size = model_parallel_size |
|
self.pipe_parallel_size = pipe_parallel_size |
|
self.short_filter_bias = short_filter_bias |
|
self.mha_out_proj_bias = mha_out_proj_bias |
|
self.qkv_proj_bias = qkv_proj_bias |
|
self.final_norm = final_norm |
|
self.use_cache = use_cache |
|
self.use_flash_attention_2 = use_flash_attention_2 |
|
self.use_flash_rmsnorm = use_flash_rmsnorm |
|
self.use_flash_depthwise = use_flash_depthwise |
|
self.use_flashfft = use_flashfft |
|
self.inference_mode = inference_mode |
|
self.prefill_style = prefill_style |
|
self.max_seqlen = max_seqlen |
|
self.eps = eps |
|
self.state_size = state_size |
|
self.rotary_emb_base = rotary_emb_base |
|
self.smeared_gqa = smeared_gqa |
|
self.make_vocab_size_divisible_by = make_vocab_size_divisible_by |
|
self.log_intermediate_values = log_intermediate_values |
|
super().__init__(**kwargs) |
|
|
|
def to_dict(self): |
|
return {attr: getattr(self, attr) for attr in self.__dict__} |
|
|
|
@classmethod |
|
def from_original_config(cls, config_path, **kwargs): |
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
|
|
return cls(**config, **kwargs) |
|
|