ldwang commited on
Commit
bb9204d
·
verified ·
1 Parent(s): 9b36856

Upload modeling_aquila.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_aquila.py +783 -257
modeling_aquila.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
@@ -17,61 +17,64 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
- """ PyTorch Aquila model."""
 
 
 
21
  import math
 
22
  from typing import List, Optional, Tuple, Union
23
 
24
  import torch
 
25
  import torch.utils.checkpoint
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
 
29
  from transformers.activations import ACT2FN
30
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
 
 
 
 
 
 
 
31
  from transformers.modeling_utils import PreTrainedModel
32
- from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
 
 
 
 
 
 
 
 
33
  from .configuration_aquila import AquilaConfig
34
 
35
 
36
- logger = logging.get_logger(__name__)
37
-
38
- _CONFIG_FOR_DOC = "AquilaConfig"
39
-
40
-
41
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42
- def _make_causal_mask(
43
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44
- ):
45
- """
46
- Make causal mask used for bi-directional self-attention.
47
- """
48
- bsz, tgt_len = input_ids_shape
49
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
50
- mask_cond = torch.arange(mask.size(-1), device=device)
51
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52
- mask = mask.to(dtype)
53
-
54
- if past_key_values_length > 0:
55
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57
 
58
 
59
- # Copied from transformers.models.bart.modeling_bart._expand_mask
60
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
- """
62
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
- """
64
- bsz, src_len = mask.size()
65
- tgt_len = tgt_len if tgt_len is not None else src_len
66
 
67
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
 
69
- inverted_mask = 1.0 - expanded_mask
70
 
71
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
 
 
 
 
 
 
 
 
 
72
 
73
 
74
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Aquila
75
  class AquilaRMSNorm(nn.Module):
76
  def __init__(self, hidden_size, eps=1e-6):
77
  """
@@ -83,92 +86,94 @@ class AquilaRMSNorm(nn.Module):
83
 
84
  def forward(self, hidden_states):
85
  input_dtype = hidden_states.dtype
86
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
 
87
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 
88
 
89
- return (self.weight * hidden_states).to(input_dtype)
90
 
 
91
 
92
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Aquila
93
- class AquilaRotaryEmbedding(torch.nn.Module):
94
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95
- super().__init__()
96
 
 
 
 
 
97
  self.dim = dim
98
  self.max_position_embeddings = max_position_embeddings
99
  self.base = base
100
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
101
  self.register_buffer("inv_freq", inv_freq, persistent=False)
102
-
103
- # Build here to make `torch.jit.trace` work.
104
- self._set_cos_sin_cache(
105
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
106
- )
107
-
108
- def _set_cos_sin_cache(self, seq_len, device, dtype):
109
- self.max_seq_len_cached = seq_len
110
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
111
-
112
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
113
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
114
  emb = torch.cat((freqs, freqs), dim=-1)
115
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
116
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
 
 
 
 
 
 
 
 
117
 
118
- def forward(self, x, seq_len=None):
 
 
 
 
 
 
 
 
 
119
  # x: [bs, num_attention_heads, seq_len, head_size]
120
- if seq_len > self.max_seq_len_cached:
121
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 
 
 
 
 
 
 
 
 
 
122
 
123
- return (
124
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
125
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
126
- )
127
 
128
- # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Aquila
129
  class AquilaLinearScalingRotaryEmbedding(AquilaRotaryEmbedding):
130
  """AquilaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
131
 
132
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
133
- self.scaling_factor = scaling_factor
134
- super().__init__(dim, max_position_embeddings, base, device)
 
 
135
 
136
- def _set_cos_sin_cache(self, seq_len, device, dtype):
137
- self.max_seq_len_cached = seq_len
138
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
139
- t = t / self.scaling_factor
140
 
141
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
142
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
143
- emb = torch.cat((freqs, freqs), dim=-1)
144
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
145
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
146
-
147
- # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Aquila
148
  class AquilaDynamicNTKScalingRotaryEmbedding(AquilaRotaryEmbedding):
149
  """AquilaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
150
 
151
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
152
- self.scaling_factor = scaling_factor
153
- super().__init__(dim, max_position_embeddings, base, device)
154
-
155
- def _set_cos_sin_cache(self, seq_len, device, dtype):
156
- self.max_seq_len_cached = seq_len
157
-
158
  if seq_len > self.max_position_embeddings:
159
  base = self.base * (
160
  (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
161
  ) ** (self.dim / (self.dim - 2))
162
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
163
- self.register_buffer("inv_freq", inv_freq, persistent=False)
164
-
165
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
166
 
167
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
168
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
169
- emb = torch.cat((freqs, freqs), dim=-1)
170
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
171
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
172
 
173
 
174
  def rotate_half(x):
@@ -178,18 +183,33 @@ def rotate_half(x):
178
  return torch.cat((-x2, x1), dim=-1)
179
 
180
 
181
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
182
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
183
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
184
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
185
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
186
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  q_embed = (q * cos) + (rotate_half(q) * sin)
188
  k_embed = (k * cos) + (rotate_half(k) * sin)
189
  return q_embed, k_embed
190
 
191
 
192
- # Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Aquila
193
  class AquilaMLP(nn.Module):
194
  def __init__(self, config):
195
  super().__init__()
@@ -236,12 +256,21 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
236
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
237
 
238
 
239
- # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Aquila
240
  class AquilaAttention(nn.Module):
241
  """Multi-headed attention from 'Attention Is All You Need' paper"""
242
- def __init__(self, config: AquilaConfig):
 
243
  super().__init__()
244
  self.config = config
 
 
 
 
 
 
 
 
 
245
  self.hidden_size = config.hidden_size
246
  self.num_heads = config.num_attention_heads
247
  self.head_dim = self.hidden_size // self.num_heads
@@ -249,16 +278,18 @@ class AquilaAttention(nn.Module):
249
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
250
  self.max_position_embeddings = config.max_position_embeddings
251
  self.rope_theta = config.rope_theta
 
252
 
253
  if (self.head_dim * self.num_heads) != self.hidden_size:
254
  raise ValueError(
255
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
256
  f" and `num_heads`: {self.num_heads})."
257
  )
258
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
259
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
260
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
261
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
262
  self._init_rope()
263
 
264
  def _init_rope(self):
@@ -288,17 +319,16 @@ class AquilaAttention(nn.Module):
288
  else:
289
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
290
 
291
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
292
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
293
-
294
  def forward(
295
  self,
296
  hidden_states: torch.Tensor,
297
  attention_mask: Optional[torch.Tensor] = None,
298
  position_ids: Optional[torch.LongTensor] = None,
299
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
300
  output_attentions: bool = False,
301
  use_cache: bool = False,
 
 
302
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
303
  bsz, q_len, _ = hidden_states.size()
304
 
@@ -328,40 +358,27 @@ class AquilaAttention(nn.Module):
328
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
329
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
330
 
331
- kv_seq_len = key_states.shape[-2]
332
- if past_key_value is not None:
333
- kv_seq_len += past_key_value[0].shape[-2]
334
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
335
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
336
 
337
  if past_key_value is not None:
338
- # reuse k, v, self_attention
339
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
340
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
341
 
342
- past_key_value = (key_states, value_states) if use_cache else None
343
-
344
- # repeat k/v heads if n_kv_heads < n_heads
345
  key_states = repeat_kv(key_states, self.num_key_value_groups)
346
  value_states = repeat_kv(value_states, self.num_key_value_groups)
347
 
348
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
349
- attn_weights = torch.clamp(attn_weights, min=-1024., max=1024.)
350
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
351
- raise ValueError(
352
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
353
- f" {attn_weights.size()}"
354
- )
355
 
356
- if attention_mask is not None:
357
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
358
- raise ValueError(
359
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
360
- )
361
- attn_weights = attn_weights + attention_mask
362
 
363
  # upcast attention to fp32
364
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
365
  attn_output = torch.matmul(attn_weights, value_states)
366
 
367
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -371,6 +388,7 @@ class AquilaAttention(nn.Module):
371
  )
372
 
373
  attn_output = attn_output.transpose(1, 2).contiguous()
 
374
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
375
 
376
  if self.config.pretraining_tp > 1:
@@ -386,12 +404,301 @@ class AquilaAttention(nn.Module):
386
  return attn_output, attn_weights, past_key_value
387
 
388
 
389
- # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Aquila
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  class AquilaDecoderLayer(nn.Module):
391
- def __init__(self, config: AquilaConfig):
392
  super().__init__()
393
  self.hidden_size = config.hidden_size
394
- self.self_attn = AquilaAttention(config=config)
 
 
395
  self.mlp = AquilaMLP(config)
396
  self.input_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
397
  self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -404,12 +711,15 @@ class AquilaDecoderLayer(nn.Module):
404
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
405
  output_attentions: Optional[bool] = False,
406
  use_cache: Optional[bool] = False,
 
 
407
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
408
  """
409
  Args:
410
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
411
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
412
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
 
413
  output_attentions (`bool`, *optional*):
414
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
415
  returned tensors for more detail.
@@ -418,6 +728,10 @@ class AquilaDecoderLayer(nn.Module):
418
  (see `past_key_values`).
419
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
420
  """
 
 
 
 
421
 
422
  residual = hidden_states
423
 
@@ -431,6 +745,8 @@ class AquilaDecoderLayer(nn.Module):
431
  past_key_value=past_key_value,
432
  output_attentions=output_attentions,
433
  use_cache=use_cache,
 
 
434
  )
435
  hidden_states = residual + hidden_states
436
 
@@ -450,6 +766,7 @@ class AquilaDecoderLayer(nn.Module):
450
 
451
  return outputs
452
 
 
453
  AQUILA_START_DOCSTRING = r"""
454
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
455
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -471,13 +788,15 @@ AQUILA_START_DOCSTRING = r"""
471
  "The bare Aquila Model outputting raw hidden-states without any specific head on top.",
472
  AQUILA_START_DOCSTRING,
473
  )
474
- # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Aquila
475
  class AquilaPreTrainedModel(PreTrainedModel):
476
  config_class = AquilaConfig
477
  base_model_prefix = "model"
478
  supports_gradient_checkpointing = True
479
  _no_split_modules = ["AquilaDecoderLayer"]
480
- _skip_keys_device_placement = "past_key_values"
 
 
 
481
 
482
  def _init_weights(self, module):
483
  std = self.config.initializer_range
@@ -490,9 +809,26 @@ class AquilaPreTrainedModel(PreTrainedModel):
490
  if module.padding_idx is not None:
491
  module.weight.data[module.padding_idx].zero_()
492
 
493
- def _set_gradient_checkpointing(self, module, value=False):
494
- if isinstance(module, AquilaModel):
495
- module.gradient_checkpointing = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
 
498
  AQUILA_INPUTS_DOCSTRING = r"""
@@ -516,7 +852,7 @@ AQUILA_INPUTS_DOCSTRING = r"""
516
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
517
  [`PreTrainedTokenizer.__call__`] for details.
518
 
519
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
520
  `past_key_values`).
521
 
522
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
@@ -530,17 +866,23 @@ AQUILA_INPUTS_DOCSTRING = r"""
530
  config.n_positions - 1]`.
531
 
532
  [What are position IDs?](../glossary#position-ids)
533
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
534
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
535
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
536
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
537
-
538
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
539
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
540
-
541
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
542
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
543
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
 
 
 
 
 
544
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
545
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
546
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
@@ -556,6 +898,10 @@ AQUILA_INPUTS_DOCSTRING = r"""
556
  more detail.
557
  return_dict (`bool`, *optional*):
558
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
 
 
 
 
559
  """
560
 
561
 
@@ -563,7 +909,6 @@ AQUILA_INPUTS_DOCSTRING = r"""
563
  "The bare Aquila Model outputting raw hidden-states without any specific head on top.",
564
  AQUILA_START_DOCSTRING,
565
  )
566
- # Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->AQUILA,Llama->Aquila
567
  class AquilaModel(AquilaPreTrainedModel):
568
  """
569
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AquilaDecoderLayer`]
@@ -578,10 +923,12 @@ class AquilaModel(AquilaPreTrainedModel):
578
  self.vocab_size = config.vocab_size
579
 
580
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
581
- self.layers = nn.ModuleList([AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
582
  self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
583
-
584
  self.gradient_checkpointing = False
 
585
  # Initialize weights and apply final processing
586
  self.post_init()
587
 
@@ -591,29 +938,6 @@ class AquilaModel(AquilaPreTrainedModel):
591
  def set_input_embeddings(self, value):
592
  self.embed_tokens = value
593
 
594
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
595
- # create causal mask
596
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
597
- combined_attention_mask = None
598
- if input_shape[-1] > 1:
599
- combined_attention_mask = _make_causal_mask(
600
- input_shape,
601
- inputs_embeds.dtype,
602
- device=inputs_embeds.device,
603
- past_key_values_length=past_key_values_length,
604
- )
605
-
606
- if attention_mask is not None:
607
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
608
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
609
- inputs_embeds.device
610
- )
611
- combined_attention_mask = (
612
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
613
- )
614
-
615
- return combined_attention_mask
616
-
617
  @add_start_docstrings_to_model_forward(AQUILA_INPUTS_DOCSTRING)
618
  def forward(
619
  self,
@@ -626,101 +950,85 @@ class AquilaModel(AquilaPreTrainedModel):
626
  output_attentions: Optional[bool] = None,
627
  output_hidden_states: Optional[bool] = None,
628
  return_dict: Optional[bool] = None,
 
629
  ) -> Union[Tuple, BaseModelOutputWithPast]:
630
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
631
  output_hidden_states = (
632
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
633
  )
634
  use_cache = use_cache if use_cache is not None else self.config.use_cache
635
-
636
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
637
 
638
- # retrieve input_ids and inputs_embeds
639
- if input_ids is not None and inputs_embeds is not None:
640
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
641
- elif input_ids is not None:
642
- batch_size, seq_length = input_ids.shape
643
- elif inputs_embeds is not None:
644
- batch_size, seq_length, _ = inputs_embeds.shape
645
- else:
646
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
647
-
648
- seq_length_with_past = seq_length
649
- past_key_values_length = 0
650
-
651
- if past_key_values is not None:
652
- past_key_values_length = past_key_values[0][0].shape[2]
653
- seq_length_with_past = seq_length_with_past + past_key_values_length
654
 
655
- if position_ids is None:
656
- device = input_ids.device if input_ids is not None else inputs_embeds.device
657
- position_ids = torch.arange(
658
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
659
  )
660
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
661
- else:
662
- position_ids = position_ids.view(-1, seq_length).long()
663
 
664
  if inputs_embeds is None:
665
  inputs_embeds = self.embed_tokens(input_ids)
666
- # embed positions
667
- if attention_mask is None:
668
- attention_mask = torch.ones(
669
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 
 
 
 
 
 
 
 
670
  )
671
- attention_mask = self._prepare_decoder_attention_mask(
672
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
673
- )
674
 
675
- hidden_states = inputs_embeds
 
676
 
677
- if self.gradient_checkpointing and self.training:
678
- if use_cache:
679
- logger.warning_once(
680
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
681
- )
682
- use_cache = False
683
 
684
  # decoder layers
685
  all_hidden_states = () if output_hidden_states else None
686
  all_self_attns = () if output_attentions else None
687
- next_decoder_cache = () if use_cache else None
688
 
689
- for idx, decoder_layer in enumerate(self.layers):
690
  if output_hidden_states:
691
  all_hidden_states += (hidden_states,)
692
 
693
- past_key_value = past_key_values[idx] if past_key_values is not None else None
694
-
695
  if self.gradient_checkpointing and self.training:
696
-
697
- def create_custom_forward(module):
698
- def custom_forward(*inputs):
699
- # None for past_key_value
700
- return module(*inputs, past_key_value, output_attentions)
701
-
702
- return custom_forward
703
-
704
- layer_outputs = torch.utils.checkpoint.checkpoint(
705
- create_custom_forward(decoder_layer),
706
  hidden_states,
707
- attention_mask,
708
  position_ids,
 
 
 
 
709
  )
710
  else:
711
  layer_outputs = decoder_layer(
712
  hidden_states,
713
- attention_mask=attention_mask,
714
  position_ids=position_ids,
715
- past_key_value=past_key_value,
716
  output_attentions=output_attentions,
717
  use_cache=use_cache,
 
718
  )
719
 
720
  hidden_states = layer_outputs[0]
721
 
722
  if use_cache:
723
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
724
 
725
  if output_attentions:
726
  all_self_attns += (layer_outputs[1],)
@@ -731,7 +1039,11 @@ class AquilaModel(AquilaPreTrainedModel):
731
  if output_hidden_states:
732
  all_hidden_states += (hidden_states,)
733
 
734
- next_cache = next_decoder_cache if use_cache else None
 
 
 
 
735
  if not return_dict:
736
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
737
  return BaseModelOutputWithPast(
@@ -741,7 +1053,70 @@ class AquilaModel(AquilaPreTrainedModel):
741
  attentions=all_self_attns,
742
  )
743
 
744
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->AQUILA,Llama->Aquila
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
  class AquilaForCausalLM(AquilaPreTrainedModel):
746
  _tied_weights_keys = ["lm_head.weight"]
747
 
@@ -786,6 +1161,7 @@ class AquilaForCausalLM(AquilaPreTrainedModel):
786
  output_attentions: Optional[bool] = None,
787
  output_hidden_states: Optional[bool] = None,
788
  return_dict: Optional[bool] = None,
 
789
  ) -> Union[Tuple, CausalLMOutputWithPast]:
790
  r"""
791
  Args:
@@ -801,18 +1177,17 @@ class AquilaForCausalLM(AquilaPreTrainedModel):
801
  ```python
802
  >>> from transformers import AutoTokenizer, AquilaForCausalLM
803
 
804
- >>> model = AquilaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
805
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
806
 
807
- >>> prompt = "Hey, are you consciours? Can you talk to me?"
808
  >>> inputs = tokenizer(prompt, return_tensors="pt")
809
 
810
  >>> # Generate
811
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
812
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
813
- "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
814
  ```"""
815
-
816
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
817
  output_hidden_states = (
818
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -830,6 +1205,7 @@ class AquilaForCausalLM(AquilaPreTrainedModel):
830
  output_attentions=output_attentions,
831
  output_hidden_states=output_hidden_states,
832
  return_dict=return_dict,
 
833
  )
834
 
835
  hidden_states = outputs[0]
@@ -867,10 +1243,49 @@ class AquilaForCausalLM(AquilaPreTrainedModel):
867
  )
868
 
869
  def prepare_inputs_for_generation(
870
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
871
  ):
872
- if past_key_values:
873
- input_ids = input_ids[:, -1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
 
875
  position_ids = kwargs.get("position_ids", None)
876
  if attention_mask is not None and position_ids is None:
@@ -878,17 +1293,30 @@ class AquilaForCausalLM(AquilaPreTrainedModel):
878
  position_ids = attention_mask.long().cumsum(-1) - 1
879
  position_ids.masked_fill_(attention_mask == 0, 1)
880
  if past_key_values:
881
- position_ids = position_ids[:, -1].unsqueeze(-1)
882
 
883
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
884
  if inputs_embeds is not None and past_key_values is None:
885
  model_inputs = {"inputs_embeds": inputs_embeds}
886
  else:
887
- model_inputs = {"input_ids": input_ids}
 
 
 
 
 
 
 
 
 
 
 
 
888
 
889
  model_inputs.update(
890
  {
891
  "position_ids": position_ids,
 
892
  "past_key_values": past_key_values,
893
  "use_cache": kwargs.get("use_cache"),
894
  "attention_mask": attention_mask,
@@ -905,9 +1333,10 @@ class AquilaForCausalLM(AquilaPreTrainedModel):
905
  )
906
  return reordered_past
907
 
 
908
  @add_start_docstrings(
909
  """
910
- The LLaMa Model transformer with a sequence classification head on top (linear layer).
911
 
912
  [`AquilaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
913
  (e.g. GPT-2) do.
@@ -920,10 +1349,7 @@ class AquilaForCausalLM(AquilaPreTrainedModel):
920
  """,
921
  AQUILA_START_DOCSTRING,
922
  )
923
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->AQUILA,Llama->Aquila
924
  class AquilaForSequenceClassification(AquilaPreTrainedModel):
925
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
926
-
927
  def __init__(self, config):
928
  super().__init__(config)
929
  self.num_labels = config.num_labels
@@ -986,9 +1412,10 @@ class AquilaForSequenceClassification(AquilaPreTrainedModel):
986
  sequence_lengths = -1
987
  else:
988
  if input_ids is not None:
989
- sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
990
- logits.device
991
- )
 
992
  else:
993
  sequence_lengths = -1
994
 
@@ -1028,3 +1455,102 @@ class AquilaForSequenceClassification(AquilaPreTrainedModel):
1028
  hidden_states=transformer_outputs.hidden_states,
1029
  attentions=transformer_outputs.attentions,
1030
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
+
21
+ # Most of the source code is adapted from Llama's source code
22
+ """PyTorch Aquila model."""
23
+
24
  import math
25
+ import warnings
26
  from typing import List, Optional, Tuple, Union
27
 
28
  import torch
29
+ import torch.nn.functional as F
30
  import torch.utils.checkpoint
31
  from torch import nn
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
 
34
  from transformers.activations import ACT2FN
35
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
36
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ CausalLMOutputWithPast,
40
+ QuestionAnsweringModelOutput,
41
+ SequenceClassifierOutputWithPast,
42
+ )
43
  from transformers.modeling_utils import PreTrainedModel
44
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
45
+ from transformers.utils import (
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ is_flash_attn_2_available,
49
+ is_flash_attn_greater_or_equal_2_10,
50
+ logging,
51
+ replace_return_docstrings,
52
+ )
53
  from .configuration_aquila import AquilaConfig
54
 
55
 
56
+ if is_flash_attn_2_available():
57
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
58
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
+ logger = logging.get_logger(__name__)
 
 
 
 
 
 
62
 
63
+ _CONFIG_FOR_DOC = "AquilaConfig"
64
 
 
65
 
66
+ def _get_unpad_data(attention_mask):
67
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
70
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
71
+ return (
72
+ indices,
73
+ cu_seqlens,
74
+ max_seqlen_in_batch,
75
+ )
76
 
77
 
 
78
  class AquilaRMSNorm(nn.Module):
79
  def __init__(self, hidden_size, eps=1e-6):
80
  """
 
86
 
87
  def forward(self, hidden_states):
88
  input_dtype = hidden_states.dtype
89
+ hidden_states = hidden_states.to(torch.float32)
90
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
91
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
92
+ return self.weight * hidden_states.to(input_dtype)
93
 
 
94
 
95
+ ALL_LAYERNORM_LAYERS.append(AquilaRMSNorm)
96
 
 
 
 
 
97
 
98
+ class AquilaRotaryEmbedding(nn.Module):
99
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
100
+ super().__init__()
101
+ self.scaling_factor = scaling_factor
102
  self.dim = dim
103
  self.max_position_embeddings = max_position_embeddings
104
  self.base = base
105
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
106
  self.register_buffer("inv_freq", inv_freq, persistent=False)
107
+ # For BC we register cos and sin cached
108
+ self.max_seq_len_cached = max_position_embeddings
109
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
110
+ t = t / self.scaling_factor
111
+ freqs = torch.outer(t, self.inv_freq)
 
 
 
 
 
 
112
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
113
  emb = torch.cat((freqs, freqs), dim=-1)
114
+ self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
115
+ self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
116
+
117
+ @property
118
+ def sin_cached(self):
119
+ logger.warning_once(
120
+ "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
121
+ "the forward method of RoPE from now on instead. It is not used in the `AquilaAttention` class"
122
+ )
123
+ return self._sin_cached
124
 
125
+ @property
126
+ def cos_cached(self):
127
+ logger.warning_once(
128
+ "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
129
+ "the forward method of RoPE from now on instead. It is not used in the `AquilaAttention` class"
130
+ )
131
+ return self._cos_cached
132
+
133
+ @torch.no_grad()
134
+ def forward(self, x, position_ids):
135
  # x: [bs, num_attention_heads, seq_len, head_size]
136
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
137
+ position_ids_expanded = position_ids[:, None, :].float()
138
+ # Force float32 since bfloat16 loses precision on long contexts
139
+ # See https://github.com/huggingface/transformers/pull/29285
140
+ device_type = x.device.type
141
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
142
+ with torch.autocast(device_type=device_type, enabled=False):
143
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
144
+ emb = torch.cat((freqs, freqs), dim=-1)
145
+ cos = emb.cos()
146
+ sin = emb.sin()
147
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
148
 
 
 
 
 
149
 
 
150
  class AquilaLinearScalingRotaryEmbedding(AquilaRotaryEmbedding):
151
  """AquilaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
152
 
153
+ def forward(self, x, position_ids):
154
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
155
+ position_ids = position_ids.float() / self.scaling_factor
156
+ cos, sin = super().forward(x, position_ids)
157
+ return cos, sin
158
 
 
 
 
 
159
 
 
 
 
 
 
 
 
160
  class AquilaDynamicNTKScalingRotaryEmbedding(AquilaRotaryEmbedding):
161
  """AquilaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
162
 
163
+ def forward(self, x, position_ids):
164
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
165
+ seq_len = torch.max(position_ids) + 1
 
 
 
 
166
  if seq_len > self.max_position_embeddings:
167
  base = self.base * (
168
  (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
169
  ) ** (self.dim / (self.dim - 2))
170
+ inv_freq = 1.0 / (
171
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
172
+ )
173
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
174
 
175
+ cos, sin = super().forward(x, position_ids)
176
+ return cos, sin
 
 
 
177
 
178
 
179
  def rotate_half(x):
 
183
  return torch.cat((-x2, x1), dim=-1)
184
 
185
 
186
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
187
+ """Applies Rotary Position Embedding to the query and key tensors.
188
+
189
+ Args:
190
+ q (`torch.Tensor`): The query tensor.
191
+ k (`torch.Tensor`): The key tensor.
192
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
193
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
194
+ position_ids (`torch.Tensor`, *optional*):
195
+ Deprecated and unused.
196
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
197
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
198
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
199
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
200
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
201
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
202
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
203
+ Returns:
204
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
205
+ """
206
+ cos = cos.unsqueeze(unsqueeze_dim)
207
+ sin = sin.unsqueeze(unsqueeze_dim)
208
  q_embed = (q * cos) + (rotate_half(q) * sin)
209
  k_embed = (k * cos) + (rotate_half(k) * sin)
210
  return q_embed, k_embed
211
 
212
 
 
213
  class AquilaMLP(nn.Module):
214
  def __init__(self, config):
215
  super().__init__()
 
256
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
257
 
258
 
 
259
  class AquilaAttention(nn.Module):
260
  """Multi-headed attention from 'Attention Is All You Need' paper"""
261
+
262
+ def __init__(self, config: AquilaConfig, layer_idx: Optional[int] = None):
263
  super().__init__()
264
  self.config = config
265
+ self.layer_idx = layer_idx
266
+ if layer_idx is None:
267
+ logger.warning_once(
268
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
269
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
270
+ "when creating this class."
271
+ )
272
+
273
+ self.attention_dropout = config.attention_dropout
274
  self.hidden_size = config.hidden_size
275
  self.num_heads = config.num_attention_heads
276
  self.head_dim = self.hidden_size // self.num_heads
 
278
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
279
  self.max_position_embeddings = config.max_position_embeddings
280
  self.rope_theta = config.rope_theta
281
+ self.is_causal = True
282
 
283
  if (self.head_dim * self.num_heads) != self.hidden_size:
284
  raise ValueError(
285
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
286
  f" and `num_heads`: {self.num_heads})."
287
  )
288
+
289
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
290
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
291
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
292
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
293
  self._init_rope()
294
 
295
  def _init_rope(self):
 
319
  else:
320
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
321
 
 
 
 
322
  def forward(
323
  self,
324
  hidden_states: torch.Tensor,
325
  attention_mask: Optional[torch.Tensor] = None,
326
  position_ids: Optional[torch.LongTensor] = None,
327
+ past_key_value: Optional[Cache] = None,
328
  output_attentions: bool = False,
329
  use_cache: bool = False,
330
+ cache_position: Optional[torch.LongTensor] = None,
331
+ **kwargs,
332
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
333
  bsz, q_len, _ = hidden_states.size()
334
 
 
358
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
359
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
360
 
361
+ past_key_value = getattr(self, "past_key_value", past_key_value)
362
+ cos, sin = self.rotary_emb(value_states, position_ids)
363
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
 
364
 
365
  if past_key_value is not None:
366
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
367
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
368
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
369
 
 
 
 
370
  key_states = repeat_kv(key_states, self.num_key_value_groups)
371
  value_states = repeat_kv(value_states, self.num_key_value_groups)
372
 
373
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
 
 
 
 
374
 
375
+ if attention_mask is not None: # no matter the length, we just slice it
376
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
377
+ attn_weights = attn_weights + causal_mask
 
 
 
378
 
379
  # upcast attention to fp32
380
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
381
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
382
  attn_output = torch.matmul(attn_weights, value_states)
383
 
384
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
388
  )
389
 
390
  attn_output = attn_output.transpose(1, 2).contiguous()
391
+
392
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
393
 
394
  if self.config.pretraining_tp > 1:
 
404
  return attn_output, attn_weights, past_key_value
405
 
406
 
407
+ class AquilaFlashAttention2(AquilaAttention):
408
+ """
409
+ Aquila flash attention module. This module inherits from `AquilaAttention` as the weights of the module stays
410
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
411
+ flash attention and deal with padding tokens in case the input contains any of them.
412
+ """
413
+
414
+ def __init__(self, *args, **kwargs):
415
+ super().__init__(*args, **kwargs)
416
+
417
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
418
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
419
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
420
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.LongTensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ past_key_value: Optional[Cache] = None,
428
+ output_attentions: bool = False,
429
+ use_cache: bool = False,
430
+ cache_position: Optional[torch.LongTensor] = None,
431
+ **kwargs,
432
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
433
+ output_attentions = False
434
+
435
+ bsz, q_len, _ = hidden_states.size()
436
+
437
+ query_states = self.q_proj(hidden_states)
438
+ key_states = self.k_proj(hidden_states)
439
+ value_states = self.v_proj(hidden_states)
440
+
441
+ # Flash attention requires the input to have the shape
442
+ # batch_size x seq_length x head_dim x hidden_dim
443
+ # therefore we just need to keep the original shape
444
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
445
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
446
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
447
+
448
+ cos, sin = self.rotary_emb(value_states, position_ids)
449
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
450
+
451
+ past_key_value = getattr(self, "past_key_value", past_key_value)
452
+
453
+ if past_key_value is not None:
454
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
455
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
456
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
457
+
458
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
459
+ # to be able to avoid many of these transpose/reshape/view.
460
+ query_states = query_states.transpose(1, 2)
461
+ key_states = key_states.transpose(1, 2)
462
+ value_states = value_states.transpose(1, 2)
463
+
464
+ dropout_rate = self.attention_dropout if self.training else 0.0
465
+
466
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
467
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
468
+ # cast them back in the correct dtype just to be sure everything works as expected.
469
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
470
+ # in fp32. (AquilaRMSNorm handles it correctly)
471
+
472
+ input_dtype = query_states.dtype
473
+ if input_dtype == torch.float32:
474
+ if torch.is_autocast_enabled():
475
+ target_dtype = torch.get_autocast_gpu_dtype()
476
+ # Handle the case where the model is quantized
477
+ elif hasattr(self.config, "_pre_quantization_dtype"):
478
+ target_dtype = self.config._pre_quantization_dtype
479
+ else:
480
+ target_dtype = self.q_proj.weight.dtype
481
+
482
+ logger.warning_once(
483
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
484
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
485
+ f" {target_dtype}."
486
+ )
487
+
488
+ query_states = query_states.to(target_dtype)
489
+ key_states = key_states.to(target_dtype)
490
+ value_states = value_states.to(target_dtype)
491
+
492
+ attn_output = self._flash_attention_forward(
493
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
494
+ )
495
+
496
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
497
+ attn_output = self.o_proj(attn_output)
498
+
499
+ if not output_attentions:
500
+ attn_weights = None
501
+
502
+ return attn_output, attn_weights, past_key_value
503
+
504
+ def _flash_attention_forward(
505
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
506
+ ):
507
+ """
508
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
509
+ first unpad the input, then computes the attention scores and pad the final attention scores.
510
+
511
+ Args:
512
+ query_states (`torch.Tensor`):
513
+ Input query states to be passed to Flash Attention API
514
+ key_states (`torch.Tensor`):
515
+ Input key states to be passed to Flash Attention API
516
+ value_states (`torch.Tensor`):
517
+ Input value states to be passed to Flash Attention API
518
+ attention_mask (`torch.Tensor`):
519
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
520
+ position of padding tokens and 1 for the position of non-padding tokens.
521
+ dropout (`float`):
522
+ Attention dropout
523
+ softmax_scale (`float`, *optional*):
524
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
525
+ """
526
+ if not self._flash_attn_uses_top_left_mask:
527
+ causal = self.is_causal
528
+ else:
529
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in AquilaFlashAttention2 __init__.
530
+ causal = self.is_causal and query_length != 1
531
+
532
+ # Contains at least one padding token in the sequence
533
+ if attention_mask is not None:
534
+ batch_size = query_states.shape[0]
535
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
536
+ query_states, key_states, value_states, attention_mask, query_length
537
+ )
538
+
539
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
540
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
541
+
542
+ attn_output_unpad = flash_attn_varlen_func(
543
+ query_states,
544
+ key_states,
545
+ value_states,
546
+ cu_seqlens_q=cu_seqlens_q,
547
+ cu_seqlens_k=cu_seqlens_k,
548
+ max_seqlen_q=max_seqlen_in_batch_q,
549
+ max_seqlen_k=max_seqlen_in_batch_k,
550
+ dropout_p=dropout,
551
+ softmax_scale=softmax_scale,
552
+ causal=causal,
553
+ )
554
+
555
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
556
+ else:
557
+ attn_output = flash_attn_func(
558
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
559
+ )
560
+
561
+ return attn_output
562
+
563
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
564
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
565
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
566
+
567
+ key_layer = index_first_axis(
568
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
569
+ )
570
+ value_layer = index_first_axis(
571
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
572
+ )
573
+ if query_length == kv_seq_len:
574
+ query_layer = index_first_axis(
575
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
576
+ )
577
+ cu_seqlens_q = cu_seqlens_k
578
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
579
+ indices_q = indices_k
580
+ elif query_length == 1:
581
+ max_seqlen_in_batch_q = 1
582
+ cu_seqlens_q = torch.arange(
583
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
584
+ ) # There is a memcpy here, that is very bad.
585
+ indices_q = cu_seqlens_q[:-1]
586
+ query_layer = query_layer.squeeze(1)
587
+ else:
588
+ # The -q_len: slice assumes left padding.
589
+ attention_mask = attention_mask[:, -query_length:]
590
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
591
+
592
+ return (
593
+ query_layer,
594
+ key_layer,
595
+ value_layer,
596
+ indices_q,
597
+ (cu_seqlens_q, cu_seqlens_k),
598
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
599
+ )
600
+
601
+
602
+ class AquilaSdpaAttention(AquilaAttention):
603
+ """
604
+ Aquila attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
605
+ `AquilaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
606
+ SDPA API.
607
+ """
608
+
609
+ # Adapted from AquilaAttention.forward
610
+ def forward(
611
+ self,
612
+ hidden_states: torch.Tensor,
613
+ attention_mask: Optional[torch.Tensor] = None,
614
+ position_ids: Optional[torch.LongTensor] = None,
615
+ past_key_value: Optional[Cache] = None,
616
+ output_attentions: bool = False,
617
+ use_cache: bool = False,
618
+ cache_position: Optional[torch.LongTensor] = None,
619
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
620
+ if output_attentions:
621
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
622
+ logger.warning_once(
623
+ "AquilaModel is using AquilaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
624
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
625
+ )
626
+ return super().forward(
627
+ hidden_states=hidden_states,
628
+ attention_mask=attention_mask,
629
+ position_ids=position_ids,
630
+ past_key_value=past_key_value,
631
+ output_attentions=output_attentions,
632
+ use_cache=use_cache,
633
+ cache_position=cache_position,
634
+ )
635
+
636
+ bsz, q_len, _ = hidden_states.size()
637
+
638
+ query_states = self.q_proj(hidden_states)
639
+ key_states = self.k_proj(hidden_states)
640
+ value_states = self.v_proj(hidden_states)
641
+
642
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
643
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
644
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
645
+
646
+ cos, sin = self.rotary_emb(value_states, position_ids)
647
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
648
+
649
+ # In case static cache is used, it is an instance attribute.
650
+ past_key_value = getattr(self, "past_key_value", past_key_value)
651
+
652
+ if past_key_value is not None:
653
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
654
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
655
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
656
+
657
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
658
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
659
+
660
+ causal_mask = attention_mask
661
+ # if attention_mask is not None and cache_position is not None:
662
+ if attention_mask is not None:
663
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
664
+
665
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
666
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
667
+ if query_states.device.type == "cuda" and causal_mask is not None:
668
+ query_states = query_states.contiguous()
669
+ key_states = key_states.contiguous()
670
+ value_states = value_states.contiguous()
671
+
672
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
673
+ query_states,
674
+ key_states,
675
+ value_states,
676
+ attn_mask=causal_mask,
677
+ dropout_p=self.attention_dropout if self.training else 0.0,
678
+ )
679
+
680
+ attn_output = attn_output.transpose(1, 2).contiguous()
681
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
682
+
683
+ attn_output = self.o_proj(attn_output)
684
+
685
+ return attn_output, None, past_key_value
686
+
687
+
688
+ AQUILA_ATTENTION_CLASSES = {
689
+ "eager": AquilaAttention,
690
+ "flash_attention_2": AquilaFlashAttention2,
691
+ "sdpa": AquilaSdpaAttention,
692
+ }
693
+
694
+
695
  class AquilaDecoderLayer(nn.Module):
696
+ def __init__(self, config: AquilaConfig, layer_idx: int):
697
  super().__init__()
698
  self.hidden_size = config.hidden_size
699
+
700
+ self.self_attn = AQUILA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
701
+
702
  self.mlp = AquilaMLP(config)
703
  self.input_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
704
  self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
711
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
712
  output_attentions: Optional[bool] = False,
713
  use_cache: Optional[bool] = False,
714
+ cache_position: Optional[torch.LongTensor] = None,
715
+ **kwargs,
716
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
717
  """
718
  Args:
719
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
720
+ attention_mask (`torch.FloatTensor`, *optional*):
721
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
722
+ query_sequence_length, key_sequence_length)` if default attention is used.
723
  output_attentions (`bool`, *optional*):
724
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
725
  returned tensors for more detail.
 
728
  (see `past_key_values`).
729
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
730
  """
731
+ if "padding_mask" in kwargs:
732
+ warnings.warn(
733
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
734
+ )
735
 
736
  residual = hidden_states
737
 
 
745
  past_key_value=past_key_value,
746
  output_attentions=output_attentions,
747
  use_cache=use_cache,
748
+ cache_position=cache_position,
749
+ **kwargs,
750
  )
751
  hidden_states = residual + hidden_states
752
 
 
766
 
767
  return outputs
768
 
769
+
770
  AQUILA_START_DOCSTRING = r"""
771
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
772
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
788
  "The bare Aquila Model outputting raw hidden-states without any specific head on top.",
789
  AQUILA_START_DOCSTRING,
790
  )
 
791
  class AquilaPreTrainedModel(PreTrainedModel):
792
  config_class = AquilaConfig
793
  base_model_prefix = "model"
794
  supports_gradient_checkpointing = True
795
  _no_split_modules = ["AquilaDecoderLayer"]
796
+ _skip_keys_device_placement = ["past_key_values"]
797
+ _supports_flash_attn_2 = True
798
+ _supports_sdpa = True
799
+ _supports_cache_class = True
800
 
801
  def _init_weights(self, module):
802
  std = self.config.initializer_range
 
809
  if module.padding_idx is not None:
810
  module.weight.data[module.padding_idx].zero_()
811
 
812
+ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
813
+ if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
814
+ raise ValueError(
815
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
816
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
817
+ )
818
+
819
+ for layer in self.model.layers:
820
+ device = layer.input_layernorm.weight.device
821
+ if hasattr(self.config, "_pre_quantization_dtype"):
822
+ dtype = self.config._pre_quantization_dtype
823
+ else:
824
+ dtype = layer.self_attn.o_proj.weight.dtype
825
+ layer.self_attn.past_key_value = cache_cls(
826
+ self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
827
+ )
828
+
829
+ def _reset_cache(self):
830
+ for layer in self.model.layers:
831
+ layer.self_attn.past_key_value = None
832
 
833
 
834
  AQUILA_INPUTS_DOCSTRING = r"""
 
852
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
853
  [`PreTrainedTokenizer.__call__`] for details.
854
 
855
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
856
  `past_key_values`).
857
 
858
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
 
866
  config.n_positions - 1]`.
867
 
868
  [What are position IDs?](../glossary#position-ids)
869
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
870
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
871
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
872
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
873
+
874
+ Two formats are allowed:
875
+ - a [`~cache_utils.Cache`] instance;
876
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
877
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
878
+ cache format.
879
+
880
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
881
+ legacy cache format will be returned.
882
+
883
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
884
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
885
+ of shape `(batch_size, sequence_length)`.
886
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
887
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
888
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
 
898
  more detail.
899
  return_dict (`bool`, *optional*):
900
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
901
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
902
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
903
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
904
+ the complete sequence length.
905
  """
906
 
907
 
 
909
  "The bare Aquila Model outputting raw hidden-states without any specific head on top.",
910
  AQUILA_START_DOCSTRING,
911
  )
 
912
  class AquilaModel(AquilaPreTrainedModel):
913
  """
914
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AquilaDecoderLayer`]
 
923
  self.vocab_size = config.vocab_size
924
 
925
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
926
+ self.layers = nn.ModuleList(
927
+ [AquilaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
928
+ )
929
  self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
930
  self.gradient_checkpointing = False
931
+
932
  # Initialize weights and apply final processing
933
  self.post_init()
934
 
 
938
  def set_input_embeddings(self, value):
939
  self.embed_tokens = value
940
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
  @add_start_docstrings_to_model_forward(AQUILA_INPUTS_DOCSTRING)
942
  def forward(
943
  self,
 
950
  output_attentions: Optional[bool] = None,
951
  output_hidden_states: Optional[bool] = None,
952
  return_dict: Optional[bool] = None,
953
+ cache_position: Optional[torch.LongTensor] = None,
954
  ) -> Union[Tuple, BaseModelOutputWithPast]:
955
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
956
  output_hidden_states = (
957
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
958
  )
959
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
960
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
961
 
962
+ if (input_ids is None) ^ (inputs_embeds is not None):
963
+ raise ValueError(
964
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
965
+ )
 
 
 
 
 
 
 
 
 
 
 
 
966
 
967
+ if self.gradient_checkpointing and self.training and use_cache:
968
+ logger.warning_once(
969
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
 
970
  )
971
+ use_cache = False
 
 
972
 
973
  if inputs_embeds is None:
974
  inputs_embeds = self.embed_tokens(input_ids)
975
+
976
+ past_seen_tokens = 0
977
+ if use_cache: # kept for BC (cache positions)
978
+ if not isinstance(past_key_values, StaticCache):
979
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
980
+ past_seen_tokens = past_key_values.get_seq_length()
981
+
982
+ if cache_position is None:
983
+ if isinstance(past_key_values, StaticCache):
984
+ raise ValueError("cache_position is a required argument when using StaticCache.")
985
+ cache_position = torch.arange(
986
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
987
  )
 
 
 
988
 
989
+ if position_ids is None:
990
+ position_ids = cache_position.unsqueeze(0)
991
 
992
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
993
+
994
+ # embed positions
995
+ hidden_states = inputs_embeds
 
 
996
 
997
  # decoder layers
998
  all_hidden_states = () if output_hidden_states else None
999
  all_self_attns = () if output_attentions else None
1000
+ next_decoder_cache = None
1001
 
1002
+ for decoder_layer in self.layers:
1003
  if output_hidden_states:
1004
  all_hidden_states += (hidden_states,)
1005
 
 
 
1006
  if self.gradient_checkpointing and self.training:
1007
+ layer_outputs = self._gradient_checkpointing_func(
1008
+ decoder_layer.__call__,
 
 
 
 
 
 
 
 
1009
  hidden_states,
1010
+ causal_mask,
1011
  position_ids,
1012
+ past_key_values,
1013
+ output_attentions,
1014
+ use_cache,
1015
+ cache_position,
1016
  )
1017
  else:
1018
  layer_outputs = decoder_layer(
1019
  hidden_states,
1020
+ attention_mask=causal_mask,
1021
  position_ids=position_ids,
1022
+ past_key_value=past_key_values,
1023
  output_attentions=output_attentions,
1024
  use_cache=use_cache,
1025
+ cache_position=cache_position,
1026
  )
1027
 
1028
  hidden_states = layer_outputs[0]
1029
 
1030
  if use_cache:
1031
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1032
 
1033
  if output_attentions:
1034
  all_self_attns += (layer_outputs[1],)
 
1039
  if output_hidden_states:
1040
  all_hidden_states += (hidden_states,)
1041
 
1042
+ next_cache = None
1043
+ if use_cache:
1044
+ next_cache = (
1045
+ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1046
+ )
1047
  if not return_dict:
1048
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1049
  return BaseModelOutputWithPast(
 
1053
  attentions=all_self_attns,
1054
  )
1055
 
1056
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1057
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1058
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1059
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1060
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1061
+ if self.config._attn_implementation == "flash_attention_2":
1062
+ if attention_mask is not None and 0.0 in attention_mask:
1063
+ return attention_mask
1064
+ return None
1065
+
1066
+ dtype, device = input_tensor.dtype, input_tensor.device
1067
+ min_dtype = torch.finfo(dtype).min
1068
+ sequence_length = input_tensor.shape[1]
1069
+ if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
1070
+ target_length = self.config.max_position_embeddings
1071
+ else: # dynamic cache
1072
+ target_length = (
1073
+ attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
1074
+ )
1075
+
1076
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1077
+ if sequence_length != 1:
1078
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1079
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1080
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1081
+ if attention_mask is not None:
1082
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1083
+ if attention_mask.dim() == 2:
1084
+ mask_length = attention_mask.shape[-1]
1085
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1086
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1087
+ elif attention_mask.dim() == 4:
1088
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1089
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1090
+ if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1091
+ offset = cache_position[0]
1092
+ else:
1093
+ offset = 0
1094
+ mask_shape = attention_mask.shape
1095
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1096
+ causal_mask[
1097
+ : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1098
+ ] = mask_slice
1099
+
1100
+ if (
1101
+ self.config._attn_implementation == "sdpa"
1102
+ and attention_mask is not None
1103
+ and attention_mask.device.type == "cuda"
1104
+ ):
1105
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1106
+ is_tracing = (
1107
+ torch.jit.is_tracing()
1108
+ or isinstance(input_tensor, torch.fx.Proxy)
1109
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1110
+ )
1111
+ if not is_tracing and torch.any(attention_mask != 1):
1112
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1113
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1114
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1115
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1116
+
1117
+ return causal_mask
1118
+
1119
+
1120
  class AquilaForCausalLM(AquilaPreTrainedModel):
1121
  _tied_weights_keys = ["lm_head.weight"]
1122
 
 
1161
  output_attentions: Optional[bool] = None,
1162
  output_hidden_states: Optional[bool] = None,
1163
  return_dict: Optional[bool] = None,
1164
+ cache_position: Optional[torch.LongTensor] = None,
1165
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1166
  r"""
1167
  Args:
 
1177
  ```python
1178
  >>> from transformers import AutoTokenizer, AquilaForCausalLM
1179
 
1180
+ >>> model = AquilaForCausalLM.from_pretrained("BAAI/Aquila2-7B")
1181
+ >>> tokenizer = AutoTokenizer.from_pretrained("BAAI/Aquila2-7B")
1182
 
1183
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1184
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1185
 
1186
  >>> # Generate
1187
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1188
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1189
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1190
  ```"""
 
1191
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1192
  output_hidden_states = (
1193
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1205
  output_attentions=output_attentions,
1206
  output_hidden_states=output_hidden_states,
1207
  return_dict=return_dict,
1208
+ cache_position=cache_position,
1209
  )
1210
 
1211
  hidden_states = outputs[0]
 
1243
  )
1244
 
1245
  def prepare_inputs_for_generation(
1246
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
1247
  ):
1248
+ # With static cache, the `past_key_values` is None
1249
+ # TODO joao: standardize interface for the different Cache classes and remove of this if
1250
+ has_static_cache = False
1251
+ if past_key_values is None:
1252
+ past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
1253
+ has_static_cache = past_key_values is not None
1254
+
1255
+ past_length = 0
1256
+ if past_key_values is not None:
1257
+ if isinstance(past_key_values, Cache):
1258
+ past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1259
+ max_cache_length = (
1260
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1261
+ if past_key_values.get_max_length() is not None
1262
+ else None
1263
+ )
1264
+ cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
1265
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1266
+ else:
1267
+ cache_length = past_length = past_key_values[0][0].shape[2]
1268
+ max_cache_length = None
1269
+
1270
+ # Keep only the unprocessed tokens:
1271
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1272
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1273
+ # input)
1274
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1275
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1276
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1277
+ # input_ids based on the past_length.
1278
+ elif past_length < input_ids.shape[1]:
1279
+ input_ids = input_ids[:, past_length:]
1280
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1281
+
1282
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1283
+ if (
1284
+ max_cache_length is not None
1285
+ and attention_mask is not None
1286
+ and cache_length + input_ids.shape[1] > max_cache_length
1287
+ ):
1288
+ attention_mask = attention_mask[:, -max_cache_length:]
1289
 
1290
  position_ids = kwargs.get("position_ids", None)
1291
  if attention_mask is not None and position_ids is None:
 
1293
  position_ids = attention_mask.long().cumsum(-1) - 1
1294
  position_ids.masked_fill_(attention_mask == 0, 1)
1295
  if past_key_values:
1296
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1297
 
1298
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1299
  if inputs_embeds is not None and past_key_values is None:
1300
  model_inputs = {"inputs_embeds": inputs_embeds}
1301
  else:
1302
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1303
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1304
+ # TODO: use `next_tokens` directly instead.
1305
+ model_inputs = {"input_ids": input_ids.contiguous()}
1306
+
1307
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1308
+ if cache_position is None:
1309
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
1310
+ else:
1311
+ cache_position = cache_position[-input_length:]
1312
+
1313
+ if has_static_cache:
1314
+ past_key_values = None
1315
 
1316
  model_inputs.update(
1317
  {
1318
  "position_ids": position_ids,
1319
+ "cache_position": cache_position,
1320
  "past_key_values": past_key_values,
1321
  "use_cache": kwargs.get("use_cache"),
1322
  "attention_mask": attention_mask,
 
1333
  )
1334
  return reordered_past
1335
 
1336
+
1337
  @add_start_docstrings(
1338
  """
1339
+ The Aquila Model transformer with a sequence classification head on top (linear layer).
1340
 
1341
  [`AquilaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1342
  (e.g. GPT-2) do.
 
1349
  """,
1350
  AQUILA_START_DOCSTRING,
1351
  )
 
1352
  class AquilaForSequenceClassification(AquilaPreTrainedModel):
 
 
1353
  def __init__(self, config):
1354
  super().__init__(config)
1355
  self.num_labels = config.num_labels
 
1412
  sequence_lengths = -1
1413
  else:
1414
  if input_ids is not None:
1415
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1416
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1417
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1418
+ sequence_lengths = sequence_lengths.to(logits.device)
1419
  else:
1420
  sequence_lengths = -1
1421
 
 
1455
  hidden_states=transformer_outputs.hidden_states,
1456
  attentions=transformer_outputs.attentions,
1457
  )
1458
+
1459
+
1460
+ @add_start_docstrings(
1461
+ """
1462
+ The Aquila Model transformer with a span classification head on top for extractive question-answering tasks like
1463
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1464
+ """,
1465
+ AQUILA_START_DOCSTRING,
1466
+ )
1467
+ class AquilaForQuestionAnswering(AquilaPreTrainedModel):
1468
+ base_model_prefix = "transformer"
1469
+
1470
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Aquila
1471
+ def __init__(self, config):
1472
+ super().__init__(config)
1473
+ self.transformer = AquilaModel(config)
1474
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1475
+
1476
+ # Initialize weights and apply final processing
1477
+ self.post_init()
1478
+
1479
+ def get_input_embeddings(self):
1480
+ return self.transformer.embed_tokens
1481
+
1482
+ def set_input_embeddings(self, value):
1483
+ self.transformer.embed_tokens = value
1484
+
1485
+ @add_start_docstrings_to_model_forward(AQUILA_INPUTS_DOCSTRING)
1486
+ def forward(
1487
+ self,
1488
+ input_ids: Optional[torch.LongTensor] = None,
1489
+ attention_mask: Optional[torch.FloatTensor] = None,
1490
+ position_ids: Optional[torch.LongTensor] = None,
1491
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1492
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1493
+ start_positions: Optional[torch.LongTensor] = None,
1494
+ end_positions: Optional[torch.LongTensor] = None,
1495
+ output_attentions: Optional[bool] = None,
1496
+ output_hidden_states: Optional[bool] = None,
1497
+ return_dict: Optional[bool] = None,
1498
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1499
+ r"""
1500
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1501
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1502
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1503
+ are not taken into account for computing the loss.
1504
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1505
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1506
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1507
+ are not taken into account for computing the loss.
1508
+ """
1509
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1510
+
1511
+ outputs = self.transformer(
1512
+ input_ids,
1513
+ attention_mask=attention_mask,
1514
+ position_ids=position_ids,
1515
+ past_key_values=past_key_values,
1516
+ inputs_embeds=inputs_embeds,
1517
+ output_attentions=output_attentions,
1518
+ output_hidden_states=output_hidden_states,
1519
+ return_dict=return_dict,
1520
+ )
1521
+
1522
+ sequence_output = outputs[0]
1523
+
1524
+ logits = self.qa_outputs(sequence_output)
1525
+ start_logits, end_logits = logits.split(1, dim=-1)
1526
+ start_logits = start_logits.squeeze(-1).contiguous()
1527
+ end_logits = end_logits.squeeze(-1).contiguous()
1528
+
1529
+ total_loss = None
1530
+ if start_positions is not None and end_positions is not None:
1531
+ # If we are on multi-GPU, split add a dimension
1532
+ if len(start_positions.size()) > 1:
1533
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1534
+ if len(end_positions.size()) > 1:
1535
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1536
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1537
+ ignored_index = start_logits.size(1)
1538
+ start_positions = start_positions.clamp(0, ignored_index)
1539
+ end_positions = end_positions.clamp(0, ignored_index)
1540
+
1541
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1542
+ start_loss = loss_fct(start_logits, start_positions)
1543
+ end_loss = loss_fct(end_logits, end_positions)
1544
+ total_loss = (start_loss + end_loss) / 2
1545
+
1546
+ if not return_dict:
1547
+ output = (start_logits, end_logits) + outputs[2:]
1548
+ return ((total_loss,) + output) if total_loss is not None else output
1549
+
1550
+ return QuestionAnsweringModelOutput(
1551
+ loss=total_loss,
1552
+ start_logits=start_logits,
1553
+ end_logits=end_logits,
1554
+ hidden_states=outputs.hidden_states,
1555
+ attentions=outputs.attentions,
1556
+ )