Stanislas li-plus commited on
Commit
3cb3f8f
1 Parent(s): 18f9ab1

Fix hidden_states dtype matching (#4)

Browse files

- Fix hidden_states dtype matching (2caddb7092dfd83a6ea9db7046651661f4927999)


Co-authored-by: Jiahao Li <li-plus@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -182,7 +182,7 @@ class RMSNorm(torch.nn.Module):
182
  self.eps = eps
183
 
184
  def forward(self, hidden_states: torch.Tensor):
185
- if hidden_states == torch.bfloat16:
186
  norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
187
  x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
188
  return self.weight * x_normed
 
182
  self.eps = eps
183
 
184
  def forward(self, hidden_states: torch.Tensor):
185
+ if hidden_states.dtype == torch.bfloat16:
186
  norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
187
  x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
188
  return self.weight * x_normed