fix rmsnorm init weight bug.
Browse filesUsing torch.ones to init rmsnorm weight. And torch.empty gets random weight tensor, which maybe out of float value limits.
- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
@@ -181,7 +181,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
|
181 |
class RMSNorm(torch.nn.Module):
|
182 |
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
183 |
super().__init__()
|
184 |
-
self.weight = torch.nn.Parameter(torch.
|
185 |
self.eps = eps
|
186 |
|
187 |
def forward(self, hidden_states: torch.Tensor):
|
|
|
181 |
class RMSNorm(torch.nn.Module):
|
182 |
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
183 |
super().__init__()
|
184 |
+
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype))
|
185 |
self.eps = eps
|
186 |
|
187 |
def forward(self, hidden_states: torch.Tensor):
|