Zymrael commited on
Commit
60eb4c7
1 Parent(s): 4a59285

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -2
model.py CHANGED
@@ -339,8 +339,10 @@ class StripedHyena(nn.Module):
339
  self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)
340
 
341
  if config.get("use_flashfft", "False"):
342
- from flashfftconv import FlashFFTConv
343
-
 
 
344
  self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
345
  else:
346
  self.flash_fft = None
 
339
  self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)
340
 
341
  if config.get("use_flashfft", "False"):
342
+ try:
343
+ from flashfftconv import FlashFFTConv
344
+ except:
345
+ raise ImportError
346
  self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
347
  else:
348
  self.flash_fft = None