Update model.py
Browse files
model.py
CHANGED
@@ -22,7 +22,7 @@ try:
|
|
22 |
except ImportError:
|
23 |
"could not import swap_mha_rope from positional_embeddings.py"
|
24 |
|
25 |
-
from flashfftconv import
|
26 |
|
27 |
# dummy import to force huggingface to bundle the tokenizer
|
28 |
from .tokenizer import ByteTokenizer
|
@@ -122,7 +122,7 @@ class ParallelHyenaFilter(nn.Module):
|
|
122 |
self.data_dtype = None
|
123 |
|
124 |
if self.use_flash_depthwise:
|
125 |
-
self.fir_fn =
|
126 |
channels=3 * self.hidden_size,
|
127 |
kernel_size=self.short_filter_length,
|
128 |
padding=self.short_filter_length - 1,
|
|
|
22 |
except ImportError:
|
23 |
"could not import swap_mha_rope from positional_embeddings.py"
|
24 |
|
25 |
+
from flashfftconv import FlashDepthWiseConv1d
|
26 |
|
27 |
# dummy import to force huggingface to bundle the tokenizer
|
28 |
from .tokenizer import ByteTokenizer
|
|
|
122 |
self.data_dtype = None
|
123 |
|
124 |
if self.use_flash_depthwise:
|
125 |
+
self.fir_fn = FlashDepthWiseConv1d(
|
126 |
channels=3 * self.hidden_size,
|
127 |
kernel_size=self.short_filter_length,
|
128 |
padding=self.short_filter_length - 1,
|