from transformers import PretrainedConfig import json class HyenaConfig(PretrainedConfig): model_type = "hyenadna" def __init__( self, vocab_size=12, d_model=256, d_inner=None, use_bias=True, train_freq=True, max_seq_len=1024, emb_dim=3, n_layer=12, num_inner_mlps=2, hyena_order=2, short_filter_order=3, filter_order=64, activation_freq=1, embed_dropout=0.1, hyena_dropout=0.0, hyena_filter_dropout=0.0, layer_norm_epsilon=1e-5, initializer_range=0.02, pad_vocab_size_multiple=8, **kwargs, ): self.vocab_size = vocab_size self.d_model = d_model if d_inner is None: self.d_inner = 4 * d_model else: self.d_inner = d_inner self.use_bias = use_bias self.train_freq = train_freq self.max_seq_len = max_seq_len self.emb_dim = emb_dim self.n_layer = n_layer self.hyena_order = hyena_order self.filter_order = filter_order self.short_filter_order = short_filter_order self.activation_freq = activation_freq self.num_inner_mlps = num_inner_mlps self.embed_dropout = embed_dropout self.hyena_dropout = hyena_dropout self.hyena_filter_dropout = hyena_filter_dropout self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.pad_vocab_size_multiple = pad_vocab_size_multiple super().__init__(**kwargs) @classmethod def from_original_config(cls, config_path, **kwargs): with open(config_path, "r") as f: config = json.load(f) vocab_size = config["vocab_size"] d_model = config["d_model"] d_inner = config["d_inner"] max_seq_len = config["layer"]["l_max"] emb_dim = config["layer"]["emb_dim"] filter_order = config["layer"]["filter_order"] if "local_order" in config["layer"]: short_filter_order = config["layer"]["local_order"] elif "short_filter_order" in config["layer"]: short_filter_order = config["layer"]["short_filter_order"] else: short_filter_order = 3 n_layer = config["n_layer"] activation_freq = config["layer"]["w"] embed_dropout = config["embed_dropout"] pad_vocab_size_multiple = config["pad_vocab_size_multiple"] return cls(vocab_size=vocab_size, d_model=d_model, d_inner=d_inner, max_seq_len=max_seq_len, emb_dim=emb_dim, filter_order=filter_order, short_filter_order=short_filter_order, n_layer=n_layer, activation_freq=activation_freq, embed_dropout=embed_dropout, pad_vocab_size_multiple=pad_vocab_size_multiple, tie_word_embeddings=False, **kwargs )