Stanislas commited on
Commit
13193de
1 Parent(s): 5fcf835

Fix precision error

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +4 -20
modeling_chatglm.py CHANGED
@@ -3,9 +3,7 @@
3
  import math
4
  import copy
5
  import warnings
6
- import re
7
  import sys
8
- import functools
9
  import torch
10
  import torch.utils.checkpoint
11
  import torch.nn.functional as F
@@ -177,14 +175,13 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
177
 
178
 
179
  class RMSNorm(torch.nn.Module):
180
- def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, quantized=False, **kwargs):
181
  super().__init__()
182
  self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
183
  self.eps = eps
184
- self.quantized = quantized
185
 
186
  def forward(self, hidden_states: torch.Tensor):
187
- if not self.quantized:
188
  norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
189
  x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
190
  return self.weight * x_normed
@@ -521,14 +518,7 @@ class GLMBlock(torch.nn.Module):
521
 
522
  self.fp32_residual_connection = config.fp32_residual_connection
523
 
524
- if config.rmsnorm:
525
- if config.quantization_bit != 0:
526
- LayerNormFunc = functools.partial(RMSNorm, quantized=True)
527
- else:
528
- LayerNormFunc = RMSNorm
529
- else:
530
- LayerNormFunc = LayerNorm
531
-
532
  # Layernorm on the input data.
533
  self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
534
  dtype=config.torch_dtype)
@@ -606,13 +596,7 @@ class GLMTransformer(torch.nn.Module):
606
  self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
607
 
608
  if self.post_layer_norm:
609
- if config.rmsnorm:
610
- if config.quantization_bit != 0:
611
- LayerNormFunc = functools.partial(RMSNorm, quantized=True)
612
- else:
613
- LayerNormFunc = RMSNorm
614
- else:
615
- LayerNormFunc = LayerNorm
616
  # Final layer norm before output.
617
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
618
  dtype=config.torch_dtype)
 
3
  import math
4
  import copy
5
  import warnings
 
6
  import sys
 
7
  import torch
8
  import torch.utils.checkpoint
9
  import torch.nn.functional as F
 
175
 
176
 
177
  class RMSNorm(torch.nn.Module):
178
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
179
  super().__init__()
180
  self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
181
  self.eps = eps
 
182
 
183
  def forward(self, hidden_states: torch.Tensor):
184
+ if hidden_states == torch.bfloat16:
185
  norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
186
  x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
187
  return self.weight * x_normed
 
518
 
519
  self.fp32_residual_connection = config.fp32_residual_connection
520
 
521
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
 
 
 
 
 
 
 
522
  # Layernorm on the input data.
523
  self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
524
  dtype=config.torch_dtype)
 
596
  self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
597
 
598
  if self.post_layer_norm:
599
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
 
 
 
 
 
 
600
  # Final layer norm before output.
601
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
602
  dtype=config.torch_dtype)