Zymrael commited on
Commit
b18b5f1
1 Parent(s): 86e226f

fix: force correct dtype in HF load

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +3 -0
modeling_hyena.py CHANGED
@@ -46,6 +46,9 @@ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
46
  self.vocab_size = vocab_size
47
  self.post_init()
48
 
 
 
 
49
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
50
  self.backbone.gradient_checkpointing = enable
51
 
 
46
  self.vocab_size = vocab_size
47
  self.post_init()
48
 
49
+ def post_init(self):
50
+ self.backbone.to_bfloat16_except_poles_residues()
51
+
52
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
53
  self.backbone.gradient_checkpointing = enable
54