Zymrael commited on
Commit
4b7049f
1 Parent(s): b18b5f1

fix: force correct mixed dtype after HF load

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +2 -1
modeling_hyena.py CHANGED
@@ -45,8 +45,9 @@ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
45
  )
46
  self.vocab_size = vocab_size
47
  self.post_init()
 
48
 
49
- def post_init(self):
50
  self.backbone.to_bfloat16_except_poles_residues()
51
 
52
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
 
45
  )
46
  self.vocab_size = vocab_size
47
  self.post_init()
48
+ self.force_dtype()
49
 
50
+ def force_dtype(self):
51
  self.backbone.to_bfloat16_except_poles_residues()
52
 
53
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):