zR commited on
Commit
bd16ff6
·
1 Parent(s): 7025474

add flash-attn

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. modeling_chatglm.py +224 -88
config.json CHANGED
@@ -17,6 +17,7 @@
17
  "apply_residual_connection_post_layernorm": false,
18
  "attention_dropout": 0.0,
19
  "attention_softmax_in_fp32": true,
 
20
  "bias_dropout_fusion": true,
21
  "ffn_hidden_size": 13696,
22
  "fp32_residual_connection": false,
 
17
  "apply_residual_connection_post_layernorm": false,
18
  "attention_dropout": 0.0,
19
  "attention_softmax_in_fp32": true,
20
+ "attn_implementation": "sdpa",
21
  "bias_dropout_fusion": true,
22
  "ffn_hidden_size": 13696,
23
  "fp32_residual_connection": false,
modeling_chatglm.py CHANGED
@@ -21,15 +21,20 @@ from transformers.modeling_outputs import (
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.utils import logging
 
25
  from transformers.generation.logits_process import LogitsProcessor
26
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
28
  from .configuration_chatglm import ChatGLMConfig
29
 
 
 
 
 
30
  # flags required to enable jit fusion kernels
31
 
32
- if sys.platform != 'darwin':
33
  torch._C._jit_set_profiling_mode(False)
34
  torch._C._jit_set_profiling_executor(False)
35
  torch._C._jit_override_can_fuse_on_cpu(True)
@@ -40,11 +45,6 @@ logger = logging.get_logger(__name__)
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
43
- CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
- "THUDM/chatglm3-6b",
45
- # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
46
- ]
47
-
48
 
49
  def default_init(cls, *args, **kwargs):
50
  return cls(*args, **kwargs)
@@ -165,12 +165,13 @@ class RMSNorm(torch.nn.Module):
165
  class CoreAttention(torch.nn.Module):
166
  def __init__(self, config: ChatGLMConfig, layer_number):
167
  super(CoreAttention, self).__init__()
168
-
169
  self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
170
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
171
  if self.apply_query_key_layer_scaling:
172
  self.attention_softmax_in_fp32 = True
173
  self.layer_number = max(1, layer_number)
 
174
 
175
  projection_size = config.kv_channels * config.num_attention_heads
176
 
@@ -189,91 +190,198 @@ class CoreAttention(torch.nn.Module):
189
  self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
190
 
191
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
192
- pytorch_major_version = int(torch.__version__.split('.')[0])
193
- if pytorch_major_version >= 2:
194
- if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
195
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
196
- is_causal=True)
197
- else:
198
- if attention_mask is not None:
199
- attention_mask = ~attention_mask
200
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
201
- attention_mask)
202
- context_layer = context_layer.transpose(1, 2).contiguous()
203
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
204
- context_layer = context_layer.reshape(*new_context_layer_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  else:
206
- # Raw attention scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- # [b, np, sq, sk]
209
- output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
210
 
211
- # [b, np, sq, hn] -> [b * np, sq, hn]
212
- query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
213
- # [b, np, sk, hn] -> [b * np, sk, hn]
214
- key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
 
215
 
216
- # preallocting input tensor: [b * np, sq, sk]
217
- matmul_input_buffer = torch.empty(
218
- output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
219
- device=query_layer.device
 
 
 
 
 
 
 
 
 
 
 
220
  )
221
 
222
- # Raw attention scores. [b * np, sq, sk]
223
- matmul_result = torch.baddbmm(
224
- matmul_input_buffer,
225
- query_layer, # [b * np, sq, hn]
226
- key_layer.transpose(1, 2), # [b * np, hn, sk]
227
- beta=0.0,
228
- alpha=(1.0 / self.norm_factor),
 
 
 
 
 
 
 
229
  )
230
 
231
- # change view to [b, np, sq, sk]
232
- attention_scores = matmul_result.view(*output_size)
233
-
234
- # ===========================
235
- # Attention probs and dropout
236
- # ===========================
237
-
238
- # attention scores and attention mask [b, np, sq, sk]
239
- if self.attention_softmax_in_fp32:
240
- attention_scores = attention_scores.float()
241
- if self.coeff is not None:
242
- attention_scores = attention_scores * self.coeff
243
- if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
244
- attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
245
- device=attention_scores.device, dtype=torch.bool)
246
- attention_mask.tril_()
247
- attention_mask = ~attention_mask
248
- if attention_mask is not None:
249
- attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
250
- attention_probs = F.softmax(attention_scores, dim=-1)
251
- attention_probs = attention_probs.type_as(value_layer)
252
-
253
- # This is actually dropping out entire tokens to attend to, which might
254
- # seem a bit unusual, but is taken from the original Transformer paper.
255
- attention_probs = self.attention_dropout(attention_probs)
256
-
257
- # query layer shape: [b * np, sq, hn]
258
- # value layer shape: [b, np, sk, hn]
259
- # attention shape: [b, np, sq, sk]
260
- # context layer shape: [b, np, sq, hn]
261
- output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
262
- # change view [b * np, sk, hn]
263
- value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
264
- # change view [b * np, sq, sk]
265
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
266
- # matmul: [b * np, sq, hn]
267
- context_layer = torch.bmm(attention_probs, value_layer)
268
- # change view [b, np, sq, hn]
269
- context_layer = context_layer.view(*output_size)
270
- # [b, np, sq, hn] --> [b, sq, np, hn]
271
- context_layer = context_layer.transpose(1, 2).contiguous()
272
- # [b, sq, np, hn] --> [b, sq, hp]
273
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
274
- context_layer = context_layer.reshape(*new_context_layer_shape)
275
 
276
- return context_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
 
279
  class SelfAttention(torch.nn.Module):
@@ -305,7 +413,7 @@ class SelfAttention(torch.nn.Module):
305
  device=device, **_config_to_kwargs(config)
306
  )
307
 
308
- self.core_attention = CoreAttention(config, self.layer_number)
309
 
310
  # Output.
311
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
@@ -383,7 +491,11 @@ class SelfAttention(torch.nn.Module):
383
  key_layer = torch.cat((cache_k, key_layer), dim=2)
384
  value_layer = torch.cat((cache_v, value_layer), dim=2)
385
  if use_cache:
386
- kv_cache = (key_layer, value_layer)
 
 
 
 
387
  else:
388
  kv_cache = None
389
 
@@ -616,7 +728,15 @@ class GLMTransformer(torch.nn.Module):
616
  )
617
  hidden_states, kv_cache = layer_ret
618
  if use_cache:
619
- presents = presents + (kv_cache,)
 
 
 
 
 
 
 
 
620
 
621
  if output_hidden_states:
622
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -639,12 +759,18 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
639
  config_class = ChatGLMConfig
640
  base_model_prefix = "transformer"
641
  _no_split_modules = ["GLMBlock"]
 
 
642
 
643
  def _init_weights(self, module: nn.Module):
644
  """Initialize the weights."""
645
  return
646
 
647
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
 
 
 
 
648
  batch_size, seq_length = input_ids.shape
649
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
650
  full_attention_mask.tril_()
@@ -719,7 +845,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
719
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
720
  )
721
 
722
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
 
723
  device=device, dtype=config.torch_dtype)
724
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
725
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
@@ -740,6 +867,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
740
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
741
  inputs_embeds: Optional[torch.Tensor] = None,
742
  use_cache: Optional[bool] = None,
 
743
  output_hidden_states: Optional[bool] = None,
744
  return_dict: Optional[bool] = None,
745
  ):
@@ -770,6 +898,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
770
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
771
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
772
  )
 
 
 
 
 
 
773
 
774
  if not return_dict:
775
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
@@ -1145,6 +1279,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1145
  inputs_embeds: Optional[torch.LongTensor] = None,
1146
  labels: Optional[torch.LongTensor] = None,
1147
  use_cache: Optional[bool] = None,
 
1148
  output_hidden_states: Optional[bool] = None,
1149
  return_dict: Optional[bool] = None,
1150
  ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
@@ -1158,6 +1293,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1158
  past_key_values=past_key_values,
1159
  inputs_embeds=inputs_embeds,
1160
  use_cache=use_cache,
 
1161
  output_hidden_states=output_hidden_states,
1162
  return_dict=return_dict,
1163
  )
 
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging, is_torch_npu_available, is_flash_attn_greater_or_equal_2_10, \
25
+ is_flash_attn_2_available
26
  from transformers.generation.logits_process import LogitsProcessor
27
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
28
 
29
  from .configuration_chatglm import ChatGLMConfig
30
 
31
+ if is_flash_attn_2_available():
32
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
33
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
34
+
35
  # flags required to enable jit fusion kernels
36
 
37
+ if sys.platform != 'darwin' and not is_torch_npu_available():
38
  torch._C._jit_set_profiling_mode(False)
39
  torch._C._jit_set_profiling_executor(False)
40
  torch._C._jit_override_can_fuse_on_cpu(True)
 
45
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
46
  _CONFIG_FOR_DOC = "ChatGLMConfig"
47
 
 
 
 
 
 
48
 
49
  def default_init(cls, *args, **kwargs):
50
  return cls(*args, **kwargs)
 
165
  class CoreAttention(torch.nn.Module):
166
  def __init__(self, config: ChatGLMConfig, layer_number):
167
  super(CoreAttention, self).__init__()
168
+ self.config = config
169
  self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
170
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
171
  if self.apply_query_key_layer_scaling:
172
  self.attention_softmax_in_fp32 = True
173
  self.layer_number = max(1, layer_number)
174
+ self.is_causal = True
175
 
176
  projection_size = config.kv_channels * config.num_attention_heads
177
 
 
190
  self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
191
 
192
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
193
+ # [b, np, sq, sk]
194
+ output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
195
+
196
+ # [b, np, sq, hn] -> [b * np, sq, hn]
197
+ query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
198
+ # [b, np, sk, hn] -> [b * np, sk, hn]
199
+ key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
200
+
201
+ # preallocting input tensor: [b * np, sq, sk]
202
+ matmul_input_buffer = torch.empty(
203
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
204
+ device=query_layer.device
205
+ )
206
+
207
+ # Raw attention scores. [b * np, sq, sk]
208
+ matmul_result = torch.baddbmm(
209
+ matmul_input_buffer,
210
+ query_layer, # [b * np, sq, hn]
211
+ key_layer.transpose(1, 2), # [b * np, hn, sk]
212
+ beta=0.0,
213
+ alpha=(1.0 / self.norm_factor),
214
+ )
215
+
216
+ # change view to [b, np, sq, sk]
217
+ attention_scores = matmul_result.view(*output_size)
218
+
219
+ # ===========================
220
+ # Attention probs and dropout
221
+ # ===========================
222
+
223
+ # attention scores and attention mask [b, np, sq, sk]
224
+ if self.attention_softmax_in_fp32:
225
+ attention_scores = attention_scores.float()
226
+ if self.coeff is not None:
227
+ attention_scores = attention_scores * self.coeff
228
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
229
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
230
+ device=attention_scores.device, dtype=torch.bool)
231
+ attention_mask.tril_()
232
+ attention_mask = ~attention_mask
233
+ if attention_mask is not None:
234
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
235
+ attention_probs = F.softmax(attention_scores, dim=-1)
236
+ attention_probs = attention_probs.type_as(value_layer)
237
+
238
+ # This is actually dropping out entire tokens to attend to, which might
239
+ # seem a bit unusual, but is taken from the original Transformer paper.
240
+ attention_probs = self.attention_dropout(attention_probs)
241
+
242
+ # query layer shape: [b * np, sq, hn]
243
+ # value layer shape: [b, np, sk, hn]
244
+ # attention shape: [b, np, sq, sk]
245
+ # context layer shape: [b, np, sq, hn]
246
+ output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
247
+ # change view [b * np, sk, hn]
248
+ value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
249
+ # change view [b * np, sq, sk]
250
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
251
+ # matmul: [b * np, sq, hn]
252
+ context_layer = torch.bmm(attention_probs, value_layer)
253
+ # change view [b, np, sq, hn]
254
+ context_layer = context_layer.view(*output_size)
255
+ # [b, np, sq, hn] --> [b, sq, np, hn]
256
+ context_layer = context_layer.transpose(1, 2).contiguous()
257
+ # [b, sq, np, hn] --> [b, sq, hp]
258
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
259
+ context_layer = context_layer.reshape(*new_context_layer_shape)
260
+
261
+ return context_layer
262
+
263
+
264
+ class SdpaAttention(CoreAttention):
265
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
266
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
267
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
268
+ is_causal=True,
269
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
270
  else:
271
+ if attention_mask is not None:
272
+ attention_mask = ~attention_mask
273
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
274
+ attention_mask,
275
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
276
+ context_layer = context_layer.transpose(1, 2).contiguous()
277
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
278
+ context_layer = context_layer.reshape(*new_context_layer_shape)
279
+ return context_layer
280
+
281
+
282
+ def _get_unpad_data(attention_mask):
283
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
284
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
285
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
286
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
287
+ return (
288
+ indices,
289
+ cu_seqlens,
290
+ max_seqlen_in_batch,
291
+ )
292
 
 
 
293
 
294
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
295
+ class FlashAttention2(CoreAttention):
296
+ def __init__(self, *args, **kwargs):
297
+ super().__init__(*args, **kwargs)
298
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
299
 
300
+ def forward(self, query_states, key_states, value_states, attention_mask):
301
+ query_states = query_states.transpose(1, 2)
302
+ key_states = key_states.transpose(1, 2)
303
+ value_states = value_states.transpose(1, 2)
304
+ batch_size, query_length = query_states.shape[:2]
305
+ if not self._flash_attn_uses_top_left_mask:
306
+ causal = self.is_causal
307
+ else:
308
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
309
+ causal = self.is_causal and query_length != 1
310
+ dropout = self.config.attention_dropout if self.training else 0.0
311
+ # Contains at least one padding token in the sequence
312
+ if attention_mask is not None:
313
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
314
+ query_states, key_states, value_states, attention_mask, query_length
315
  )
316
 
317
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
318
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
319
+
320
+ attn_output_unpad = flash_attn_varlen_func(
321
+ query_states,
322
+ key_states,
323
+ value_states,
324
+ cu_seqlens_q=cu_seqlens_q,
325
+ cu_seqlens_k=cu_seqlens_k,
326
+ max_seqlen_q=max_seqlen_in_batch_q,
327
+ max_seqlen_k=max_seqlen_in_batch_k,
328
+ dropout_p=dropout,
329
+ softmax_scale=None,
330
+ causal=causal,
331
  )
332
 
333
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
334
+ else:
335
+ attn_output = flash_attn_func(
336
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
337
+ )
338
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
339
+ return attn_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
342
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
343
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
344
+
345
+ key_layer = index_first_axis(
346
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
347
+ )
348
+ value_layer = index_first_axis(
349
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
350
+ )
351
+ if query_length == kv_seq_len:
352
+ query_layer = index_first_axis(
353
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), indices_k
354
+ )
355
+ cu_seqlens_q = cu_seqlens_k
356
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
357
+ indices_q = indices_k
358
+ elif query_length == 1:
359
+ max_seqlen_in_batch_q = 1
360
+ cu_seqlens_q = torch.arange(
361
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
362
+ ) # There is a memcpy here, that is very bad.
363
+ indices_q = cu_seqlens_q[:-1]
364
+ query_layer = query_layer.squeeze(1)
365
+ else:
366
+ # The -q_len: slice assumes left padding.
367
+ attention_mask = attention_mask[:, -query_length:]
368
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
369
+
370
+ return (
371
+ query_layer,
372
+ key_layer,
373
+ value_layer,
374
+ indices_q,
375
+ (cu_seqlens_q, cu_seqlens_k),
376
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
377
+ )
378
+
379
+
380
+ CORE_ATTENTION_CLASSES = {
381
+ "eager": CoreAttention,
382
+ "sdpa": SdpaAttention,
383
+ "flash_attention_2": FlashAttention2
384
+ }
385
 
386
 
387
  class SelfAttention(torch.nn.Module):
 
413
  device=device, **_config_to_kwargs(config)
414
  )
415
 
416
+ self.core_attention = CORE_ATTENTION_CLASSES[config._attn_implementation](config, self.layer_number)
417
 
418
  # Output.
419
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
 
491
  key_layer = torch.cat((cache_k, key_layer), dim=2)
492
  value_layer = torch.cat((cache_v, value_layer), dim=2)
493
  if use_cache:
494
+ if kv_cache is None:
495
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
496
+ dim=1)
497
+ else:
498
+ kv_cache = (key_layer, value_layer)
499
  else:
500
  kv_cache = None
501
 
 
728
  )
729
  hidden_states, kv_cache = layer_ret
730
  if use_cache:
731
+ # token by token decoding, use tuple format
732
+ if kv_caches[0] is not None:
733
+ presents = presents + (kv_cache,)
734
+ # prefilling in decoding, use tensor format to save cuda memory
735
+ else:
736
+ if len(presents) == 0:
737
+ presents = kv_cache
738
+ else:
739
+ presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)
740
 
741
  if output_hidden_states:
742
  all_hidden_states = all_hidden_states + (hidden_states,)
 
759
  config_class = ChatGLMConfig
760
  base_model_prefix = "transformer"
761
  _no_split_modules = ["GLMBlock"]
762
+ _supports_flash_attn_2 = True
763
+ _supports_sdpa = True
764
 
765
  def _init_weights(self, module: nn.Module):
766
  """Initialize the weights."""
767
  return
768
 
769
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
770
+ if self.config._attn_implementation == "flash_attention_2":
771
+ if padding_mask is not None and not padding_mask.all():
772
+ return padding_mask
773
+ return None
774
  batch_size, seq_length = input_ids.shape
775
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
776
  full_attention_mask.tril_()
 
845
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
846
  )
847
 
848
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio,
849
+ original_impl=config.original_rope,
850
  device=device, dtype=config.torch_dtype)
851
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
852
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
 
867
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
868
  inputs_embeds: Optional[torch.Tensor] = None,
869
  use_cache: Optional[bool] = None,
870
+ output_attentions: Optional[bool] = None,
871
  output_hidden_states: Optional[bool] = None,
872
  return_dict: Optional[bool] = None,
873
  ):
 
898
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
899
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
900
  )
901
+ if presents is not None and type(presents) is torch.Tensor:
902
+ presents = presents.split(1, dim=0)
903
+ presents = list(presents)
904
+ presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
905
+ presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
906
+ presents = tuple(presents)
907
 
908
  if not return_dict:
909
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
 
1279
  inputs_embeds: Optional[torch.LongTensor] = None,
1280
  labels: Optional[torch.LongTensor] = None,
1281
  use_cache: Optional[bool] = None,
1282
+ output_attentions: Optional[bool] = None,
1283
  output_hidden_states: Optional[bool] = None,
1284
  return_dict: Optional[bool] = None,
1285
  ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
 
1293
  past_key_values=past_key_values,
1294
  inputs_embeds=inputs_embeds,
1295
  use_cache=use_cache,
1296
+ output_attentions=output_attentions,
1297
  output_hidden_states=output_hidden_states,
1298
  return_dict=return_dict,
1299
  )