from transformers import PretrainedConfig import json class StripedHyenaConfig(PretrainedConfig): model_type = "stripedhyena" def __init__( self, vocab_size=32000, hidden_size=4096, num_filters=4096, inner_mlp_size=14336, attn_layer_idxs=[], hyena_layer_idxs=[], num_layers=32, tie_embeddings=False, short_filter_length=3, num_attention_heads=32, proj_groups=4, hyena_filter_groups=1, split_k0=True, column_split_hyena=True, column_split=False, model_parallel_size=1, pipe_parallel_size=1, short_filter_bias=True, mha_out_proj_bias=False, qkv_proj_bias=False, final_norm=True, use_cache=True, use_flash_attention_2=True, use_flash_rmsnorm=True, use_flash_depthwise=False, use_flashfft=False, inference_mode=False, prefill_style="fft", max_seqlen=32768, eps=1e-5, state_size=2, rotary_emb_base=500000, smeared_gqa=False, make_vocab_size_divisible_by=8, log_intermediate_values=False, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_filters = num_filters self.inner_mlp_size = inner_mlp_size self.attn_layer_idxs = attn_layer_idxs self.hyena_layer_idxs = hyena_layer_idxs self.num_layers = num_layers self.tie_embeddings = tie_embeddings self.short_filter_length = short_filter_length self.num_attention_heads = num_attention_heads self.proj_groups = proj_groups self.hyena_filter_groups = hyena_filter_groups self.split_k0 = split_k0 self.column_split_hyena = column_split_hyena self.column_split = column_split self.model_parallel_size = model_parallel_size self.pipe_parallel_size = pipe_parallel_size self.short_filter_bias = short_filter_bias self.mha_out_proj_bias = mha_out_proj_bias self.qkv_proj_bias = qkv_proj_bias self.final_norm = final_norm self.use_cache = use_cache self.use_flash_attention_2 = use_flash_attention_2 self.use_flash_rmsnorm = use_flash_rmsnorm self.use_flash_depthwise = use_flash_depthwise self.use_flashfft = use_flashfft self.inference_mode = inference_mode self.prefill_style = prefill_style self.max_seqlen = max_seqlen self.eps = eps self.state_size = state_size self.rotary_emb_base = rotary_emb_base self.smeared_gqa = smeared_gqa self.make_vocab_size_divisible_by = make_vocab_size_divisible_by self.log_intermediate_values = log_intermediate_values super().__init__(**kwargs) def to_dict(self): return {attr: getattr(self, attr) for attr in self.__dict__} @classmethod def from_original_config(cls, config_path, **kwargs): with open(config_path, "r") as f: config = json.load(f) return cls(**config, **kwargs)