Sengxian commited on
Commit
c3b3141
1 Parent(s): e9b655e

Update implementation

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +15 -39
modeling_chatglm.py CHANGED
@@ -35,12 +35,12 @@ if sys.platform != 'darwin':
35
 
36
  logger = logging.get_logger(__name__)
37
 
38
- _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"
39
  _CONFIG_FOR_DOC = "ChatGLM6BConfig"
40
 
41
  CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
42
- "THUDM/chatglm-6b",
43
- # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
44
  ]
45
 
46
 
@@ -92,7 +92,7 @@ class RotaryEmbedding(nn.Module):
92
  self.dim = dim
93
  self.original_impl = original_impl
94
 
95
- def forward_original_impl(
96
  self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
97
  ):
98
  """Enhanced Transformer with Rotary Position Embedding.
@@ -118,14 +118,13 @@ class RotaryEmbedding(nn.Module):
118
  return cache
119
 
120
  def forward(self, max_seq_len, offset=0):
121
- if self.original_impl:
122
- return self.forward_original_impl(
123
- max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
124
- )
125
 
126
 
127
  @torch.jit.script
128
- def apply_rotary_pos_emb_original(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
129
  # x: [sq, b, np, hn]
130
  sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
131
  rot_dim = rope_cache.shape[-2] * 2
@@ -313,8 +312,6 @@ class SelfAttention(torch.nn.Module):
313
  device=device, **_config_to_kwargs(config)
314
  )
315
 
316
- self.interleaved_qkv = config.interleaved_qkv
317
-
318
  def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
319
  if self.multi_query_attention:
320
  num_attention_heads = self.num_multi_query_groups_per_partition
@@ -364,33 +361,18 @@ class SelfAttention(torch.nn.Module):
364
  + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
365
  )
366
  else:
367
- if self.interleaved_qkv:
368
- new_tensor_shape = mixed_x_layer.size()[:-1] + \
369
- (self.num_attention_heads_per_partition,
370
- 3 * self.hidden_size_per_attention_head)
371
- mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
372
 
373
  # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
374
  (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
375
 
376
- if not self.interleaved_qkv:
377
- query_layer = query_layer.view(
378
- query_layer.size()[:-1] + (
379
- self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
380
- ).contiguous()
381
- key_layer = key_layer.view(
382
- key_layer.size()[:-1] + (
383
- self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
384
- ).contiguous()
385
- value_layer = value_layer.view(
386
- value_layer.size()[:-1] + (
387
- self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
388
- ).contiguous()
389
-
390
  # apply relative positional encoding (rotary embedding)
391
  if rotary_pos_emb is not None:
392
- query_layer = apply_rotary_pos_emb_original(query_layer, rotary_pos_emb)
393
- key_layer = apply_rotary_pos_emb_original(key_layer, rotary_pos_emb)
394
 
395
  # adjust key and value for inference
396
  if use_cache:
@@ -713,13 +695,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
713
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
714
  )
715
 
716
- if config.rotary_percent < 1.0:
717
- rotary_dim = int(rotary_dim * config.rotary_percent)
718
-
719
- # partial rotary embeddings, which is better than full rotary
720
- # Wang and Komatsuzaki et al
721
- # https://github.com/kingoflolz/mesh-transformer-jax/
722
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim, original_impl=config.original_rope, device=device,
723
  dtype=config.torch_dtype)
724
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
725
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
 
35
 
36
  logger = logging.get_logger(__name__)
37
 
38
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
39
  _CONFIG_FOR_DOC = "ChatGLM6BConfig"
40
 
41
  CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
42
+ "THUDM/chatglm2-6b",
43
+ # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
44
  ]
45
 
46
 
 
92
  self.dim = dim
93
  self.original_impl = original_impl
94
 
95
+ def forward_impl(
96
  self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
97
  ):
98
  """Enhanced Transformer with Rotary Position Embedding.
 
118
  return cache
119
 
120
  def forward(self, max_seq_len, offset=0):
121
+ return self.forward_impl(
122
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
123
+ )
 
124
 
125
 
126
  @torch.jit.script
127
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
128
  # x: [sq, b, np, hn]
129
  sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
130
  rot_dim = rope_cache.shape[-2] * 2
 
312
  device=device, **_config_to_kwargs(config)
313
  )
314
 
 
 
315
  def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
316
  if self.multi_query_attention:
317
  num_attention_heads = self.num_multi_query_groups_per_partition
 
361
  + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
362
  )
363
  else:
364
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
365
+ (self.num_attention_heads_per_partition,
366
+ 3 * self.hidden_size_per_attention_head)
367
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
 
368
 
369
  # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
370
  (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  # apply relative positional encoding (rotary embedding)
373
  if rotary_pos_emb is not None:
374
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
375
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
376
 
377
  # adjust key and value for inference
378
  if use_cache:
 
695
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
696
  )
697
 
698
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
 
 
 
 
 
 
699
  dtype=config.torch_dtype)
700
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
701
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,