Update implementation
Browse files- modeling_chatglm.py +15 -39
modeling_chatglm.py
CHANGED
@@ -35,12 +35,12 @@ if sys.platform != 'darwin':
|
|
35 |
|
36 |
logger = logging.get_logger(__name__)
|
37 |
|
38 |
-
_CHECKPOINT_FOR_DOC = "THUDM/
|
39 |
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
40 |
|
41 |
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
42 |
-
"THUDM/
|
43 |
-
# See all ChatGLM
|
44 |
]
|
45 |
|
46 |
|
@@ -92,7 +92,7 @@ class RotaryEmbedding(nn.Module):
|
|
92 |
self.dim = dim
|
93 |
self.original_impl = original_impl
|
94 |
|
95 |
-
def
|
96 |
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
97 |
):
|
98 |
"""Enhanced Transformer with Rotary Position Embedding.
|
@@ -118,14 +118,13 @@ class RotaryEmbedding(nn.Module):
|
|
118 |
return cache
|
119 |
|
120 |
def forward(self, max_seq_len, offset=0):
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
)
|
125 |
|
126 |
|
127 |
@torch.jit.script
|
128 |
-
def
|
129 |
# x: [sq, b, np, hn]
|
130 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
131 |
rot_dim = rope_cache.shape[-2] * 2
|
@@ -313,8 +312,6 @@ class SelfAttention(torch.nn.Module):
|
|
313 |
device=device, **_config_to_kwargs(config)
|
314 |
)
|
315 |
|
316 |
-
self.interleaved_qkv = config.interleaved_qkv
|
317 |
-
|
318 |
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
319 |
if self.multi_query_attention:
|
320 |
num_attention_heads = self.num_multi_query_groups_per_partition
|
@@ -364,33 +361,18 @@ class SelfAttention(torch.nn.Module):
|
|
364 |
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
365 |
)
|
366 |
else:
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
372 |
|
373 |
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
374 |
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
375 |
|
376 |
-
if not self.interleaved_qkv:
|
377 |
-
query_layer = query_layer.view(
|
378 |
-
query_layer.size()[:-1] + (
|
379 |
-
self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
380 |
-
).contiguous()
|
381 |
-
key_layer = key_layer.view(
|
382 |
-
key_layer.size()[:-1] + (
|
383 |
-
self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
384 |
-
).contiguous()
|
385 |
-
value_layer = value_layer.view(
|
386 |
-
value_layer.size()[:-1] + (
|
387 |
-
self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
388 |
-
).contiguous()
|
389 |
-
|
390 |
# apply relative positional encoding (rotary embedding)
|
391 |
if rotary_pos_emb is not None:
|
392 |
-
query_layer =
|
393 |
-
key_layer =
|
394 |
|
395 |
# adjust key and value for inference
|
396 |
if use_cache:
|
@@ -713,13 +695,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
713 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
714 |
)
|
715 |
|
716 |
-
|
717 |
-
rotary_dim = int(rotary_dim * config.rotary_percent)
|
718 |
-
|
719 |
-
# partial rotary embeddings, which is better than full rotary
|
720 |
-
# Wang and Komatsuzaki et al
|
721 |
-
# https://github.com/kingoflolz/mesh-transformer-jax/
|
722 |
-
self.rotary_pos_emb = RotaryEmbedding(rotary_dim, original_impl=config.original_rope, device=device,
|
723 |
dtype=config.torch_dtype)
|
724 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
725 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
|
|
35 |
|
36 |
logger = logging.get_logger(__name__)
|
37 |
|
38 |
+
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
|
39 |
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
40 |
|
41 |
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
42 |
+
"THUDM/chatglm2-6b",
|
43 |
+
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
44 |
]
|
45 |
|
46 |
|
|
|
92 |
self.dim = dim
|
93 |
self.original_impl = original_impl
|
94 |
|
95 |
+
def forward_impl(
|
96 |
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
97 |
):
|
98 |
"""Enhanced Transformer with Rotary Position Embedding.
|
|
|
118 |
return cache
|
119 |
|
120 |
def forward(self, max_seq_len, offset=0):
|
121 |
+
return self.forward_impl(
|
122 |
+
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
|
123 |
+
)
|
|
|
124 |
|
125 |
|
126 |
@torch.jit.script
|
127 |
+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
128 |
# x: [sq, b, np, hn]
|
129 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
130 |
rot_dim = rope_cache.shape[-2] * 2
|
|
|
312 |
device=device, **_config_to_kwargs(config)
|
313 |
)
|
314 |
|
|
|
|
|
315 |
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
316 |
if self.multi_query_attention:
|
317 |
num_attention_heads = self.num_multi_query_groups_per_partition
|
|
|
361 |
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
362 |
)
|
363 |
else:
|
364 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
365 |
+
(self.num_attention_heads_per_partition,
|
366 |
+
3 * self.hidden_size_per_attention_head)
|
367 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
|
|
368 |
|
369 |
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
370 |
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
# apply relative positional encoding (rotary embedding)
|
373 |
if rotary_pos_emb is not None:
|
374 |
+
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
375 |
+
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
376 |
|
377 |
# adjust key and value for inference
|
378 |
if use_cache:
|
|
|
695 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
696 |
)
|
697 |
|
698 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
|
|
|
|
|
|
|
|
|
|
|
|
699 |
dtype=config.torch_dtype)
|
700 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
701 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|