Update modeling_quiet.py
Browse files- modeling_quiet.py +3 -0
modeling_quiet.py
CHANGED
@@ -1662,6 +1662,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1662 |
prev_rm_logits = rm_logits # for policy gradient
|
1663 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1664 |
|
|
|
|
|
|
|
1665 |
if ahead_idx == 0:
|
1666 |
hidden_states_lm = hidden_states
|
1667 |
logits = self.lm_head(hidden_states_lm)
|
|
|
1662 |
prev_rm_logits = rm_logits # for policy gradient
|
1663 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1664 |
|
1665 |
+
hidden_states_lm = hidden_states_lm.to(self.lm_head.weight.dtype)
|
1666 |
+
logits = self.lm_head(hidden_states_lm)
|
1667 |
+
|
1668 |
if ahead_idx == 0:
|
1669 |
hidden_states_lm = hidden_states
|
1670 |
logits = self.lm_head(hidden_states_lm)
|