Fix hidden_states dtype matching

#4
by li-plus - opened
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -181,7 +181,7 @@ class RMSNorm(torch.nn.Module):
181
  self.eps = eps
182
 
183
  def forward(self, hidden_states: torch.Tensor):
184
- if hidden_states == torch.bfloat16:
185
  norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
186
  x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
187
  return self.weight * x_normed
 
181
  self.eps = eps
182
 
183
  def forward(self, hidden_states: torch.Tensor):
184
+ if hidden_states.dtype == torch.bfloat16:
185
  norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
186
  x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
187
  return self.weight * x_normed