Stanislas commited on
Commit
5fcf835
1 Parent(s): 5a26c4f

Fix precision error

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +27 -8
modeling_chatglm.py CHANGED
@@ -5,7 +5,7 @@ import copy
5
  import warnings
6
  import re
7
  import sys
8
-
9
  import torch
10
  import torch.utils.checkpoint
11
  import torch.nn.functional as F
@@ -177,15 +177,21 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
177
 
178
 
179
  class RMSNorm(torch.nn.Module):
180
- def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
181
  super().__init__()
182
  self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
183
  self.eps = eps
 
184
 
185
  def forward(self, hidden_states: torch.Tensor):
186
- input_dtype = hidden_states.dtype
187
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
188
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 
 
 
 
 
189
 
190
  return (self.weight * hidden_states).to(input_dtype)
191
 
@@ -515,10 +521,17 @@ class GLMBlock(torch.nn.Module):
515
 
516
  self.fp32_residual_connection = config.fp32_residual_connection
517
 
518
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
 
 
 
 
 
 
 
519
  # Layernorm on the input data.
520
  self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
521
- dtype=config.torch_dtype)
522
 
523
  # Self attention.
524
  self.self_attention = SelfAttention(config, layer_number, device=device)
@@ -593,7 +606,13 @@ class GLMTransformer(torch.nn.Module):
593
  self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
594
 
595
  if self.post_layer_norm:
596
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
 
 
 
 
 
 
597
  # Final layer norm before output.
598
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
599
  dtype=config.torch_dtype)
 
5
  import warnings
6
  import re
7
  import sys
8
+ import functools
9
  import torch
10
  import torch.utils.checkpoint
11
  import torch.nn.functional as F
 
177
 
178
 
179
  class RMSNorm(torch.nn.Module):
180
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, quantized=False, **kwargs):
181
  super().__init__()
182
  self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
183
  self.eps = eps
184
+ self.quantized = quantized
185
 
186
  def forward(self, hidden_states: torch.Tensor):
187
+ if not self.quantized:
188
+ norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
189
+ x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
190
+ return self.weight * x_normed
191
+ else:
192
+ input_dtype = hidden_states.dtype
193
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
194
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
195
 
196
  return (self.weight * hidden_states).to(input_dtype)
197
 
 
521
 
522
  self.fp32_residual_connection = config.fp32_residual_connection
523
 
524
+ if config.rmsnorm:
525
+ if config.quantization_bit != 0:
526
+ LayerNormFunc = functools.partial(RMSNorm, quantized=True)
527
+ else:
528
+ LayerNormFunc = RMSNorm
529
+ else:
530
+ LayerNormFunc = LayerNorm
531
+
532
  # Layernorm on the input data.
533
  self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
534
+ dtype=config.torch_dtype)
535
 
536
  # Self attention.
537
  self.self_attention = SelfAttention(config, layer_number, device=device)
 
606
  self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
607
 
608
  if self.post_layer_norm:
609
+ if config.rmsnorm:
610
+ if config.quantization_bit != 0:
611
+ LayerNormFunc = functools.partial(RMSNorm, quantized=True)
612
+ else:
613
+ LayerNormFunc = RMSNorm
614
+ else:
615
+ LayerNormFunc = LayerNorm
616
  # Final layer norm before output.
617
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
618
  dtype=config.torch_dtype)