Matt commited on
Commit
e8c1eff
1 Parent(s): 27cdeb1

Correctly mark z as a buffer

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +2 -2
modeling_hyena.py CHANGED
@@ -62,8 +62,8 @@ class HyenaPositionalEmbedding(nn.Module):
62
  f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
 
64
  z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
65
- # The original code sets z's LR to lr_pos_emb, which is 1e-5 by default
66
- self.z = nn.Parameter(z, requires_grad=True)
67
  self.register_buffer("t", t)
68
 
69
  def forward(self, L):
 
62
  f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
 
64
  z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
65
+
66
+ self.register_buffer("z", z)
67
  self.register_buffer("t", t)
68
 
69
  def forward(self, L):