Crystalcareai commited on
Commit
cd900ce
·
verified ·
1 Parent(s): 695a26f

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +3 -0
modeling_quiet.py CHANGED
@@ -1662,6 +1662,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1662
  prev_rm_logits = rm_logits # for policy gradient
1663
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1664
 
 
 
 
1665
  if ahead_idx == 0:
1666
  hidden_states_lm = hidden_states
1667
  logits = self.lm_head(hidden_states_lm)
 
1662
  prev_rm_logits = rm_logits # for policy gradient
1663
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1664
 
1665
+ hidden_states_lm = hidden_states_lm.to(self.lm_head.weight.dtype)
1666
+ logits = self.lm_head(hidden_states_lm)
1667
+
1668
  if ahead_idx == 0:
1669
  hidden_states_lm = hidden_states
1670
  logits = self.lm_head(hidden_states_lm)