Commit
•
1a60346
1
Parent(s):
96d43a8
Upload HyenaDNAForCausalLM
Browse files- modeling_hyena.py +2 -2
modeling_hyena.py
CHANGED
@@ -19,8 +19,8 @@ def fftconv(u, k, D):
|
|
19 |
seqlen = u.shape[-1]
|
20 |
fft_size = 2 * seqlen
|
21 |
|
22 |
-
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
|
23 |
-
u_f = torch.fft.rfft(u.to(dtype=
|
24 |
|
25 |
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
|
26 |
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
|
|
|
19 |
seqlen = u.shape[-1]
|
20 |
fft_size = 2 * seqlen
|
21 |
|
22 |
+
k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
|
23 |
+
u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
|
24 |
|
25 |
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
|
26 |
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
|