duzx16 commited on
Commit
5fe53eb
1 Parent(s): 74d61a6

Fix checkpointing

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +10 -6
modeling_chatglm.py CHANGED
@@ -63,7 +63,7 @@ class PrefixEncoder(torch.nn.Module):
63
  Output shape: (batch-size, prefix-length, 2*layers*hidden)
64
  """
65
 
66
- def __init__(self, config):
67
  super().__init__()
68
  self.prefix_projection = config.prefix_projection
69
  if self.prefix_projection:
@@ -75,7 +75,8 @@ class PrefixEncoder(torch.nn.Module):
75
  torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
76
  )
77
  else:
78
- self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
 
79
 
80
  def forward(self, prefix: torch.Tensor):
81
  if self.prefix_projection:
@@ -629,8 +630,8 @@ class GLMTransformer(torch.nn.Module):
629
  hidden_states,
630
  attention_mask,
631
  rotary_pos_emb,
632
- kv_cache=kv_caches[index],
633
- use_cache=use_cache
634
  )
635
  else:
636
  layer_ret = layer(
@@ -737,6 +738,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
737
  if device is not None:
738
  init_kwargs["device"] = device
739
  self.embedding = init_method(Embedding, config, **init_kwargs)
 
 
 
740
 
741
  # Rotary positional embeddings
742
  self.seq_length = config.seq_length
@@ -768,8 +772,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
768
  batch_size,
769
  self.pre_seq_len,
770
  self.num_layers * 2,
771
- self.num_attention_heads,
772
- self.hidden_size // self.num_attention_heads
773
  )
774
  # seq_len, b, nh, hidden_size
775
  past_key_values = self.dropout(past_key_values)
 
63
  Output shape: (batch-size, prefix-length, 2*layers*hidden)
64
  """
65
 
66
+ def __init__(self, config: ChatGLMConfig):
67
  super().__init__()
68
  self.prefix_projection = config.prefix_projection
69
  if self.prefix_projection:
 
75
  torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
76
  )
77
  else:
78
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
79
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
80
 
81
  def forward(self, prefix: torch.Tensor):
82
  if self.prefix_projection:
 
630
  hidden_states,
631
  attention_mask,
632
  rotary_pos_emb,
633
+ kv_caches[index],
634
+ use_cache
635
  )
636
  else:
637
  layer_ret = layer(
 
738
  if device is not None:
739
  init_kwargs["device"] = device
740
  self.embedding = init_method(Embedding, config, **init_kwargs)
741
+ self.num_layers = config.num_layers
742
+ self.multi_query_group_num = config.multi_query_group_num
743
+ self.kv_channels = config.kv_channels
744
 
745
  # Rotary positional embeddings
746
  self.seq_length = config.seq_length
 
772
  batch_size,
773
  self.pre_seq_len,
774
  self.num_layers * 2,
775
+ self.multi_query_group_num,
776
+ self.kv_channels
777
  )
778
  # seq_len, b, nh, hidden_size
779
  past_key_values = self.dropout(past_key_values)