zxdu20 commited on
Commit
0cfae21
1 Parent(s): aea6cef

Fix backward for quantization

Browse files
Files changed (2) hide show
  1. modeling_chatglm.py +6 -8
  2. quantization.py +2 -2
modeling_chatglm.py CHANGED
@@ -134,11 +134,11 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
134
 
135
 
136
  class PrefixEncoder(torch.nn.Module):
137
- r'''
138
  The torch.nn model to encode the prefix
139
  Input shape: (batch-size, prefix-length)
140
  Output shape: (batch-size, prefix-length, 2*layers*hidden)
141
- '''
142
  def __init__(self, config):
143
  super().__init__()
144
  self.prefix_projection = config.prefix_projection
@@ -148,7 +148,7 @@ class PrefixEncoder(torch.nn.Module):
148
  self.trans = torch.nn.Sequential(
149
  torch.nn.Linear(config.hidden_size, config.hidden_size),
150
  torch.nn.Tanh(),
151
- torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
152
  )
153
  else:
154
  self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
@@ -814,7 +814,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
814
  self.num_attention_heads,
815
  self.hidden_size // self.num_attention_heads
816
  )
817
- #seq_len, b, nh, hidden_size
818
  past_key_values = self.dropout(past_key_values)
819
  past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
820
  # past_key_values = [(v[0], v[1]) for v in past_key_values]
@@ -909,7 +909,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
909
  )
910
 
911
  if self.pre_seq_len is not None:
912
- prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
 
913
  prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
914
  attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
915
 
@@ -942,9 +943,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
942
  else:
943
  attention_mask = attention_mask.to(input_ids.device)
944
 
945
- if self.training:
946
- hidden_states = hidden_states.requires_grad_(True)
947
-
948
  for i, layer in enumerate(self.layers):
949
 
950
  if output_hidden_states:
 
134
 
135
 
136
  class PrefixEncoder(torch.nn.Module):
137
+ """
138
  The torch.nn model to encode the prefix
139
  Input shape: (batch-size, prefix-length)
140
  Output shape: (batch-size, prefix-length, 2*layers*hidden)
141
+ """
142
  def __init__(self, config):
143
  super().__init__()
144
  self.prefix_projection = config.prefix_projection
 
148
  self.trans = torch.nn.Sequential(
149
  torch.nn.Linear(config.hidden_size, config.hidden_size),
150
  torch.nn.Tanh(),
151
+ torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
152
  )
153
  else:
154
  self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
 
814
  self.num_attention_heads,
815
  self.hidden_size // self.num_attention_heads
816
  )
817
+ # seq_len, b, nh, hidden_size
818
  past_key_values = self.dropout(past_key_values)
819
  past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
820
  # past_key_values = [(v[0], v[1]) for v in past_key_values]
 
909
  )
910
 
911
  if self.pre_seq_len is not None:
912
+ prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
913
+ attention_mask.device)
914
  prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
915
  attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
916
 
 
943
  else:
944
  attention_mask = attention_mask.to(input_ids.device)
945
 
 
 
 
946
  for i, layer in enumerate(self.layers):
947
 
948
  if output_hidden_states:
quantization.py CHANGED
@@ -14,11 +14,11 @@ class W8A16Linear(torch.autograd.Function):
14
  @staticmethod
15
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
16
  ctx.inp_shape = inp.size()
17
- ctx.weight_shape = quant_w.size()
18
  ctx.weight_bit_width = weight_bit_width
19
  out_features = quant_w.size(0)
20
  inp = inp.contiguous().view(-1, inp.size(-1))
21
  weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
 
22
  output = inp.mm(weight.t())
23
  ctx.save_for_backward(inp, quant_w, scale_w)
24
  return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@@ -30,7 +30,7 @@ class W8A16Linear(torch.autograd.Function):
30
  grad_output = grad_output.contiguous().view(-1, weight.size(0))
31
  grad_input = grad_output.mm(weight)
32
  grad_weight = grad_output.t().mm(inp)
33
- return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
34
 
35
 
36
  class Kernel:
 
14
  @staticmethod
15
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
16
  ctx.inp_shape = inp.size()
 
17
  ctx.weight_bit_width = weight_bit_width
18
  out_features = quant_w.size(0)
19
  inp = inp.contiguous().view(-1, inp.size(-1))
20
  weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
21
+ ctx.weight_shape = weight.size()
22
  output = inp.mm(weight.t())
23
  ctx.save_for_backward(inp, quant_w, scale_w)
24
  return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
 
30
  grad_output = grad_output.contiguous().view(-1, weight.size(0))
31
  grad_input = grad_output.mm(weight)
32
  grad_weight = grad_output.t().mm(inp)
33
+ return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
34
 
35
 
36
  class Kernel: