Fix hidden_states dtype matching
Browse files- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
@@ -181,7 +181,7 @@ class RMSNorm(torch.nn.Module):
|
|
181 |
self.eps = eps
|
182 |
|
183 |
def forward(self, hidden_states: torch.Tensor):
|
184 |
-
if hidden_states == torch.bfloat16:
|
185 |
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
186 |
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
187 |
return self.weight * x_normed
|
|
|
181 |
self.eps = eps
|
182 |
|
183 |
def forward(self, hidden_states: torch.Tensor):
|
184 |
+
if hidden_states.dtype == torch.bfloat16:
|
185 |
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
186 |
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
187 |
return self.weight * x_normed
|