Rocketknight1 HF staff commited on
Commit
6d3bcdb
1 Parent(s): dd11615

Upload HyenaDNAForCausalLM

Browse files
Files changed (4) hide show
  1. config.json +35 -0
  2. configuration_hyena.py +88 -0
  3. model.safetensors +3 -0
  4. modeling_hyena.py +574 -0
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "hyenadna-tiny-1k-seqlen-hf",
3
+ "activation_freq": 10,
4
+ "architectures": [
5
+ "HyenaDNAForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_hyena.HyenaConfig",
9
+ "AutoModel": "modeling_hyena.HyenaDNAModel",
10
+ "AutoModelForCausalLM": "modeling_hyena.HyenaDNAForCausalLM",
11
+ "AutoModelForSequenceClassification": "modeling_hyena.HyenaDNAForSequenceClassification"
12
+ },
13
+ "d_inner": 512,
14
+ "d_model": 128,
15
+ "emb_dim": 5,
16
+ "embed_dropout": 0.1,
17
+ "filter_order": 64,
18
+ "hyena_dropout": 0.0,
19
+ "hyena_filter_dropout": 0.0,
20
+ "hyena_order": 2,
21
+ "initializer_range": 0.02,
22
+ "layer_norm_epsilon": 1e-05,
23
+ "max_seq_len": 1026,
24
+ "model_type": "hyenadna",
25
+ "n_layer": 2,
26
+ "num_inner_mlps": 2,
27
+ "pad_vocab_size_multiple": 8,
28
+ "short_filter_order": 3,
29
+ "tie_word_embeddings": false,
30
+ "torch_dtype": "float32",
31
+ "train_freq": true,
32
+ "transformers_version": "4.35.0.dev0",
33
+ "use_bias": true,
34
+ "vocab_size": 12
35
+ }
configuration_hyena.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import json
3
+
4
+
5
+ class HyenaConfig(PretrainedConfig):
6
+ model_type = "hyenadna"
7
+ def __init__(
8
+ self,
9
+ vocab_size=12,
10
+ d_model=256,
11
+ d_inner=None,
12
+ use_bias=True,
13
+ train_freq=True,
14
+ max_seq_len=1024,
15
+ emb_dim=3,
16
+ n_layer=12,
17
+ num_inner_mlps=2,
18
+ hyena_order=2,
19
+ short_filter_order=3,
20
+ filter_order=64,
21
+ activation_freq=1,
22
+ embed_dropout=0.1,
23
+ hyena_dropout=0.0,
24
+ hyena_filter_dropout=0.0,
25
+ layer_norm_epsilon=1e-5,
26
+ initializer_range=0.02,
27
+ pad_vocab_size_multiple=8,
28
+ **kwargs,
29
+ ):
30
+ self.vocab_size = vocab_size
31
+ self.d_model = d_model
32
+ if d_inner is None:
33
+ self.d_inner = 4 * d_model
34
+ else:
35
+ self.d_inner = d_inner
36
+ self.use_bias = use_bias
37
+ self.train_freq = train_freq
38
+ self.max_seq_len = max_seq_len
39
+ self.emb_dim = emb_dim
40
+ self.n_layer = n_layer
41
+ self.hyena_order = hyena_order
42
+ self.filter_order = filter_order
43
+ self.short_filter_order = short_filter_order
44
+ self.activation_freq = activation_freq
45
+ self.num_inner_mlps = num_inner_mlps
46
+ self.embed_dropout = embed_dropout
47
+ self.hyena_dropout = hyena_dropout
48
+ self.hyena_filter_dropout = hyena_filter_dropout
49
+ self.layer_norm_epsilon = layer_norm_epsilon
50
+ self.initializer_range = initializer_range
51
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
52
+ super().__init__(**kwargs)
53
+
54
+ @classmethod
55
+ def from_original_config(cls, config_path, **kwargs):
56
+ with open(config_path, "r") as f:
57
+ config = json.load(f)
58
+
59
+ vocab_size = config["vocab_size"]
60
+ d_model = config["d_model"]
61
+ d_inner = config["d_inner"]
62
+ max_seq_len = config["layer"]["l_max"]
63
+ emb_dim = config["layer"]["emb_dim"]
64
+ filter_order = config["layer"]["filter_order"]
65
+ if "local_order" in config["layer"]:
66
+ short_filter_order = config["layer"]["local_order"]
67
+ elif "short_filter_order" in config["layer"]:
68
+ short_filter_order = config["layer"]["short_filter_order"]
69
+ else:
70
+ short_filter_order = 3
71
+ n_layer = config["n_layer"]
72
+ activation_freq = config["layer"]["w"]
73
+ embed_dropout = config["embed_dropout"]
74
+ pad_vocab_size_multiple = config["pad_vocab_size_multiple"]
75
+ return cls(vocab_size=vocab_size,
76
+ d_model=d_model,
77
+ d_inner=d_inner,
78
+ max_seq_len=max_seq_len,
79
+ emb_dim=emb_dim,
80
+ filter_order=filter_order,
81
+ short_filter_order=short_filter_order,
82
+ n_layer=n_layer,
83
+ activation_freq=activation_freq,
84
+ embed_dropout=embed_dropout,
85
+ pad_vocab_size_multiple=pad_vocab_size_multiple,
86
+ tie_word_embeddings=False,
87
+ **kwargs
88
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ce2146c21e9c4baa6bddc4998fd3d029903ae84a563bf80218644082194a12d
3
+ size 1809192
modeling_hyena.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """HyenaDNA custom code port to Hugging Face Hub"""
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from .configuration_hyena import HyenaConfig
9
+ from transformers import PreTrainedModel
10
+ from typing import Optional, Tuple, Union
11
+ from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput, BaseModelOutputWithNoAttention
12
+
13
+
14
+ def fftconv(u, k, D):
15
+ """
16
+ We apply a convolution through the fourier domain (from the Convolution Theorem)
17
+
18
+ """
19
+ seqlen = u.shape[-1]
20
+ fft_size = 2 * seqlen
21
+
22
+ k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
23
+ u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
24
+
25
+ if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
27
+
28
+ out = y + u * D.unsqueeze(-1)
29
+ return out.to(dtype=u.dtype)
30
+
31
+
32
+ @torch.jit.script
33
+ def mul_sum(q, y):
34
+ return (q * y).sum(dim=1)
35
+
36
+
37
+ class HyenaSin(nn.Module):
38
+ """The Sin activation function for the Hyena Filter function."""
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.freq = nn.Parameter(config.activation_freq * torch.ones(1, config.filter_order)) if config.train_freq else config.activation_freq * torch.ones(1, config.filter_order)
42
+
43
+ def forward(self, x):
44
+ return torch.sin(self.freq * x)
45
+
46
+
47
+ class HyenaPositionalEmbedding(nn.Module):
48
+ def __init__(self, config):
49
+ """Complex exponential positional embeddings for Hyena filters."""
50
+ super().__init__()
51
+
52
+ self.seq_len = config.max_seq_len
53
+ # The time embedding fed to the filteres is normalized so that t_f = 1
54
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
55
+
56
+ if config.emb_dim > 1:
57
+ bands = (config.emb_dim - 1) // 2
58
+ # To compute the right embeddings we use the "proper" linspace
59
+ t_rescaled = torch.linspace(0, self.seq_len - 1, self.seq_len)[None, :, None]
60
+ w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
+
62
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
+
64
+ z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
65
+ # The original code sets z's LR to lr_pos_emb, which is 1e-5 by default
66
+ self.z = nn.Parameter(z, requires_grad=True)
67
+ self.register_buffer("t", t)
68
+
69
+ def forward(self, L):
70
+ return self.z[:, :L], self.t[:, :L]
71
+
72
+
73
+ class HyenaExponentialModulation(nn.Module):
74
+ """The window function applied to the output of the (MLP) filter function."""
75
+ def __init__(
76
+ self,
77
+ d_model,
78
+ fast_decay_pct=0.3,
79
+ slow_decay_pct=1.5,
80
+ target=1e-2,
81
+ modulate: bool=True,
82
+ shift: float = 0.05,
83
+ **kwargs
84
+ ):
85
+ super().__init__()
86
+ self.modulate = modulate
87
+ self.shift = shift
88
+ max_decay = math.log(target) / fast_decay_pct
89
+ min_decay = math.log(target) / slow_decay_pct
90
+ deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
91
+ self.register_buffer("deltas", deltas)
92
+
93
+ def forward(self, t, x):
94
+ if self.modulate:
95
+ decay = torch.exp(-t * self.deltas.abs())
96
+ x = x * (decay + self.shift)
97
+ return x
98
+
99
+
100
+ class HyenaFilter(nn.Module):
101
+ def __init__(
102
+ self,
103
+ config,
104
+ **kwargs
105
+ ):
106
+ """
107
+ Implicit long filter with modulation.
108
+
109
+ Args:
110
+ d_model: number of channels in the input
111
+ emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
112
+ order: width of the FFN
113
+ num_inner_mlps: number of inner linear layers inside filter MLP
114
+
115
+ Note:
116
+ filter_dropout is not implemented
117
+ """
118
+ super().__init__()
119
+
120
+ self.d_model = config.d_model * (config.hyena_order - 1)
121
+ self.use_bias = config.use_bias
122
+ self.bias = nn.Parameter(torch.randn(self.d_model))
123
+ self.dropout = nn.Dropout(config.hyena_filter_dropout)
124
+
125
+ act = HyenaSin(config)
126
+ self.emb_dim = config.emb_dim
127
+ assert self.emb_dim % 2 != 0 and self.emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
128
+ self.seq_len = config.max_seq_len
129
+
130
+ self.pos_emb = HyenaPositionalEmbedding(config)
131
+
132
+ self.implicit_filter = nn.Sequential(
133
+ nn.Linear(self.emb_dim, config.filter_order),
134
+ act,
135
+ )
136
+ for i in range(config.num_inner_mlps):
137
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.filter_order))
138
+ self.implicit_filter.append(act)
139
+
140
+ self.implicit_filter.append(nn.Linear(config.filter_order, config.d_model, bias=False))
141
+
142
+ self.modulation = HyenaExponentialModulation(config.d_model)
143
+
144
+ self.normalized = False
145
+
146
+ def filter(self, L, *args, **kwargs):
147
+ z, t = self.pos_emb(L)
148
+ h = self.implicit_filter(z.to(dtype=self.implicit_filter[0].weight.dtype))
149
+ h = self.modulation(t, h)
150
+ return h
151
+
152
+ def forward(self, x, L, k=None, bias=None, *args, **kwargs):
153
+ if k is None: k = self.filter(L)
154
+
155
+ # Ensure compatibility with filters that return a tuple
156
+ k = k[0] if type(k) is tuple else k
157
+
158
+ y = fftconv(x, k, bias)
159
+ return y
160
+
161
+
162
+ class HyenaOperator(nn.Module):
163
+ def __init__(
164
+ self,
165
+ config,
166
+ **filter_args,
167
+ ):
168
+ r"""
169
+ Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
170
+
171
+ Args:
172
+ d_model (int): Dimension of the input and output embeddings (width of the layer)
173
+ l_max: (int): Maximum input sequence length. Defaults to None
174
+ order: (int): Depth of the Hyena recurrence. Defaults to 2
175
+ dropout: (float): Dropout probability. Defaults to 0.0
176
+ filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
177
+ """
178
+ super().__init__()
179
+
180
+ self.d_model = config.d_model
181
+ self.l_max = config.max_seq_len
182
+ self.order = config.hyena_order
183
+ inner_width = config.d_model * (self.order + 1)
184
+ self.dropout = nn.Dropout(config.hyena_dropout)
185
+ self.in_proj = nn.Linear(self.d_model, inner_width)
186
+ self.out_proj = nn.Linear(self.d_model, self.d_model)
187
+
188
+ self.short_filter = nn.Conv1d(
189
+ inner_width,
190
+ inner_width,
191
+ config.short_filter_order,
192
+ padding=2,
193
+ groups=inner_width
194
+ )
195
+ self.filter_fn = HyenaFilter(config)
196
+
197
+ def forward(self, u):
198
+ l = u.size(-2)
199
+ l_filter = min(l, self.l_max)
200
+ u = self.in_proj(u).transpose(1, 2)
201
+
202
+ uc = self.short_filter(u)[...,:l_filter]
203
+ *x, v = uc.split(self.d_model, dim=1)
204
+
205
+ k = self.filter_fn.filter(l_filter)[0]
206
+ k = k.transpose(0, 1).reshape(self.order - 1, self.d_model, l_filter)
207
+ bias = self.filter_fn.bias.reshape(self.order - 1, self.d_model)
208
+
209
+ for o, x_i in enumerate(reversed(x[1:])):
210
+ v = self.dropout(v * x_i)
211
+ v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
212
+
213
+ y = (v * x[0]).transpose(1, 2)
214
+
215
+ y = self.out_proj(y)
216
+ return y
217
+
218
+ class HyenaMlp(nn.Module):
219
+
220
+ def __init__(self, config):
221
+ """
222
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
223
+ """
224
+ super().__init__()
225
+ in_features = config.d_model
226
+ hidden_features = config.d_inner
227
+ self.fc1 = nn.Linear(in_features, hidden_features)
228
+ self.fc2 = nn.Linear(hidden_features, config.d_model)
229
+
230
+ def forward(self, x):
231
+ y = self.fc1(x)
232
+ y = F.gelu(y, approximate="tanh")
233
+ y = self.fc2(y)
234
+ return y
235
+
236
+ class HyenaBlock(nn.Module):
237
+
238
+ def __init__(self, config):
239
+ """
240
+ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
241
+ For prenorm=True, this Block has a slightly different structure compared to a regular
242
+ prenorm Transformer block.
243
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
244
+ [Ref: https://arxiv.org/abs/2002.04745]
245
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
246
+ the hidden_states (output of the MLP) and the residual.
247
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
248
+ The residual needs to be provided (except for the very first block).
249
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
250
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
251
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
252
+ This is for performance reason: for post-norm architecture, returning the input allows us
253
+ to fuse the backward of nn.Linear with the residual connection.
254
+ """
255
+ super().__init__()
256
+ self.mixer = HyenaOperator(config)
257
+ self.norm1 = nn.LayerNorm(config.d_model)
258
+ self.mlp = HyenaMlp(config)
259
+ self.norm2 = nn.LayerNorm(config.d_model)
260
+
261
+ def forward(self, hidden_states):
262
+ r"""Pass the input through the encoder layer.
263
+ Args:
264
+ hidden_states: the sequence to the encoder layer (required).
265
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
266
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
267
+ before applying the query projection. Useful for e.g., ViT where we only care
268
+ about the CLS token in the last layer.
269
+ """
270
+ residual = hidden_states
271
+ residual = residual.to(torch.float32)
272
+ hyena_normed = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
273
+ hidden_states = self.mixer(hyena_normed)
274
+ # Tested above here and all is equivalent. That means the mixer is fine!!!
275
+ residual = hidden_states + residual
276
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
277
+ residual = residual.to(torch.float32)
278
+
279
+ hidden_states = self.mlp(hidden_states)
280
+ return hidden_states + residual
281
+
282
+
283
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
284
+
285
+
286
+ class HyenaEmbeddings(nn.Module):
287
+
288
+ def __init__(self, config, padding_idx=None):
289
+ """
290
+ If max_position_embeddings <= 0, there's no position embeddings
291
+ If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
292
+ the project up to embed_dim
293
+ """
294
+ super().__init__()
295
+ vocab_size = config.vocab_size
296
+ if vocab_size % config.pad_vocab_size_multiple != 0:
297
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
298
+ self.word_embeddings = nn.Embedding(vocab_size, config.d_model, padding_idx=padding_idx)
299
+
300
+ def forward(self, input_ids):
301
+ """
302
+ input_ids: (batch, seqlen)
303
+ """
304
+ embeddings = self.word_embeddings(input_ids)
305
+ return embeddings
306
+
307
+ class HyenaLMBackbone(nn.Module):
308
+
309
+ def __init__(self, config) -> None:
310
+ super().__init__()
311
+ # note max_position_embeddings is 0 for Hyena, and therefore isn't used
312
+ self.embeddings = HyenaEmbeddings(config)
313
+ self.dropout = nn.Dropout(config.embed_dropout)
314
+
315
+ self.layers = nn.ModuleList([HyenaBlock(config) for i in range(config.n_layer)])
316
+
317
+ self.ln_f = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
318
+ self.gradient_checkpointing = False
319
+
320
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
321
+ all_hidden_states = []
322
+ if inputs_embeds is not None:
323
+ hidden_states = inputs_embeds
324
+ else:
325
+ hidden_states = self.embeddings(input_ids)
326
+ if output_hidden_states:
327
+ all_hidden_states.append(hidden_states)
328
+
329
+ for layer in self.layers:
330
+ if self.gradient_checkpointing and self.training:
331
+ hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states)
332
+ else:
333
+ hidden_states = layer(hidden_states)
334
+ if output_hidden_states:
335
+ all_hidden_states.append(hidden_states)
336
+
337
+ hidden_states = self.ln_f(hidden_states.to(dtype=self.ln_f.weight.dtype))
338
+ if output_hidden_states:
339
+ all_hidden_states.append(hidden_states)
340
+
341
+ return hidden_states, all_hidden_states
342
+
343
+
344
+ class HyenaDNAPreTrainedModel(PreTrainedModel):
345
+ config_class = HyenaConfig
346
+ base_model_prefix = "hyena"
347
+ supports_gradient_checkpointing = True
348
+ _no_split_modules = ["HyenaBlock"]
349
+ _skip_keys_device_placement = "past_key_values"
350
+ _keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
351
+
352
+ def _init_weights(self, module, initializer_range=0.02):
353
+ if isinstance(module, nn.Linear):
354
+ nn.init.normal_(module.weight, std=initializer_range)
355
+ if module.bias is not None:
356
+ nn.init.zeros_(module.bias)
357
+ elif isinstance(module, nn.Embedding):
358
+ nn.init.normal_(module.weight, std=initializer_range)
359
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
360
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
361
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
362
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
363
+ #
364
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
365
+ for name, p in self.named_parameters():
366
+ if name in ["out_proj.weight", "fc2.weight"]:
367
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
368
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
369
+ # If using GLU activation for now, we scale the std by 2
370
+ elif name in ["output_linear.0.weight"]:
371
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
372
+ nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
373
+
374
+
375
+ class HyenaDNAModel(HyenaDNAPreTrainedModel):
376
+ def __init__(self, config, **kwargs) -> None:
377
+ super().__init__(config, **kwargs)
378
+
379
+ self.backbone = HyenaLMBackbone(config)
380
+ self.config = config
381
+
382
+ # Initialize weights and apply final processing
383
+ self.post_init()
384
+
385
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=None, return_dict=None):
386
+ output_hidden_states = (
387
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
388
+ )
389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
390
+
391
+ hidden_states, all_hidden_states = self.backbone(input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states)
392
+ if return_dict:
393
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states,
394
+ hidden_states=all_hidden_states if output_hidden_states else None)
395
+ elif output_hidden_states:
396
+ return hidden_states, all_hidden_states
397
+ else:
398
+ return hidden_states
399
+
400
+
401
+ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
402
+
403
+ def __init__(self, config, **kwargs):
404
+ super().__init__(config, **kwargs)
405
+ self.hyena = HyenaDNAModel(config)
406
+ vocab_size = config.vocab_size
407
+ if vocab_size % config.pad_vocab_size_multiple != 0:
408
+ vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
409
+ self.vocab_size = vocab_size
410
+ self.lm_head = nn.Linear(config.d_model, vocab_size, bias=False)
411
+
412
+ # Initialize weights and apply final processing
413
+ self.post_init()
414
+
415
+ def get_input_embeddings(self):
416
+ return self.hyena.backbone.embeddings.word_embeddings
417
+
418
+ def set_input_embeddings(self, value):
419
+ self.hyena.backbone.embeddings.word_embeddings = value
420
+
421
+ def get_output_embeddings(self):
422
+ return self.lm_head
423
+
424
+ def set_output_embeddings(self, new_embeddings):
425
+ self.lm_head = new_embeddings
426
+
427
+ def set_decoder(self, decoder):
428
+ self.hyena = decoder
429
+
430
+ def get_decoder(self):
431
+ return self.hyena
432
+
433
+ def forward(
434
+ self,
435
+ input_ids: torch.LongTensor = None,
436
+ inputs_embeds: Optional[torch.FloatTensor] = None,
437
+ labels: Optional[torch.LongTensor] = None,
438
+ output_hidden_states: Optional[bool] = None,
439
+ return_dict: Optional[bool] = None,
440
+ ) -> Union[Tuple, CausalLMOutput]:
441
+
442
+ output_hidden_states = (
443
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
444
+ )
445
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
446
+
447
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
448
+ outputs = self.hyena(
449
+ input_ids=input_ids,
450
+ inputs_embeds=inputs_embeds,
451
+ output_hidden_states=output_hidden_states,
452
+ return_dict=return_dict,
453
+ )
454
+
455
+ hidden_states = outputs[0]
456
+ logits = self.lm_head(hidden_states)
457
+ logits = logits.float()
458
+
459
+ loss = None
460
+ if labels is not None:
461
+ # Shift so that tokens < n predict n
462
+ shift_logits = logits[..., :-1, :].contiguous()
463
+ shift_labels = labels[..., 1:].contiguous()
464
+ # Flatten the tokens
465
+ loss_fct = nn.CrossEntropyLoss()
466
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
467
+ shift_labels = shift_labels.view(-1)
468
+ # Enable model parallelism
469
+ shift_labels = shift_labels.to(shift_logits.device)
470
+ loss = loss_fct(shift_logits, shift_labels)
471
+
472
+ if not return_dict:
473
+ output = (logits,) + outputs[1:]
474
+ return (loss,) + output if loss is not None else output
475
+
476
+ return CausalLMOutput(
477
+ loss=loss,
478
+ logits=logits,
479
+ hidden_states=outputs.hidden_states,
480
+ )
481
+
482
+
483
+ class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
484
+ def __init__(self, config, **kwargs):
485
+ super().__init__(config, **kwargs)
486
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
487
+ self.hyena = HyenaDNAModel(config)
488
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ return self.hyena.backbone.embeddings.word_embeddings
495
+
496
+ def set_input_embeddings(self, value):
497
+ self.hyena.backbone.embeddings.word_embeddings = value
498
+
499
+ def forward(
500
+ self,
501
+ input_ids: torch.LongTensor = None,
502
+ inputs_embeds: Optional[torch.FloatTensor] = None,
503
+ labels: Optional[torch.LongTensor] = None,
504
+ output_hidden_states: Optional[bool] = None,
505
+ return_dict: Optional[bool] = None,
506
+ ) -> Union[Tuple, SequenceClassifierOutput]:
507
+ r"""
508
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
509
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
510
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
511
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
512
+ """
513
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
514
+
515
+ transformer_outputs = self.hyena(
516
+ input_ids,
517
+ inputs_embeds=inputs_embeds,
518
+ output_hidden_states=output_hidden_states,
519
+ return_dict=return_dict,
520
+ )
521
+ hidden_states = transformer_outputs[0]
522
+ logits = self.score(hidden_states)
523
+
524
+ if input_ids is not None:
525
+ batch_size = input_ids.shape[0]
526
+ else:
527
+ batch_size = inputs_embeds.shape[0]
528
+
529
+ if self.config.pad_token_id is None and batch_size != 1:
530
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
531
+ if self.config.pad_token_id is None:
532
+ sequence_lengths = -1
533
+ else:
534
+ if input_ids is not None:
535
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
536
+ logits.device
537
+ )
538
+ else:
539
+ sequence_lengths = -1
540
+
541
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
542
+
543
+ loss = None
544
+ if labels is not None:
545
+ labels = labels.to(logits.device)
546
+ if self.config.problem_type is None:
547
+ if self.num_labels == 1:
548
+ self.config.problem_type = "regression"
549
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
550
+ self.config.problem_type = "single_label_classification"
551
+ else:
552
+ self.config.problem_type = "multi_label_classification"
553
+
554
+ if self.config.problem_type == "regression":
555
+ loss_fct = nn.MSELoss()
556
+ if self.num_labels == 1:
557
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
558
+ else:
559
+ loss = loss_fct(pooled_logits, labels)
560
+ elif self.config.problem_type == "single_label_classification":
561
+ loss_fct = nn.CrossEntropyLoss()
562
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
563
+ elif self.config.problem_type == "multi_label_classification":
564
+ loss_fct = nn.BCEWithLogitsLoss()
565
+ loss = loss_fct(pooled_logits, labels)
566
+ if not return_dict:
567
+ output = (pooled_logits,) + transformer_outputs[1:]
568
+ return ((loss,) + output) if loss is not None else output
569
+
570
+ return SequenceClassifierOutput(
571
+ loss=loss,
572
+ logits=pooled_logits,
573
+ hidden_states=transformer_outputs.hidden_states,
574
+ )