Rocketknight1 HF staff commited on
Commit
b776073
1 Parent(s): 9bd693a

Upload HyenaDNAForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +22 -17
modeling_hyena.py CHANGED
@@ -19,8 +19,8 @@ def fftconv(u, k, D):
19
  seqlen = u.shape[-1]
20
  fft_size = 2 * seqlen
21
 
22
- k_f = torch.fft.rfft(k, n=fft_size) / fft_size
23
- u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
24
 
25
  if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
  y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
@@ -60,11 +60,9 @@ class HyenaPositionalEmbedding(nn.Module):
60
  w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
 
62
  f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
- # Matt: This is just Euler's formula, so if complex64 is a problem it can be replaced
64
- # by separate sin() and cos() calls.
65
- z = torch.exp(-1j * f * w)
66
- z = torch.cat([t, z.real, z.imag], dim=-1)
67
- # TODO Set z's LR to lr_pos_emb
68
  self.z = nn.Parameter(z, requires_grad=True)
69
  self.register_buffer("t", t)
70
 
@@ -147,7 +145,7 @@ class HyenaFilter(nn.Module):
147
 
148
  def filter(self, L, *args, **kwargs):
149
  z, t = self.pos_emb(L)
150
- h = self.implicit_filter(z)
151
  h = self.modulation(t, h)
152
  return h
153
 
@@ -349,8 +347,15 @@ class HyenaDNAPreTrainedModel(PreTrainedModel):
349
  supports_gradient_checkpointing = True
350
  _no_split_modules = ["HyenaBlock"]
351
  _skip_keys_device_placement = "past_key_values"
352
-
353
- def _init_weights(self, initializer_range=0.02):
 
 
 
 
 
 
 
354
  # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
355
  # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
356
  # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
@@ -368,8 +373,8 @@ class HyenaDNAPreTrainedModel(PreTrainedModel):
368
 
369
 
370
  class HyenaDNAModel(HyenaDNAPreTrainedModel):
371
- def __init__(self, config) -> None:
372
- super().__init__(config)
373
 
374
  self.backbone = HyenaLMBackbone(config)
375
  self.config = config
@@ -395,8 +400,8 @@ class HyenaDNAModel(HyenaDNAPreTrainedModel):
395
 
396
  class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
397
 
398
- def __init__(self, config):
399
- super().__init__(config)
400
  self.hyena = HyenaDNAModel(config)
401
  vocab_size = config.vocab_size
402
  if vocab_size % config.pad_vocab_size_multiple != 0:
@@ -476,9 +481,9 @@ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
476
 
477
 
478
  class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
479
- def __init__(self, config):
480
- super().__init__(config)
481
- self.num_labels = config.num_labels
482
  self.hyena = HyenaDNAModel(config)
483
  self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
484
 
 
19
  seqlen = u.shape[-1]
20
  fft_size = 2 * seqlen
21
 
22
+ k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
23
+ u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
24
 
25
  if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
  y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
 
60
  w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
 
62
  f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
+
64
+ z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
65
+ # The original code sets z's LR to lr_pos_emb, which is 1e-5 by default
 
 
66
  self.z = nn.Parameter(z, requires_grad=True)
67
  self.register_buffer("t", t)
68
 
 
145
 
146
  def filter(self, L, *args, **kwargs):
147
  z, t = self.pos_emb(L)
148
+ h = self.implicit_filter(z.to(dtype=self.implicit_filter[0].weight.dtype))
149
  h = self.modulation(t, h)
150
  return h
151
 
 
347
  supports_gradient_checkpointing = True
348
  _no_split_modules = ["HyenaBlock"]
349
  _skip_keys_device_placement = "past_key_values"
350
+ _keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
351
+
352
+ def _init_weights(self, module, initializer_range=0.02):
353
+ if isinstance(module, nn.Linear):
354
+ nn.init.normal_(module.weight, std=initializer_range)
355
+ if module.bias is not None:
356
+ nn.init.zeros_(module.bias)
357
+ elif isinstance(module, nn.Embedding):
358
+ nn.init.normal_(module.weight, std=initializer_range)
359
  # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
360
  # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
361
  # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 
373
 
374
 
375
  class HyenaDNAModel(HyenaDNAPreTrainedModel):
376
+ def __init__(self, config, **kwargs) -> None:
377
+ super().__init__(config, **kwargs)
378
 
379
  self.backbone = HyenaLMBackbone(config)
380
  self.config = config
 
400
 
401
  class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
402
 
403
+ def __init__(self, config, **kwargs):
404
+ super().__init__(config, **kwargs)
405
  self.hyena = HyenaDNAModel(config)
406
  vocab_size = config.vocab_size
407
  if vocab_size % config.pad_vocab_size_multiple != 0:
 
481
 
482
 
483
  class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
484
+ def __init__(self, config, **kwargs):
485
+ super().__init__(config, **kwargs)
486
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
487
  self.hyena = HyenaDNAModel(config)
488
  self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
489