efederici commited on
Commit
48db954
1 Parent(s): 31eb985

Update norm.py

Browse files
Files changed (1) hide show
  1. norm.py +1 -1
norm.py CHANGED
@@ -25,7 +25,7 @@ class LPLayerNorm(torch.nn.LayerNorm):
25
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26
 
27
  def rms_norm(x, weight=None, eps=1e-05):
28
- output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29
  if weight is not None:
30
  return output * weight
31
  return output
 
25
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26
 
27
  def rms_norm(x, weight=None, eps=1e-05):
28
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29
  if weight is not None:
30
  return output * weight
31
  return output