cognitivess commited on
Commit
5336c3e
·
verified ·
1 Parent(s): 189915a

Update cognitivess_model/modeling_Cognitivess.py

Browse files
cognitivess_model/modeling_Cognitivess.py CHANGED
@@ -1,10 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 Cognitivess AI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -17,30 +12,31 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
- """PyTorch Cognitivess model."""
21
-
22
  import math
23
  from typing import List, Optional, Tuple, Union
24
 
25
  import torch
 
26
  import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
  from ...activations import ACT2FN
31
- from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
32
  from ...modeling_attn_mask_utils import AttentionMaskConverter
 
33
  from ...modeling_outputs import (
34
  BaseModelOutputWithPast,
35
  CausalLMOutputWithPast,
 
36
  SequenceClassifierOutputWithPast,
37
  TokenClassifierOutput,
38
  )
39
  from ...modeling_utils import PreTrainedModel
 
40
  from ...utils import (
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
43
- is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
45
  logging,
46
  replace_return_docstrings,
@@ -48,15 +44,11 @@ from ...utils import (
48
  from .configuration_Cognitivess import CognitivessConfig
49
 
50
 
51
- if is_flash_attn_2_available():
52
- from ...modeling_flash_attention_utils import _flash_attention_forward
53
-
54
  logger = logging.get_logger(__name__)
55
 
56
  _CONFIG_FOR_DOC = "CognitivessConfig"
57
 
58
 
59
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Cognitivess
60
  class CognitivessRMSNorm(nn.Module):
61
  def __init__(self, hidden_size, eps=1e-6):
62
  """
@@ -74,18 +66,22 @@ class CognitivessRMSNorm(nn.Module):
74
  return self.weight * hidden_states.to(input_dtype)
75
 
76
 
 
 
 
77
  class CognitivessRotaryEmbedding(nn.Module):
78
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
79
  super().__init__()
80
-
81
  self.dim = dim
82
  self.max_position_embeddings = max_position_embeddings
83
  self.base = base
84
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
85
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
 
86
 
87
  @torch.no_grad()
88
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
89
  def forward(self, x, position_ids):
90
  # x: [bs, num_attention_heads, seq_len, head_size]
91
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
@@ -102,7 +98,35 @@ class CognitivessRotaryEmbedding(nn.Module):
102
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
103
 
104
 
105
- # Copied from transformers.models.llama.modeling_llama.rotate_half
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def rotate_half(x):
107
  """Rotates half the hidden dims of the input."""
108
  x1 = x[..., : x.shape[-1] // 2]
@@ -110,7 +134,6 @@ def rotate_half(x):
110
  return torch.cat((-x2, x1), dim=-1)
111
 
112
 
113
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
114
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
115
  """Applies Rotary Position Embedding to the query and key tensors.
116
 
@@ -141,18 +164,37 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
141
  class CognitivessMLP(nn.Module):
142
  def __init__(self, config):
143
  super().__init__()
 
144
  self.hidden_size = config.hidden_size
145
  self.intermediate_size = config.intermediate_size
146
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
147
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
148
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
149
  self.act_fn = ACT2FN[config.hidden_act]
150
 
151
- def forward(self, hidden_state):
152
- return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
 
155
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
156
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
157
  """
158
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -166,10 +208,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
166
 
167
 
168
  class CognitivessAttention(nn.Module):
169
- """
170
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
171
- and "Generating Long Sequences with Sparse Transformers".
172
- """
173
 
174
  def __init__(self, config: CognitivessConfig, layer_idx: Optional[int] = None):
175
  super().__init__()
@@ -185,23 +224,51 @@ class CognitivessAttention(nn.Module):
185
  self.attention_dropout = config.attention_dropout
186
  self.hidden_size = config.hidden_size
187
  self.num_heads = config.num_attention_heads
188
- self.head_dim = config.head_dim
189
  self.num_key_value_heads = config.num_key_value_heads
190
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
191
  self.max_position_embeddings = config.max_position_embeddings
192
  self.rope_theta = config.rope_theta
193
  self.is_causal = True
194
 
195
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
196
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
197
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
198
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
199
 
200
- self.rotary_emb = CognitivessRotaryEmbedding(
201
- self.head_dim,
202
- max_position_embeddings=self.max_position_embeddings,
203
- base=self.rope_theta,
204
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  def forward(
207
  self,
@@ -212,12 +279,31 @@ class CognitivessAttention(nn.Module):
212
  output_attentions: bool = False,
213
  use_cache: bool = False,
214
  cache_position: Optional[torch.LongTensor] = None,
 
215
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
216
  bsz, q_len, _ = hidden_states.size()
217
 
218
- query_states = self.q_proj(hidden_states)
219
- key_states = self.k_proj(hidden_states)
220
- value_states = self.v_proj(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
223
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -253,8 +339,14 @@ class CognitivessAttention(nn.Module):
253
 
254
  attn_output = attn_output.transpose(1, 2).contiguous()
255
 
256
- attn_output = attn_output.view(bsz, q_len, -1)
257
- attn_output = self.o_proj(attn_output)
 
 
 
 
 
 
258
 
259
  if not output_attentions:
260
  attn_weights = None
@@ -269,7 +361,6 @@ class CognitivessFlashAttention2(CognitivessAttention):
269
  flash attention and deal with padding tokens in case the input contains any of them.
270
  """
271
 
272
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
273
  def __init__(self, *args, **kwargs):
274
  super().__init__(*args, **kwargs)
275
 
@@ -281,13 +372,13 @@ class CognitivessFlashAttention2(CognitivessAttention):
281
  def forward(
282
  self,
283
  hidden_states: torch.Tensor,
284
- attention_mask: Optional[torch.Tensor] = None,
285
  position_ids: Optional[torch.LongTensor] = None,
286
  past_key_value: Optional[Cache] = None,
287
  output_attentions: bool = False,
288
  use_cache: bool = False,
289
  cache_position: Optional[torch.LongTensor] = None,
290
- ):
291
  if isinstance(past_key_value, StaticCache):
292
  raise ValueError(
293
  "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
@@ -302,54 +393,35 @@ class CognitivessFlashAttention2(CognitivessAttention):
302
  key_states = self.k_proj(hidden_states)
303
  value_states = self.v_proj(hidden_states)
304
 
 
 
 
305
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
306
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
307
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
308
 
309
- kv_seq_len = key_states.shape[-2]
310
- if past_key_value is not None:
311
- kv_seq_len += cache_position[0]
312
-
313
  cos, sin = self.rotary_emb(value_states, position_ids)
314
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
315
 
316
  if past_key_value is not None:
317
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
318
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
319
- if (
320
- getattr(self.config, "sliding_window", None) is not None
321
- and kv_seq_len > self.config.sliding_window
322
- and cache_has_contents
323
- ):
324
- slicing_tokens = 1 - self.config.sliding_window
325
-
326
- past_key = past_key_value[self.layer_idx][0]
327
- past_value = past_key_value[self.layer_idx][1]
328
-
329
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
330
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
331
-
332
- if past_key.shape[-2] != self.config.sliding_window - 1:
333
- raise ValueError(
334
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
335
- f" {past_key.shape}"
336
- )
337
-
338
- if attention_mask is not None:
339
- attention_mask = attention_mask[:, slicing_tokens:]
340
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
341
-
342
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
343
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
344
 
345
- # repeat k/v heads if n_kv_heads < n_heads
346
- key_states = repeat_kv(key_states, self.num_key_value_groups)
347
- value_states = repeat_kv(value_states, self.num_key_value_groups)
348
- dropout_rate = 0.0 if not self.training else self.attention_dropout
 
 
 
349
 
350
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
351
  # therefore the input hidden states gets silently casted in float32. Hence, we need
352
- # cast them back in float16 just to be sure everything works as expected.
 
 
 
353
  input_dtype = query_states.dtype
354
  if input_dtype == torch.float32:
355
  if torch.is_autocast_enabled():
@@ -370,11 +442,6 @@ class CognitivessFlashAttention2(CognitivessAttention):
370
  key_states = key_states.to(target_dtype)
371
  value_states = value_states.to(target_dtype)
372
 
373
- # Reashape to the expected shape for Flash Attention
374
- query_states = query_states.transpose(1, 2)
375
- key_states = key_states.transpose(1, 2)
376
- value_states = value_states.transpose(1, 2)
377
-
378
  attn_output = _flash_attention_forward(
379
  query_states,
380
  key_states,
@@ -382,12 +449,12 @@ class CognitivessFlashAttention2(CognitivessAttention):
382
  attention_mask,
383
  q_len,
384
  dropout=dropout_rate,
385
- sliding_window=getattr(self.config, "sliding_window", None),
386
  use_top_left_mask=self._flash_attn_uses_top_left_mask,
387
  is_causal=self.is_causal,
388
  )
389
 
390
- attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
391
  attn_output = self.o_proj(attn_output)
392
 
393
  if not output_attentions:
@@ -396,7 +463,6 @@ class CognitivessFlashAttention2(CognitivessAttention):
396
  return attn_output, attn_weights, past_key_value
397
 
398
 
399
- # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Cognitivess
400
  class CognitivessSdpaAttention(CognitivessAttention):
401
  """
402
  Cognitivess attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -492,7 +558,6 @@ Cognitivess_ATTENTION_CLASSES = {
492
  }
493
 
494
 
495
- # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Cognitivess, LLAMA->Cognitivess
496
  class CognitivessDecoderLayer(nn.Module):
497
  def __init__(self, config: CognitivessConfig, layer_idx: int):
498
  super().__init__()
@@ -594,10 +659,11 @@ class CognitivessPreTrainedModel(PreTrainedModel):
594
  base_model_prefix = "model"
595
  supports_gradient_checkpointing = True
596
  _no_split_modules = ["CognitivessDecoderLayer"]
597
- _skip_keys_device_placement = "past_key_values"
598
  _supports_flash_attn_2 = True
599
  _supports_sdpa = True
600
  _supports_cache_class = True
 
601
  _supports_static_cache = True
602
 
603
  def _init_weights(self, module):
@@ -633,7 +699,7 @@ Cognitivess_INPUTS_DOCSTRING = r"""
633
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
634
  [`PreTrainedTokenizer.__call__`] for details.
635
 
636
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
637
  `past_key_values`).
638
 
639
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
@@ -679,6 +745,10 @@ Cognitivess_INPUTS_DOCSTRING = r"""
679
  more detail.
680
  return_dict (`bool`, *optional*):
681
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
 
 
 
 
682
  """
683
 
684
 
@@ -703,10 +773,9 @@ class CognitivessModel(CognitivessPreTrainedModel):
703
  self.layers = nn.ModuleList(
704
  [CognitivessDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
705
  )
706
- self._attn_implementation = config._attn_implementation
707
  self.norm = CognitivessRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
708
-
709
  self.gradient_checkpointing = False
 
710
  # Initialize weights and apply final processing
711
  self.post_init()
712
 
@@ -735,10 +804,8 @@ class CognitivessModel(CognitivessPreTrainedModel):
735
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
736
  )
737
  use_cache = use_cache if use_cache is not None else self.config.use_cache
738
-
739
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
740
 
741
- # retrieve input_ids and inputs_embeds
742
  if (input_ids is None) ^ (inputs_embeds is not None):
743
  raise ValueError(
744
  "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
@@ -746,7 +813,7 @@ class CognitivessModel(CognitivessPreTrainedModel):
746
 
747
  if self.gradient_checkpointing and self.training and use_cache:
748
  logger.warning_once(
749
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
750
  )
751
  use_cache = False
752
 
@@ -754,9 +821,9 @@ class CognitivessModel(CognitivessPreTrainedModel):
754
  inputs_embeds = self.embed_tokens(input_ids)
755
 
756
  return_legacy_cache = False
757
- if use_cache and not isinstance(past_key_values, Cache):
758
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
759
  return_legacy_cache = True
 
760
  logger.warning_once(
761
  "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
762
  "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
@@ -767,14 +834,14 @@ class CognitivessModel(CognitivessPreTrainedModel):
767
  cache_position = torch.arange(
768
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
769
  )
770
-
771
  if position_ids is None:
772
  position_ids = cache_position.unsqueeze(0)
773
 
774
  causal_mask = self._update_causal_mask(
775
- attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
776
  )
777
 
 
778
  hidden_states = inputs_embeds
779
 
780
  # decoder layers
@@ -841,7 +908,6 @@ class CognitivessModel(CognitivessPreTrainedModel):
841
  input_tensor: torch.Tensor,
842
  cache_position: torch.Tensor,
843
  past_key_values: Cache,
844
- use_cache: bool,
845
  output_attentions: bool,
846
  ):
847
  # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
@@ -849,15 +915,7 @@ class CognitivessModel(CognitivessPreTrainedModel):
849
  # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
850
  # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
851
 
852
- if self._attn_implementation == "flash_attention_2":
853
- if attention_mask is not None and use_cache:
854
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
855
- if is_padding_right:
856
- raise ValueError(
857
- "You are attempting to perform batched generation with padding_side='right'"
858
- " this may lead to unexpected behaviour for Flash Attention version of Cognitivess. Make sure to "
859
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
860
- )
861
  if attention_mask is not None and 0.0 in attention_mask:
862
  return attention_mask
863
  return None
@@ -865,22 +923,15 @@ class CognitivessModel(CognitivessPreTrainedModel):
865
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
866
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
867
  # to infer the attention mask.
868
-
869
- # cache_position must be valid here no matter which cache we use
870
- past_seen_tokens = cache_position[0] if past_key_values is not None else 0
871
  using_static_cache = isinstance(past_key_values, StaticCache)
872
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
873
 
874
- if (
875
- self.config._attn_implementation == "sdpa"
876
- and not (using_static_cache or using_sliding_window_cache)
877
- and not output_attentions
878
- ):
879
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
880
  attention_mask,
881
  inputs_embeds=input_tensor,
882
  past_key_values_length=past_seen_tokens,
883
- sliding_window=self.config.sliding_window,
884
  is_training=self.training,
885
  ):
886
  return None
@@ -888,13 +939,8 @@ class CognitivessModel(CognitivessPreTrainedModel):
888
  dtype, device = input_tensor.dtype, input_tensor.device
889
  min_dtype = torch.finfo(dtype).min
890
  sequence_length = input_tensor.shape[1]
891
- # SlidingWindowCache
892
- if using_sliding_window_cache:
893
- target_length = max(sequence_length, self.config.sliding_window)
894
- # StaticCache
895
- elif using_static_cache:
896
  target_length = past_key_values.get_max_length()
897
- # DynamicCache or no cache
898
  else:
899
  target_length = (
900
  attention_mask.shape[-1]
@@ -911,25 +957,18 @@ class CognitivessModel(CognitivessPreTrainedModel):
911
  causal_mask = torch.full(
912
  (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
913
  )
914
- exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
915
- if self.config.sliding_window is not None:
916
- if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
917
- exclude_mask.bitwise_or_(
918
- torch.arange(target_length, device=device)
919
- <= (cache_position.reshape(-1, 1) - self.config.sliding_window)
920
- )
921
- causal_mask *= exclude_mask
922
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
923
  if attention_mask is not None:
924
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
925
- if attention_mask.dim() == 2:
926
- mask_length = attention_mask.shape[-1]
927
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
928
- padding_mask = padding_mask == 0
929
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
930
- padding_mask, min_dtype
931
- )
932
-
933
  if (
934
  self.config._attn_implementation == "sdpa"
935
  and attention_mask is not None
@@ -1015,7 +1054,6 @@ class CognitivessForCausalLM(CognitivessPreTrainedModel):
1015
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1016
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1017
  ```"""
1018
-
1019
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1020
  output_hidden_states = (
1021
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1037,7 +1075,12 @@ class CognitivessForCausalLM(CognitivessPreTrainedModel):
1037
  )
1038
 
1039
  hidden_states = outputs[0]
1040
- logits = self.lm_head(hidden_states)
 
 
 
 
 
1041
  logits = logits.float()
1042
 
1043
  loss = None
@@ -1046,11 +1089,11 @@ class CognitivessForCausalLM(CognitivessPreTrainedModel):
1046
  shift_logits = logits[..., :-1, :].contiguous()
1047
  shift_labels = labels[..., 1:].contiguous()
1048
  # Flatten the tokens
 
1049
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1050
  shift_labels = shift_labels.view(-1)
1051
- # Ensure tensors are on the same device
1052
  shift_labels = shift_labels.to(shift_logits.device)
1053
- loss_fct = CrossEntropyLoss()
1054
  loss = loss_fct(shift_logits, shift_labels)
1055
 
1056
  if not return_dict:
@@ -1065,7 +1108,6 @@ class CognitivessForCausalLM(CognitivessPreTrainedModel):
1065
  attentions=outputs.attentions,
1066
  )
1067
 
1068
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1069
  def prepare_inputs_for_generation(
1070
  self,
1071
  input_ids,
@@ -1126,7 +1168,6 @@ class CognitivessForCausalLM(CognitivessPreTrainedModel):
1126
  """,
1127
  Cognitivess_START_DOCSTRING,
1128
  )
1129
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Cognitivess, LLAMA->Cognitivess
1130
  class CognitivessForSequenceClassification(CognitivessPreTrainedModel):
1131
  def __init__(self, config):
1132
  super().__init__(config)
@@ -1235,6 +1276,105 @@ class CognitivessForSequenceClassification(CognitivessPreTrainedModel):
1235
  )
1236
 
1237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1238
  @add_start_docstrings(
1239
  """
1240
  The Cognitivess Model transformer with a token classification head on top (a linear layer on top of the hidden-states
@@ -1242,7 +1382,6 @@ class CognitivessForSequenceClassification(CognitivessPreTrainedModel):
1242
  """,
1243
  Cognitivess_START_DOCSTRING,
1244
  )
1245
- # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Cognitivess, LLAMA->Cognitivess
1246
  class CognitivessForTokenClassification(CognitivessPreTrainedModel):
1247
  def __init__(self, config):
1248
  super().__init__(config)
 
1
  # coding=utf-8
2
+ # Copyright 2022 Cognitivess and the HuggingFace Inc. team. All rights reserved.
 
 
 
 
 
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
 
 
15
  import math
16
  from typing import List, Optional, Tuple, Union
17
 
18
  import torch
19
+ import torch.nn.functional as F
20
  import torch.utils.checkpoint
21
  from torch import nn
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
 
24
  from ...activations import ACT2FN
25
+ from ...cache_utils import Cache, DynamicCache, StaticCache
26
  from ...modeling_attn_mask_utils import AttentionMaskConverter
27
+ from ...modeling_flash_attention_utils import _flash_attention_forward
28
  from ...modeling_outputs import (
29
  BaseModelOutputWithPast,
30
  CausalLMOutputWithPast,
31
+ QuestionAnsweringModelOutput,
32
  SequenceClassifierOutputWithPast,
33
  TokenClassifierOutput,
34
  )
35
  from ...modeling_utils import PreTrainedModel
36
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
37
  from ...utils import (
38
  add_start_docstrings,
39
  add_start_docstrings_to_model_forward,
 
40
  is_flash_attn_greater_or_equal_2_10,
41
  logging,
42
  replace_return_docstrings,
 
44
  from .configuration_Cognitivess import CognitivessConfig
45
 
46
 
 
 
 
47
  logger = logging.get_logger(__name__)
48
 
49
  _CONFIG_FOR_DOC = "CognitivessConfig"
50
 
51
 
 
52
  class CognitivessRMSNorm(nn.Module):
53
  def __init__(self, hidden_size, eps=1e-6):
54
  """
 
66
  return self.weight * hidden_states.to(input_dtype)
67
 
68
 
69
+ ALL_LAYERNORM_LAYERS.append(CognitivessRMSNorm)
70
+
71
+
72
  class CognitivessRotaryEmbedding(nn.Module):
73
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
74
  super().__init__()
75
+ self.scaling_factor = scaling_factor
76
  self.dim = dim
77
  self.max_position_embeddings = max_position_embeddings
78
  self.base = base
79
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
80
  self.register_buffer("inv_freq", inv_freq, persistent=False)
81
+ # For BC we register cos and sin cached
82
+ self.max_seq_len_cached = max_position_embeddings
83
 
84
  @torch.no_grad()
 
85
  def forward(self, x, position_ids):
86
  # x: [bs, num_attention_heads, seq_len, head_size]
87
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
 
98
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
99
 
100
 
101
+ class CognitivessLinearScalingRotaryEmbedding(CognitivessRotaryEmbedding):
102
+ """CognitivessRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
103
+
104
+ def forward(self, x, position_ids):
105
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
106
+ position_ids = position_ids.float() / self.scaling_factor
107
+ cos, sin = super().forward(x, position_ids)
108
+ return cos, sin
109
+
110
+
111
+ class CognitivessDynamicNTKScalingRotaryEmbedding(CognitivessRotaryEmbedding):
112
+ """CognitivessRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
113
+
114
+ def forward(self, x, position_ids):
115
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
116
+ seq_len = torch.max(position_ids) + 1
117
+ if seq_len > self.max_position_embeddings:
118
+ base = self.base * (
119
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
120
+ ) ** (self.dim / (self.dim - 2))
121
+ inv_freq = 1.0 / (
122
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
123
+ )
124
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
125
+
126
+ cos, sin = super().forward(x, position_ids)
127
+ return cos, sin
128
+
129
+
130
  def rotate_half(x):
131
  """Rotates half the hidden dims of the input."""
132
  x1 = x[..., : x.shape[-1] // 2]
 
134
  return torch.cat((-x2, x1), dim=-1)
135
 
136
 
 
137
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
138
  """Applies Rotary Position Embedding to the query and key tensors.
139
 
 
164
  class CognitivessMLP(nn.Module):
165
  def __init__(self, config):
166
  super().__init__()
167
+ self.config = config
168
  self.hidden_size = config.hidden_size
169
  self.intermediate_size = config.intermediate_size
170
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
171
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
172
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
173
  self.act_fn = ACT2FN[config.hidden_act]
174
 
175
+ def forward(self, x):
176
+ if self.config.pretraining_tp > 1:
177
+ slice = self.intermediate_size // self.config.pretraining_tp
178
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
179
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
180
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
181
+
182
+ gate_proj = torch.cat(
183
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
184
+ )
185
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
186
+
187
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
188
+ down_proj = [
189
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
190
+ ]
191
+ down_proj = sum(down_proj)
192
+ else:
193
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
194
+
195
+ return down_proj
196
 
197
 
 
198
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
199
  """
200
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
208
 
209
 
210
  class CognitivessAttention(nn.Module):
211
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
 
 
 
212
 
213
  def __init__(self, config: CognitivessConfig, layer_idx: Optional[int] = None):
214
  super().__init__()
 
224
  self.attention_dropout = config.attention_dropout
225
  self.hidden_size = config.hidden_size
226
  self.num_heads = config.num_attention_heads
227
+ self.head_dim = self.hidden_size // self.num_heads
228
  self.num_key_value_heads = config.num_key_value_heads
229
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
230
  self.max_position_embeddings = config.max_position_embeddings
231
  self.rope_theta = config.rope_theta
232
  self.is_causal = True
233
 
234
+ if (self.head_dim * self.num_heads) != self.hidden_size:
235
+ raise ValueError(
236
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
237
+ f" and `num_heads`: {self.num_heads})."
238
+ )
239
 
240
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
241
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
242
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
243
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
244
+ self._init_rope()
245
+
246
+ def _init_rope(self):
247
+ if self.config.rope_scaling is None:
248
+ self.rotary_emb = CognitivessRotaryEmbedding(
249
+ self.head_dim,
250
+ max_position_embeddings=self.max_position_embeddings,
251
+ base=self.rope_theta,
252
+ )
253
+ else:
254
+ scaling_type = self.config.rope_scaling["type"]
255
+ scaling_factor = self.config.rope_scaling["factor"]
256
+ if scaling_type == "linear":
257
+ self.rotary_emb = CognitivessLinearScalingRotaryEmbedding(
258
+ self.head_dim,
259
+ max_position_embeddings=self.max_position_embeddings,
260
+ scaling_factor=scaling_factor,
261
+ base=self.rope_theta,
262
+ )
263
+ elif scaling_type == "dynamic":
264
+ self.rotary_emb = CognitivessDynamicNTKScalingRotaryEmbedding(
265
+ self.head_dim,
266
+ max_position_embeddings=self.max_position_embeddings,
267
+ scaling_factor=scaling_factor,
268
+ base=self.rope_theta,
269
+ )
270
+ else:
271
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
272
 
273
  def forward(
274
  self,
 
279
  output_attentions: bool = False,
280
  use_cache: bool = False,
281
  cache_position: Optional[torch.LongTensor] = None,
282
+ **kwargs,
283
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
284
  bsz, q_len, _ = hidden_states.size()
285
 
286
+ if self.config.pretraining_tp > 1:
287
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
288
+ query_slices = self.q_proj.weight.split(
289
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
290
+ )
291
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
292
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
293
+
294
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
295
+ query_states = torch.cat(query_states, dim=-1)
296
+
297
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
298
+ key_states = torch.cat(key_states, dim=-1)
299
+
300
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
301
+ value_states = torch.cat(value_states, dim=-1)
302
+
303
+ else:
304
+ query_states = self.q_proj(hidden_states)
305
+ key_states = self.k_proj(hidden_states)
306
+ value_states = self.v_proj(hidden_states)
307
 
308
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
309
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
339
 
340
  attn_output = attn_output.transpose(1, 2).contiguous()
341
 
342
+ attn_output = attn_output.reshape(bsz, q_len, -1)
343
+
344
+ if self.config.pretraining_tp > 1:
345
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
346
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
347
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
348
+ else:
349
+ attn_output = self.o_proj(attn_output)
350
 
351
  if not output_attentions:
352
  attn_weights = None
 
361
  flash attention and deal with padding tokens in case the input contains any of them.
362
  """
363
 
 
364
  def __init__(self, *args, **kwargs):
365
  super().__init__(*args, **kwargs)
366
 
 
372
  def forward(
373
  self,
374
  hidden_states: torch.Tensor,
375
+ attention_mask: Optional[torch.LongTensor] = None,
376
  position_ids: Optional[torch.LongTensor] = None,
377
  past_key_value: Optional[Cache] = None,
378
  output_attentions: bool = False,
379
  use_cache: bool = False,
380
  cache_position: Optional[torch.LongTensor] = None,
381
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
382
  if isinstance(past_key_value, StaticCache):
383
  raise ValueError(
384
  "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
 
393
  key_states = self.k_proj(hidden_states)
394
  value_states = self.v_proj(hidden_states)
395
 
396
+ # Flash attention requires the input to have the shape
397
+ # batch_size x seq_length x head_dim x hidden_dim
398
+ # therefore we just need to keep the original shape
399
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
400
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
401
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
402
 
 
 
 
 
403
  cos, sin = self.rotary_emb(value_states, position_ids)
404
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
405
 
406
  if past_key_value is not None:
407
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
408
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
410
 
411
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
412
+ # to be able to avoid many of these transpose/reshape/view.
413
+ query_states = query_states.transpose(1, 2)
414
+ key_states = key_states.transpose(1, 2)
415
+ value_states = value_states.transpose(1, 2)
416
+
417
+ dropout_rate = self.attention_dropout if self.training else 0.0
418
 
419
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
420
  # therefore the input hidden states gets silently casted in float32. Hence, we need
421
+ # cast them back in the correct dtype just to be sure everything works as expected.
422
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
423
+ # in fp32. (CognitivessRMSNorm handles it correctly)
424
+
425
  input_dtype = query_states.dtype
426
  if input_dtype == torch.float32:
427
  if torch.is_autocast_enabled():
 
442
  key_states = key_states.to(target_dtype)
443
  value_states = value_states.to(target_dtype)
444
 
 
 
 
 
 
445
  attn_output = _flash_attention_forward(
446
  query_states,
447
  key_states,
 
449
  attention_mask,
450
  q_len,
451
  dropout=dropout_rate,
452
+ sliding_window=getattr(self, "sliding_window", None),
453
  use_top_left_mask=self._flash_attn_uses_top_left_mask,
454
  is_causal=self.is_causal,
455
  )
456
 
457
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
458
  attn_output = self.o_proj(attn_output)
459
 
460
  if not output_attentions:
 
463
  return attn_output, attn_weights, past_key_value
464
 
465
 
 
466
  class CognitivessSdpaAttention(CognitivessAttention):
467
  """
468
  Cognitivess attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
 
558
  }
559
 
560
 
 
561
  class CognitivessDecoderLayer(nn.Module):
562
  def __init__(self, config: CognitivessConfig, layer_idx: int):
563
  super().__init__()
 
659
  base_model_prefix = "model"
660
  supports_gradient_checkpointing = True
661
  _no_split_modules = ["CognitivessDecoderLayer"]
662
+ _skip_keys_device_placement = ["past_key_values"]
663
  _supports_flash_attn_2 = True
664
  _supports_sdpa = True
665
  _supports_cache_class = True
666
+ _supports_quantized_cache = True
667
  _supports_static_cache = True
668
 
669
  def _init_weights(self, module):
 
699
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
700
  [`PreTrainedTokenizer.__call__`] for details.
701
 
702
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
703
  `past_key_values`).
704
 
705
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
 
745
  more detail.
746
  return_dict (`bool`, *optional*):
747
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
748
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
749
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
750
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
751
+ the complete sequence length.
752
  """
753
 
754
 
 
773
  self.layers = nn.ModuleList(
774
  [CognitivessDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
775
  )
 
776
  self.norm = CognitivessRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
777
  self.gradient_checkpointing = False
778
+
779
  # Initialize weights and apply final processing
780
  self.post_init()
781
 
 
804
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
805
  )
806
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
807
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
 
 
809
  if (input_ids is None) ^ (inputs_embeds is not None):
810
  raise ValueError(
811
  "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
 
813
 
814
  if self.gradient_checkpointing and self.training and use_cache:
815
  logger.warning_once(
816
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
817
  )
818
  use_cache = False
819
 
 
821
  inputs_embeds = self.embed_tokens(input_ids)
822
 
823
  return_legacy_cache = False
824
+ if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
 
825
  return_legacy_cache = True
826
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
827
  logger.warning_once(
828
  "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
829
  "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
 
834
  cache_position = torch.arange(
835
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
836
  )
 
837
  if position_ids is None:
838
  position_ids = cache_position.unsqueeze(0)
839
 
840
  causal_mask = self._update_causal_mask(
841
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
842
  )
843
 
844
+ # embed positions
845
  hidden_states = inputs_embeds
846
 
847
  # decoder layers
 
908
  input_tensor: torch.Tensor,
909
  cache_position: torch.Tensor,
910
  past_key_values: Cache,
 
911
  output_attentions: bool,
912
  ):
913
  # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
 
915
  # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
916
  # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
917
 
918
+ if self.config._attn_implementation == "flash_attention_2":
 
 
 
 
 
 
 
 
919
  if attention_mask is not None and 0.0 in attention_mask:
920
  return attention_mask
921
  return None
 
923
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
924
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
925
  # to infer the attention mask.
926
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
927
  using_static_cache = isinstance(past_key_values, StaticCache)
 
928
 
929
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
930
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
 
 
 
931
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
932
  attention_mask,
933
  inputs_embeds=input_tensor,
934
  past_key_values_length=past_seen_tokens,
 
935
  is_training=self.training,
936
  ):
937
  return None
 
939
  dtype, device = input_tensor.dtype, input_tensor.device
940
  min_dtype = torch.finfo(dtype).min
941
  sequence_length = input_tensor.shape[1]
942
+ if using_static_cache:
 
 
 
 
943
  target_length = past_key_values.get_max_length()
 
944
  else:
945
  target_length = (
946
  attention_mask.shape[-1]
 
957
  causal_mask = torch.full(
958
  (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
959
  )
960
+ if sequence_length != 1:
961
+ causal_mask = torch.triu(causal_mask, diagonal=1)
962
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
 
 
 
 
 
963
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
964
  if attention_mask is not None:
965
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
966
+ mask_length = attention_mask.shape[-1]
967
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
968
+ padding_mask = padding_mask == 0
969
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
970
+ padding_mask, min_dtype
971
+ )
 
 
972
  if (
973
  self.config._attn_implementation == "sdpa"
974
  and attention_mask is not None
 
1054
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1055
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1056
  ```"""
 
1057
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1058
  output_hidden_states = (
1059
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1075
  )
1076
 
1077
  hidden_states = outputs[0]
1078
+ if self.config.pretraining_tp > 1:
1079
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1080
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1081
+ logits = torch.cat(logits, dim=-1)
1082
+ else:
1083
+ logits = self.lm_head(hidden_states)
1084
  logits = logits.float()
1085
 
1086
  loss = None
 
1089
  shift_logits = logits[..., :-1, :].contiguous()
1090
  shift_labels = labels[..., 1:].contiguous()
1091
  # Flatten the tokens
1092
+ loss_fct = CrossEntropyLoss()
1093
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1094
  shift_labels = shift_labels.view(-1)
1095
+ # Enable model parallelism
1096
  shift_labels = shift_labels.to(shift_logits.device)
 
1097
  loss = loss_fct(shift_logits, shift_labels)
1098
 
1099
  if not return_dict:
 
1108
  attentions=outputs.attentions,
1109
  )
1110
 
 
1111
  def prepare_inputs_for_generation(
1112
  self,
1113
  input_ids,
 
1168
  """,
1169
  Cognitivess_START_DOCSTRING,
1170
  )
 
1171
  class CognitivessForSequenceClassification(CognitivessPreTrainedModel):
1172
  def __init__(self, config):
1173
  super().__init__(config)
 
1276
  )
1277
 
1278
 
1279
+ @add_start_docstrings(
1280
+ """
1281
+ The Cognitivess Model transformer with a span classification head on top for extractive question-answering tasks like
1282
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1283
+ """,
1284
+ Cognitivess_START_DOCSTRING,
1285
+ )
1286
+ class CognitivessForQuestionAnswering(CognitivessPreTrainedModel):
1287
+ base_model_prefix = "transformer"
1288
+
1289
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Cognitivess
1290
+ def __init__(self, config):
1291
+ super().__init__(config)
1292
+ self.transformer = CognitivessModel(config)
1293
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1294
+
1295
+ # Initialize weights and apply final processing
1296
+ self.post_init()
1297
+
1298
+ def get_input_embeddings(self):
1299
+ return self.transformer.embed_tokens
1300
+
1301
+ def set_input_embeddings(self, value):
1302
+ self.transformer.embed_tokens = value
1303
+
1304
+ @add_start_docstrings_to_model_forward(Cognitivess_INPUTS_DOCSTRING)
1305
+ def forward(
1306
+ self,
1307
+ input_ids: Optional[torch.LongTensor] = None,
1308
+ attention_mask: Optional[torch.FloatTensor] = None,
1309
+ position_ids: Optional[torch.LongTensor] = None,
1310
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1311
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1312
+ start_positions: Optional[torch.LongTensor] = None,
1313
+ end_positions: Optional[torch.LongTensor] = None,
1314
+ output_attentions: Optional[bool] = None,
1315
+ output_hidden_states: Optional[bool] = None,
1316
+ return_dict: Optional[bool] = None,
1317
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1318
+ r"""
1319
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1320
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1321
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1322
+ are not taken into account for computing the loss.
1323
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1324
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1325
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1326
+ are not taken into account for computing the loss.
1327
+ """
1328
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1329
+
1330
+ outputs = self.transformer(
1331
+ input_ids,
1332
+ attention_mask=attention_mask,
1333
+ position_ids=position_ids,
1334
+ past_key_values=past_key_values,
1335
+ inputs_embeds=inputs_embeds,
1336
+ output_attentions=output_attentions,
1337
+ output_hidden_states=output_hidden_states,
1338
+ return_dict=return_dict,
1339
+ )
1340
+
1341
+ sequence_output = outputs[0]
1342
+
1343
+ logits = self.qa_outputs(sequence_output)
1344
+ start_logits, end_logits = logits.split(1, dim=-1)
1345
+ start_logits = start_logits.squeeze(-1).contiguous()
1346
+ end_logits = end_logits.squeeze(-1).contiguous()
1347
+
1348
+ total_loss = None
1349
+ if start_positions is not None and end_positions is not None:
1350
+ # If we are on multi-GPU, split add a dimension
1351
+ if len(start_positions.size()) > 1:
1352
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1353
+ if len(end_positions.size()) > 1:
1354
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1355
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1356
+ ignored_index = start_logits.size(1)
1357
+ start_positions = start_positions.clamp(0, ignored_index)
1358
+ end_positions = end_positions.clamp(0, ignored_index)
1359
+
1360
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1361
+ start_loss = loss_fct(start_logits, start_positions)
1362
+ end_loss = loss_fct(end_logits, end_positions)
1363
+ total_loss = (start_loss + end_loss) / 2
1364
+
1365
+ if not return_dict:
1366
+ output = (start_logits, end_logits) + outputs[2:]
1367
+ return ((total_loss,) + output) if total_loss is not None else output
1368
+
1369
+ return QuestionAnsweringModelOutput(
1370
+ loss=total_loss,
1371
+ start_logits=start_logits,
1372
+ end_logits=end_logits,
1373
+ hidden_states=outputs.hidden_states,
1374
+ attentions=outputs.attentions,
1375
+ )
1376
+
1377
+
1378
  @add_start_docstrings(
1379
  """
1380
  The Cognitivess Model transformer with a token classification head on top (a linear layer on top of the hidden-states
 
1382
  """,
1383
  Cognitivess_START_DOCSTRING,
1384
  )
 
1385
  class CognitivessForTokenClassification(CognitivessPreTrainedModel):
1386
  def __init__(self, config):
1387
  super().__init__(config)