Andrei Panferov commited on
Commit
dfb8eb3
1 Parent(s): 161c13a
config.json CHANGED
@@ -3,9 +3,8 @@
3
  "LlamaForCausalLM_AQLM"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "configuration_llama.LlamaConfig",
7
- "AutoModel": "modeling_llama.LlamaModel",
8
- "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM"
9
  },
10
  "bos_token_id": 1,
11
  "eos_token_id": 2,
 
3
  "LlamaForCausalLM_AQLM"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_llama_aqlm.LlamaConfig",
7
+ "AutoModelForCausalLM": "modeling_llama_aqlm.LlamaForCausalLM"
 
8
  },
9
  "bos_token_id": 1,
10
  "eos_token_id": 2,
configuration_llama.py → configuration_llama_aqlm.py RENAMED
@@ -6,16 +6,13 @@ class LlamaConfig(OrigLlamaConfig):
6
 
7
  def __init__(
8
  self,
9
- nbits_per_codebook: int = 16,
10
- num_codebooks: int = 1,
11
- out_group_size: int = 1,
12
- in_group_size: int = 8,
 
 
13
  **kwargs,
14
  ):
15
  super().__init__(**kwargs)
16
- self.aqlm = {
17
- "nbits_per_codebook": nbits_per_codebook,
18
- "num_codebooks": num_codebooks,
19
- "out_group_size": out_group_size,
20
- "in_group_size": in_group_size,
21
- }
 
6
 
7
  def __init__(
8
  self,
9
+ aqlm: dict[str, int] = {
10
+ "nbits_per_codebook": 16,
11
+ "num_codebooks": 1,
12
+ "out_group_size": 8,
13
+ "in_group_size": 1,
14
+ },
15
  **kwargs,
16
  ):
17
  super().__init__(**kwargs)
18
+ self.aqlm = aqlm
 
 
 
 
 
inference.py CHANGED
@@ -1,4 +1,4 @@
1
- """ Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search"""
2
  import functools
3
  import os
4
  from typing import Optional
@@ -66,9 +66,7 @@ class FinalizedQuantizedLinear(nn.Module):
66
  self.register_parameter("bias", None)
67
 
68
  def forward(self, input: torch.Tensor) -> torch.Tensor:
69
- return forward_pass_quantized_linear(
70
- input, self.codes, self.codebooks, self.scales, self.bias
71
- )
72
 
73
 
74
  def get_int_dtype(nbits: int) -> torch.dtype:
@@ -122,14 +120,11 @@ def _dequantize_weight(
122
  ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]
123
 
124
  reconstructed_weight_groupwise = reconstructed_weight_flat.view(
125
- list(codes.shape[:-3])
126
- + [num_out_groups, num_in_groups, out_group_size, in_group_size]
127
  )
128
  if scales is not None:
129
  reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
130
- return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(
131
- list(codes.shape[:-3]) + [out_features, in_features]
132
- )
133
 
134
 
135
  def forward_pass_quantized_linear(
@@ -140,10 +135,7 @@ def forward_pass_quantized_linear(
140
  bias: Optional[torch.Tensor],
141
  ) -> torch.Tensor:
142
  if input.is_cuda:
143
- matmul_result = aqlm_gemm_stupid(input, codes, codebooks, scales)
144
- if bias is not None:
145
- matmul_result += bias
146
- return matmul_result
147
  else:
148
  dequantized_weight = _dequantize_weight(
149
  unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
@@ -168,6 +160,7 @@ def forward_pass_quantized_linear(
168
  "in_group_size",
169
  "num_input_groups",
170
  "num_input_groups_next_power_of_2",
 
171
  "compute_in_fp32",
172
  ],
173
  )
@@ -178,6 +171,7 @@ def _aqlm_gemv_simple(
178
  codes_i16_ptr,
179
  codebooks_ptr,
180
  scales_ptr,
 
181
  in_features: tl.constexpr,
182
  out_features: tl.constexpr,
183
  num_codebooks: tl.constexpr,
@@ -187,6 +181,7 @@ def _aqlm_gemv_simple(
187
  num_input_groups: tl.constexpr,
188
  num_input_groups_next_power_of_2: tl.constexpr,
189
  compute_in_fp32: tl.constexpr,
 
190
  UNUSED: tl.constexpr,
191
  ):
192
  # variables ending with "_i" mean "for i-th output unit"
@@ -197,8 +192,7 @@ def _aqlm_gemv_simple(
197
  input_vec_ptr
198
  + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
199
  + tl.arange(0, in_group_size)[None, None, :],
200
- mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None]
201
- < num_input_groups,
202
  )
203
  # [in_features//in_group_size, 1, group_size]
204
  # Note: we could simply load input_vec then reshape
@@ -216,14 +210,10 @@ def _aqlm_gemv_simple(
216
  )
217
  codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
218
 
219
- codes_i = tl.load(
220
- codes_i_ptrs, mask=codes_i_mask_1d[:, None]
221
- ) # [in_features//in_group_size, num_codebooks]
222
  if codes_i.dtype == tl.int16:
223
  codes_i = codes_i.to(tl.int32)
224
- codes_i = (codes_i) + (
225
- codes_i < 0
226
- ) * codebook_size # aka 2 ** nbits_per_codebook
227
  # ^-- (because codes are int16 tensors that contain uint data)
228
 
229
  # The following alternative does not work:
@@ -232,9 +222,7 @@ def _aqlm_gemv_simple(
232
  codes_i = codes_i.to(tl.int32)
233
 
234
  # shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
235
- codes_i += (
236
- tl.arange(0, num_codebooks)[None, :] * codebook_size
237
- ) # aka 2 ** nbits_per_codebook
238
  # ^-- [in_group_size, num_codebooks]
239
 
240
  # Stage 3: convert codes to pointers to every individual (activated) weight in codebooks
@@ -249,9 +237,7 @@ def _aqlm_gemv_simple(
249
  )
250
 
251
  # Stage 4: reconstruct weights, multiply by inputs and write out
252
- weights_i = tl.load(
253
- weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0
254
- )
255
  if compute_in_fp32:
256
  weights_i = weights_i.to(tl.float32)
257
  input_vec = input_vec.to(tl.float32)
@@ -262,16 +248,15 @@ def _aqlm_gemv_simple(
262
  if out_group_size == 1:
263
  scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
264
  output_i = tl.sum(weights_i * input_vec) * scale
 
 
265
  tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
266
  else:
267
- output_i = tl.sum(
268
- tl.sum(weights_i * input_vec, axis=2), axis=0
269
- ) # [out_group_size]
270
  output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
271
- tl.store(
272
- output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size),
273
- output_i.to(input_vec.dtype),
274
- )
275
 
276
 
277
  def next_power_of_2(x):
@@ -283,6 +268,7 @@ def aqlm_gemv_simple(
283
  codes_i16: torch.ShortTensor,
284
  codebooks: torch.Tensor,
285
  scales: torch.Tensor,
 
286
  compute_in_fp32: bool = True,
287
  ):
288
 
@@ -305,6 +291,7 @@ def aqlm_gemv_simple(
305
  codes_i16,
306
  codebooks,
307
  scales,
 
308
  in_features,
309
  out_features,
310
  num_codebooks,
@@ -314,6 +301,7 @@ def aqlm_gemv_simple(
314
  num_input_groups,
315
  next_power_of_2(num_input_groups),
316
  compute_in_fp32,
 
317
  )
318
 
319
  return output_vec
@@ -324,15 +312,14 @@ def aqlm_gemm_stupid(
324
  codes_i16: torch.ShortTensor,
325
  codebooks: torch.Tensor,
326
  scales: torch.Tensor,
 
327
  compute_in_fp32: bool = True,
328
  ):
329
  original_shape = input.shape
330
  input = input.reshape(-1, original_shape[-1])
331
  return torch.cat(
332
  [
333
- aqlm_gemv_simple(
334
- input_vec.unsqueeze(0), codes_i16, codebooks, scales, compute_in_fp32
335
- )
336
  for input_vec in input
337
  ]
338
  ).reshape(original_shape[:-1] + (-1,))
 
1
+ """ This file serves as the single entry point to efficiently run FinalizedQuantizedLinear layers"""
2
  import functools
3
  import os
4
  from typing import Optional
 
66
  self.register_parameter("bias", None)
67
 
68
  def forward(self, input: torch.Tensor) -> torch.Tensor:
69
+ return forward_pass_quantized_linear(input, self.codes, self.codebooks, self.scales, self.bias)
 
 
70
 
71
 
72
  def get_int_dtype(nbits: int) -> torch.dtype:
 
120
  ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]
121
 
122
  reconstructed_weight_groupwise = reconstructed_weight_flat.view(
123
+ list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
 
124
  )
125
  if scales is not None:
126
  reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
127
+ return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
 
 
128
 
129
 
130
  def forward_pass_quantized_linear(
 
135
  bias: Optional[torch.Tensor],
136
  ) -> torch.Tensor:
137
  if input.is_cuda:
138
+ return aqlm_gemm_stupid(input, codes, codebooks, scales, bias)
 
 
 
139
  else:
140
  dequantized_weight = _dequantize_weight(
141
  unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
 
160
  "in_group_size",
161
  "num_input_groups",
162
  "num_input_groups_next_power_of_2",
163
+ "has_bias",
164
  "compute_in_fp32",
165
  ],
166
  )
 
171
  codes_i16_ptr,
172
  codebooks_ptr,
173
  scales_ptr,
174
+ bias_ptr,
175
  in_features: tl.constexpr,
176
  out_features: tl.constexpr,
177
  num_codebooks: tl.constexpr,
 
181
  num_input_groups: tl.constexpr,
182
  num_input_groups_next_power_of_2: tl.constexpr,
183
  compute_in_fp32: tl.constexpr,
184
+ has_bias: tl.constexpr,
185
  UNUSED: tl.constexpr,
186
  ):
187
  # variables ending with "_i" mean "for i-th output unit"
 
192
  input_vec_ptr
193
  + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
194
  + tl.arange(0, in_group_size)[None, None, :],
195
+ mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups,
 
196
  )
197
  # [in_features//in_group_size, 1, group_size]
198
  # Note: we could simply load input_vec then reshape
 
210
  )
211
  codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
212
 
213
+ codes_i = tl.load(codes_i_ptrs, mask=codes_i_mask_1d[:, None]) # [in_features//in_group_size, num_codebooks]
 
 
214
  if codes_i.dtype == tl.int16:
215
  codes_i = codes_i.to(tl.int32)
216
+ codes_i = (codes_i) + (codes_i < 0) * codebook_size # aka 2 ** nbits_per_codebook
 
 
217
  # ^-- (because codes are int16 tensors that contain uint data)
218
 
219
  # The following alternative does not work:
 
222
  codes_i = codes_i.to(tl.int32)
223
 
224
  # shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
225
+ codes_i += tl.arange(0, num_codebooks)[None, :] * codebook_size # aka 2 ** nbits_per_codebook
 
 
226
  # ^-- [in_group_size, num_codebooks]
227
 
228
  # Stage 3: convert codes to pointers to every individual (activated) weight in codebooks
 
237
  )
238
 
239
  # Stage 4: reconstruct weights, multiply by inputs and write out
240
+ weights_i = tl.load(weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0)
 
 
241
  if compute_in_fp32:
242
  weights_i = weights_i.to(tl.float32)
243
  input_vec = input_vec.to(tl.float32)
 
248
  if out_group_size == 1:
249
  scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
250
  output_i = tl.sum(weights_i * input_vec) * scale
251
+ if bias_ptr:
252
+ output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
253
  tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
254
  else:
255
+ output_i = tl.sum(tl.sum(weights_i, axis=2) * input_vec, axis=0) # [out_group_size]
 
 
256
  output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
257
+ if bias_ptr:
258
+ output_i += tl.load(bias_ptr + pid).to(weights_i.dtype)
259
+ tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))
 
260
 
261
 
262
  def next_power_of_2(x):
 
268
  codes_i16: torch.ShortTensor,
269
  codebooks: torch.Tensor,
270
  scales: torch.Tensor,
271
+ bias: Optional[torch.Tensor],
272
  compute_in_fp32: bool = True,
273
  ):
274
 
 
291
  codes_i16,
292
  codebooks,
293
  scales,
294
+ bias,
295
  in_features,
296
  out_features,
297
  num_codebooks,
 
301
  num_input_groups,
302
  next_power_of_2(num_input_groups),
303
  compute_in_fp32,
304
+ bias is not None,
305
  )
306
 
307
  return output_vec
 
312
  codes_i16: torch.ShortTensor,
313
  codebooks: torch.Tensor,
314
  scales: torch.Tensor,
315
+ bias: Optional[torch.Tensor],
316
  compute_in_fp32: bool = True,
317
  ):
318
  original_shape = input.shape
319
  input = input.reshape(-1, original_shape[-1])
320
  return torch.cat(
321
  [
322
+ aqlm_gemv_simple(input_vec.unsqueeze(0), codes_i16, codebooks, scales, bias, compute_in_fp32)
 
 
323
  for input_vec in input
324
  ]
325
  ).reshape(original_shape[:-1] + (-1,))
modeling_llama.py → modeling_llama_aqlm.py RENAMED
@@ -19,6 +19,7 @@
19
  # limitations under the License.
20
  """ PyTorch LLaMA model."""
21
  import math
 
22
  from typing import List, Optional, Tuple, Union
23
 
24
  import torch
@@ -27,23 +28,45 @@ import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers.activations import ACT2FN
30
- from transformers.modeling_outputs import (BaseModelOutputWithPast,
31
- CausalLMOutputWithPast,
32
- SequenceClassifierOutputWithPast)
 
 
 
 
 
 
 
 
 
33
  from transformers.modeling_utils import PreTrainedModel
34
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
35
- from transformers.utils import (add_start_docstrings,
36
- add_start_docstrings_to_model_forward,
37
- is_flash_attn_available, logging,
38
- replace_return_docstrings)
 
 
 
 
 
39
 
40
- from .configuration_llama import LlamaConfig
41
  from .inference import FinalizedQuantizedLinear
42
 
43
- if is_flash_attn_available():
44
  from flash_attn import flash_attn_func, flash_attn_varlen_func
45
- from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa
46
- unpad_input)
 
 
 
 
 
 
 
 
47
 
48
 
49
  logger = logging.get_logger(__name__)
@@ -51,13 +74,11 @@ logger = logging.get_logger(__name__)
51
  _CONFIG_FOR_DOC = "LlamaConfig"
52
 
53
 
54
- def _get_unpad_data(padding_mask):
55
- seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
56
- indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
57
  max_seqlen_in_batch = seqlens_in_batch.max().item()
58
- cu_seqlens = F.pad(
59
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
60
- )
61
  return (
62
  indices,
63
  cu_seqlens,
@@ -65,51 +86,21 @@ def _get_unpad_data(padding_mask):
65
  )
66
 
67
 
68
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
69
- def _make_causal_mask(
70
- input_ids_shape: torch.Size,
71
- dtype: torch.dtype,
72
- device: torch.device,
73
- past_key_values_length: int = 0,
74
- ):
75
- """
76
- Make causal mask used for bi-directional self-attention.
77
- """
78
- bsz, tgt_len = input_ids_shape
79
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
80
- mask_cond = torch.arange(mask.size(-1), device=device)
81
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
82
- mask = mask.to(dtype)
83
-
84
- if past_key_values_length > 0:
85
- mask = torch.cat(
86
- [
87
- torch.zeros(
88
- tgt_len, past_key_values_length, dtype=dtype, device=device
89
- ),
90
- mask,
91
- ],
92
- dim=-1,
93
- )
94
- return mask[None, None, :, :].expand(
95
- bsz, 1, tgt_len, tgt_len + past_key_values_length
96
- )
97
-
98
-
99
- # Copied from transformers.models.bart.modeling_bart._expand_mask
100
  def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
101
- """
102
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
103
- """
104
- bsz, src_len = mask.size()
105
- tgt_len = tgt_len if tgt_len is not None else src_len
106
-
107
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
108
 
109
- inverted_mask = 1.0 - expanded_mask
110
 
111
- return inverted_mask.masked_fill(
112
- inverted_mask.to(torch.bool), torch.finfo(dtype).min
 
 
 
 
 
 
113
  )
114
 
115
 
@@ -140,33 +131,23 @@ class LlamaRotaryEmbedding(nn.Module):
140
  self.dim = dim
141
  self.max_position_embeddings = max_position_embeddings
142
  self.base = base
143
- inv_freq = 1.0 / (
144
- self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
145
- )
146
  self.register_buffer("inv_freq", inv_freq, persistent=False)
147
 
148
  # Build here to make `torch.jit.trace` work.
149
  self._set_cos_sin_cache(
150
- seq_len=max_position_embeddings,
151
- device=self.inv_freq.device,
152
- dtype=torch.get_default_dtype(),
153
  )
154
 
155
  def _set_cos_sin_cache(self, seq_len, device, dtype):
156
  self.max_seq_len_cached = seq_len
157
- t = torch.arange(
158
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
159
- )
160
 
161
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
162
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
163
  emb = torch.cat((freqs, freqs), dim=-1)
164
- self.register_buffer(
165
- "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
166
- )
167
- self.register_buffer(
168
- "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
169
- )
170
 
171
  def forward(self, x, seq_len=None):
172
  # x: [bs, num_attention_heads, seq_len, head_size]
@@ -174,54 +155,34 @@ class LlamaRotaryEmbedding(nn.Module):
174
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
175
 
176
  return (
177
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
178
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
179
  )
180
 
181
 
182
  class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
183
  """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
184
 
185
- def __init__(
186
- self,
187
- dim,
188
- max_position_embeddings=2048,
189
- base=10000,
190
- device=None,
191
- scaling_factor=1.0,
192
- ):
193
  self.scaling_factor = scaling_factor
194
  super().__init__(dim, max_position_embeddings, base, device)
195
 
196
  def _set_cos_sin_cache(self, seq_len, device, dtype):
197
  self.max_seq_len_cached = seq_len
198
- t = torch.arange(
199
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
200
- )
201
  t = t / self.scaling_factor
202
 
203
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
204
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
205
  emb = torch.cat((freqs, freqs), dim=-1)
206
- self.register_buffer(
207
- "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
208
- )
209
- self.register_buffer(
210
- "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
211
- )
212
 
213
 
214
  class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
215
  """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
216
 
217
- def __init__(
218
- self,
219
- dim,
220
- max_position_embeddings=2048,
221
- base=10000,
222
- device=None,
223
- scaling_factor=1.0,
224
- ):
225
  self.scaling_factor = scaling_factor
226
  super().__init__(dim, max_position_embeddings, base, device)
227
 
@@ -230,27 +191,18 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
230
 
231
  if seq_len > self.max_position_embeddings:
232
  base = self.base * (
233
- (self.scaling_factor * seq_len / self.max_position_embeddings)
234
- - (self.scaling_factor - 1)
235
  ) ** (self.dim / (self.dim - 2))
236
- inv_freq = 1.0 / (
237
- base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
238
- )
239
  self.register_buffer("inv_freq", inv_freq, persistent=False)
240
 
241
- t = torch.arange(
242
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
243
- )
244
 
245
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
246
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
247
  emb = torch.cat((freqs, freqs), dim=-1)
248
- self.register_buffer(
249
- "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
250
- )
251
- self.register_buffer(
252
- "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
253
- )
254
 
255
 
256
  def rotate_half(x):
@@ -260,12 +212,29 @@ def rotate_half(x):
260
  return torch.cat((-x2, x1), dim=-1)
261
 
262
 
263
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
264
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
265
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
266
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
267
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
268
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  q_embed = (q * cos) + (rotate_half(q) * sin)
270
  k_embed = (k * cos) + (rotate_half(k) * sin)
271
  return q_embed, k_embed
@@ -277,15 +246,9 @@ class LlamaMLP(nn.Module):
277
  self.config = config
278
  self.hidden_size = config.hidden_size
279
  self.intermediate_size = config.intermediate_size
280
- self.gate_proj = FinalizedQuantizedLinear(
281
- self.hidden_size, self.intermediate_size, bias=False, **config.aqlm
282
- )
283
- self.up_proj = FinalizedQuantizedLinear(
284
- self.hidden_size, self.intermediate_size, bias=False, **config.aqlm
285
- )
286
- self.down_proj = FinalizedQuantizedLinear(
287
- self.intermediate_size, self.hidden_size, bias=False, **config.aqlm
288
- )
289
  self.act_fn = ACT2FN[config.hidden_act]
290
 
291
  def forward(self, x):
@@ -295,25 +258,12 @@ class LlamaMLP(nn.Module):
295
  up_proj_slices = self.up_proj.weight.split(slice, dim=0)
296
  down_proj_slices = self.down_proj.weight.split(slice, dim=1)
297
 
298
- gate_proj = torch.cat(
299
- [
300
- F.linear(x, gate_proj_slices[i])
301
- for i in range(self.config.pretraining_tp)
302
- ],
303
- dim=-1,
304
- )
305
- up_proj = torch.cat(
306
- [
307
- F.linear(x, up_proj_slices[i])
308
- for i in range(self.config.pretraining_tp)
309
- ],
310
- dim=-1,
311
- )
312
 
313
  intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
314
  down_proj = [
315
- F.linear(intermediate_states[i], down_proj_slices[i])
316
- for i in range(self.config.pretraining_tp)
317
  ]
318
  down_proj = sum(down_proj)
319
  else:
@@ -330,18 +280,25 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
330
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
331
  if n_rep == 1:
332
  return hidden_states
333
- hidden_states = hidden_states[:, :, None, :, :].expand(
334
- batch, num_key_value_heads, n_rep, slen, head_dim
335
- )
336
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
337
 
338
 
339
  class LlamaAttention(nn.Module):
340
  """Multi-headed attention from 'Attention Is All You Need' paper"""
341
 
342
- def __init__(self, config: LlamaConfig):
343
  super().__init__()
344
  self.config = config
 
 
 
 
 
 
 
 
 
345
  self.hidden_size = config.hidden_size
346
  self.num_heads = config.num_attention_heads
347
  self.head_dim = self.hidden_size // self.num_heads
@@ -349,35 +306,25 @@ class LlamaAttention(nn.Module):
349
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
350
  self.max_position_embeddings = config.max_position_embeddings
351
  self.rope_theta = config.rope_theta
 
352
 
353
  if (self.head_dim * self.num_heads) != self.hidden_size:
354
  raise ValueError(
355
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
356
  f" and `num_heads`: {self.num_heads})."
357
  )
 
358
  self.q_proj = FinalizedQuantizedLinear(
359
- self.hidden_size,
360
- self.num_heads * self.head_dim,
361
- bias=config.attention_bias,
362
- **config.aqlm,
363
  )
364
  self.k_proj = FinalizedQuantizedLinear(
365
- self.hidden_size,
366
- self.num_key_value_heads * self.head_dim,
367
- bias=config.attention_bias,
368
- **config.aqlm,
369
  )
370
  self.v_proj = FinalizedQuantizedLinear(
371
- self.hidden_size,
372
- self.num_key_value_heads * self.head_dim,
373
- bias=config.attention_bias,
374
- **config.aqlm,
375
  )
376
  self.o_proj = FinalizedQuantizedLinear(
377
- self.num_heads * self.head_dim,
378
- self.hidden_size,
379
- bias=config.attention_bias,
380
- **config.aqlm,
381
  )
382
  self._init_rope()
383
 
@@ -409,50 +356,40 @@ class LlamaAttention(nn.Module):
409
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
410
 
411
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
412
- return (
413
- tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
414
- .transpose(1, 2)
415
- .contiguous()
416
- )
417
 
418
  def forward(
419
  self,
420
  hidden_states: torch.Tensor,
421
  attention_mask: Optional[torch.Tensor] = None,
422
  position_ids: Optional[torch.LongTensor] = None,
423
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
424
  output_attentions: bool = False,
425
  use_cache: bool = False,
426
- padding_mask: Optional[torch.LongTensor] = None,
427
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
428
  bsz, q_len, _ = hidden_states.size()
429
 
430
  if self.config.pretraining_tp > 1:
431
- key_value_slicing = (
432
- self.num_key_value_heads * self.head_dim
433
- ) // self.config.pretraining_tp
434
  query_slices = self.q_proj.weight.split(
435
  (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
436
  )
437
  key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
438
  value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
439
 
440
- query_states = [
441
- F.linear(hidden_states, query_slices[i])
442
- for i in range(self.config.pretraining_tp)
443
- ]
444
  query_states = torch.cat(query_states, dim=-1)
445
 
446
- key_states = [
447
- F.linear(hidden_states, key_slices[i])
448
- for i in range(self.config.pretraining_tp)
449
- ]
450
  key_states = torch.cat(key_states, dim=-1)
451
 
452
- value_states = [
453
- F.linear(hidden_states, value_slices[i])
454
- for i in range(self.config.pretraining_tp)
455
- ]
456
  value_states = torch.cat(value_states, dim=-1)
457
 
458
  else:
@@ -460,37 +397,30 @@ class LlamaAttention(nn.Module):
460
  key_states = self.k_proj(hidden_states)
461
  value_states = self.v_proj(hidden_states)
462
 
463
- query_states = query_states.view(
464
- bsz, q_len, self.num_heads, self.head_dim
465
- ).transpose(1, 2)
466
- key_states = key_states.view(
467
- bsz, q_len, self.num_key_value_heads, self.head_dim
468
- ).transpose(1, 2)
469
- value_states = value_states.view(
470
- bsz, q_len, self.num_key_value_heads, self.head_dim
471
- ).transpose(1, 2)
472
 
473
  kv_seq_len = key_states.shape[-2]
474
  if past_key_value is not None:
475
- kv_seq_len += past_key_value[0].shape[-2]
 
 
 
 
 
 
476
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
477
- query_states, key_states = apply_rotary_pos_emb(
478
- query_states, key_states, cos, sin, position_ids
479
- )
480
 
481
  if past_key_value is not None:
482
- # reuse k, v, self_attention
483
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
484
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
485
-
486
- past_key_value = (key_states, value_states) if use_cache else None
487
 
488
  key_states = repeat_kv(key_states, self.num_key_value_groups)
489
  value_states = repeat_kv(value_states, self.num_key_value_groups)
490
 
491
- attn_weights = torch.matmul(
492
- query_states, key_states.transpose(2, 3)
493
- ) / math.sqrt(self.head_dim)
494
 
495
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
496
  raise ValueError(
@@ -506,9 +436,8 @@ class LlamaAttention(nn.Module):
506
  attn_weights = attn_weights + attention_mask
507
 
508
  # upcast attention to fp32
509
- attn_weights = nn.functional.softmax(
510
- attn_weights, dim=-1, dtype=torch.float32
511
- ).to(query_states.dtype)
512
  attn_output = torch.matmul(attn_weights, value_states)
513
 
514
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -522,18 +451,9 @@ class LlamaAttention(nn.Module):
522
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
523
 
524
  if self.config.pretraining_tp > 1:
525
- attn_output = attn_output.split(
526
- self.hidden_size // self.config.pretraining_tp, dim=2
527
- )
528
- o_proj_slices = self.o_proj.weight.split(
529
- self.hidden_size // self.config.pretraining_tp, dim=1
530
- )
531
- attn_output = sum(
532
- [
533
- F.linear(attn_output[i], o_proj_slices[i])
534
- for i in range(self.config.pretraining_tp)
535
- ]
536
- )
537
  else:
538
  attn_output = self.o_proj(attn_output)
539
 
@@ -550,17 +470,33 @@ class LlamaFlashAttention2(LlamaAttention):
550
  flash attention and deal with padding tokens in case the input contains any of them.
551
  """
552
 
 
 
 
 
 
 
 
 
553
  def forward(
554
  self,
555
  hidden_states: torch.Tensor,
556
- attention_mask: Optional[torch.Tensor] = None,
557
  position_ids: Optional[torch.LongTensor] = None,
558
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
559
  output_attentions: bool = False,
560
  use_cache: bool = False,
561
- padding_mask: Optional[torch.LongTensor] = None,
562
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
563
  # LlamaFlashAttention2 attention does not support output_attentions
 
 
 
 
 
 
 
 
564
  output_attentions = False
565
 
566
  bsz, q_len, _ = hidden_states.size()
@@ -570,68 +506,56 @@ class LlamaFlashAttention2(LlamaAttention):
570
  value_states = self.v_proj(hidden_states)
571
 
572
  # Flash attention requires the input to have the shape
573
- # batch_size x seq_length x head_dime x hidden_dim
574
  # therefore we just need to keep the original shape
575
- query_states = query_states.view(
576
- bsz, q_len, self.num_heads, self.head_dim
577
- ).transpose(1, 2)
578
- key_states = key_states.view(
579
- bsz, q_len, self.num_key_value_heads, self.head_dim
580
- ).transpose(1, 2)
581
- value_states = value_states.view(
582
- bsz, q_len, self.num_key_value_heads, self.head_dim
583
- ).transpose(1, 2)
584
 
585
  kv_seq_len = key_states.shape[-2]
586
  if past_key_value is not None:
587
- kv_seq_len += past_key_value[0].shape[-2]
588
-
589
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
590
-
591
- query_states, key_states = apply_rotary_pos_emb(
592
- query_states, key_states, cos, sin, position_ids
593
- )
594
 
595
  if past_key_value is not None:
596
- # reuse k, v, self_attention
597
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
598
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
599
-
600
- past_key_value = (key_states, value_states) if use_cache else None
601
 
 
 
602
  query_states = query_states.transpose(1, 2)
603
  key_states = key_states.transpose(1, 2)
604
  value_states = value_states.transpose(1, 2)
605
 
606
- # TODO: llama does not have dropout in the config??
607
- # It is recommended to use dropout with FA according to the docs
608
- # when training.
609
- dropout_rate = 0.0 # if not self.training else self.attn_dropout
610
 
611
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
612
  # therefore the input hidden states gets silently casted in float32. Hence, we need
613
- # cast them back in float16 just to be sure everything works as expected.
614
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
615
  # in fp32. (LlamaRMSNorm handles it correctly)
 
616
  input_dtype = query_states.dtype
617
  if input_dtype == torch.float32:
 
 
 
 
 
 
618
  logger.warning_once(
619
- "The input hidden states seems to be silently casted in float32, this might be related to"
620
- " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
621
- " float16."
622
  )
623
 
624
- query_states = query_states.to(torch.float16)
625
- key_states = key_states.to(torch.float16)
626
- value_states = value_states.to(torch.float16)
627
 
628
  attn_output = self._flash_attention_forward(
629
- query_states,
630
- key_states,
631
- value_states,
632
- padding_mask,
633
- q_len,
634
- dropout=dropout_rate,
635
  )
636
 
637
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -643,14 +567,7 @@ class LlamaFlashAttention2(LlamaAttention):
643
  return attn_output, attn_weights, past_key_value
644
 
645
  def _flash_attention_forward(
646
- self,
647
- query_states,
648
- key_states,
649
- value_states,
650
- padding_mask,
651
- query_length,
652
- dropout=0.0,
653
- softmax_scale=None,
654
  ):
655
  """
656
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -663,7 +580,7 @@ class LlamaFlashAttention2(LlamaAttention):
663
  Input key states to be passed to Flash Attention API
664
  value_states (`torch.Tensor`):
665
  Input value states to be passed to Flash Attention API
666
- padding_mask (`torch.Tensor`):
667
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
668
  position of padding tokens and 1 for the position of non-padding tokens.
669
  dropout (`int`, *optional*):
@@ -671,18 +588,17 @@ class LlamaFlashAttention2(LlamaAttention):
671
  softmax_scale (`float`, *optional*):
672
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
673
  """
 
 
 
 
 
 
674
  # Contains at least one padding token in the sequence
675
- if padding_mask is not None:
676
  batch_size = query_states.shape[0]
677
- (
678
- query_states,
679
- key_states,
680
- value_states,
681
- indices_q,
682
- cu_seq_lens,
683
- max_seq_lens,
684
- ) = self._upad_input(
685
- query_states, key_states, value_states, padding_mask, query_length
686
  )
687
 
688
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
@@ -698,42 +614,30 @@ class LlamaFlashAttention2(LlamaAttention):
698
  max_seqlen_k=max_seqlen_in_batch_k,
699
  dropout_p=dropout,
700
  softmax_scale=softmax_scale,
701
- causal=True,
702
  )
703
 
704
- attn_output = pad_input(
705
- attn_output_unpad, indices_q, batch_size, query_length
706
- )
707
  else:
708
  attn_output = flash_attn_func(
709
- query_states,
710
- key_states,
711
- value_states,
712
- dropout,
713
- softmax_scale=softmax_scale,
714
- causal=True,
715
  )
716
 
717
  return attn_output
718
 
719
- def _upad_input(
720
- self, query_layer, key_layer, value_layer, padding_mask, query_length
721
- ):
722
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
723
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
724
 
725
  key_layer = index_first_axis(
726
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
727
- indices_k,
728
  )
729
  value_layer = index_first_axis(
730
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
731
- indices_k,
732
  )
733
  if query_length == kv_seq_len:
734
  query_layer = index_first_axis(
735
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
736
- indices_k,
737
  )
738
  cu_seqlens_q = cu_seqlens_k
739
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -747,10 +651,8 @@ class LlamaFlashAttention2(LlamaAttention):
747
  query_layer = query_layer.squeeze(1)
748
  else:
749
  # The -q_len: slice assumes left padding.
750
- padding_mask = padding_mask[:, -query_length:]
751
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
752
- query_layer, padding_mask
753
- )
754
 
755
  return (
756
  query_layer,
@@ -762,20 +664,110 @@ class LlamaFlashAttention2(LlamaAttention):
762
  )
763
 
764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
  class LlamaDecoderLayer(nn.Module):
766
- def __init__(self, config: LlamaConfig):
767
  super().__init__()
768
  self.hidden_size = config.hidden_size
769
- self.self_attn = (
770
- LlamaAttention(config=config)
771
- if not getattr(config, "_flash_attn_2_enabled", False)
772
- else LlamaFlashAttention2(config=config)
773
- )
774
  self.mlp = LlamaMLP(config)
775
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
776
- self.post_attention_layernorm = LlamaRMSNorm(
777
- config.hidden_size, eps=config.rms_norm_eps
778
- )
779
 
780
  def forward(
781
  self,
@@ -785,15 +777,14 @@ class LlamaDecoderLayer(nn.Module):
785
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
786
  output_attentions: Optional[bool] = False,
787
  use_cache: Optional[bool] = False,
788
- padding_mask: Optional[torch.LongTensor] = None,
789
- ) -> Tuple[
790
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
791
- ]:
792
  """
793
  Args:
794
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
795
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
796
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
 
797
  output_attentions (`bool`, *optional*):
798
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
799
  returned tensors for more detail.
@@ -802,6 +793,10 @@ class LlamaDecoderLayer(nn.Module):
802
  (see `past_key_values`).
803
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
804
  """
 
 
 
 
805
 
806
  residual = hidden_states
807
 
@@ -815,7 +810,7 @@ class LlamaDecoderLayer(nn.Module):
815
  past_key_value=past_key_value,
816
  output_attentions=output_attentions,
817
  use_cache=use_cache,
818
- padding_mask=padding_mask,
819
  )
820
  hidden_states = residual + hidden_states
821
 
@@ -864,6 +859,8 @@ class LlamaPreTrainedModel(PreTrainedModel):
864
  _no_split_modules = ["LlamaDecoderLayer"]
865
  _skip_keys_device_placement = "past_key_values"
866
  _supports_flash_attn_2 = True
 
 
867
 
868
  def _init_weights(self, module):
869
  std = self.config.initializer_range
@@ -876,10 +873,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
876
  if module.padding_idx is not None:
877
  module.weight.data[module.padding_idx].zero_()
878
 
879
- def _set_gradient_checkpointing(self, module, value=False):
880
- if isinstance(module, LlamaModel):
881
- module.gradient_checkpointing = value
882
-
883
 
884
  LLAMA_INPUTS_DOCSTRING = r"""
885
  Args:
@@ -916,13 +909,19 @@ LLAMA_INPUTS_DOCSTRING = r"""
916
  config.n_positions - 1]`.
917
 
918
  [What are position IDs?](../glossary#position-ids)
919
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
920
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
921
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
922
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
923
 
924
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
925
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
 
 
 
 
 
 
926
 
927
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
928
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
@@ -962,12 +961,12 @@ class LlamaModel(LlamaPreTrainedModel):
962
  self.padding_idx = config.pad_token_id
963
  self.vocab_size = config.vocab_size
964
 
965
- self.embed_tokens = nn.Embedding(
966
- config.vocab_size, config.hidden_size, self.padding_idx
967
- )
968
  self.layers = nn.ModuleList(
969
- [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
970
  )
 
 
971
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
972
 
973
  self.gradient_checkpointing = False
@@ -980,34 +979,6 @@ class LlamaModel(LlamaPreTrainedModel):
980
  def set_input_embeddings(self, value):
981
  self.embed_tokens = value
982
 
983
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
984
- def _prepare_decoder_attention_mask(
985
- self, attention_mask, input_shape, inputs_embeds, past_key_values_length
986
- ):
987
- # create causal mask
988
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
989
- combined_attention_mask = None
990
- if input_shape[-1] > 1:
991
- combined_attention_mask = _make_causal_mask(
992
- input_shape,
993
- inputs_embeds.dtype,
994
- device=inputs_embeds.device,
995
- past_key_values_length=past_key_values_length,
996
- )
997
-
998
- if attention_mask is not None:
999
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1000
- expanded_attn_mask = _expand_mask(
1001
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1002
- ).to(inputs_embeds.device)
1003
- combined_attention_mask = (
1004
- expanded_attn_mask
1005
- if combined_attention_mask is None
1006
- else expanded_attn_mask + combined_attention_mask
1007
- )
1008
-
1009
- return combined_attention_mask
1010
-
1011
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1012
  def forward(
1013
  self,
@@ -1021,133 +992,102 @@ class LlamaModel(LlamaPreTrainedModel):
1021
  output_hidden_states: Optional[bool] = None,
1022
  return_dict: Optional[bool] = None,
1023
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1024
- output_attentions = (
1025
- output_attentions
1026
- if output_attentions is not None
1027
- else self.config.output_attentions
1028
- )
1029
  output_hidden_states = (
1030
- output_hidden_states
1031
- if output_hidden_states is not None
1032
- else self.config.output_hidden_states
1033
  )
1034
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1035
 
1036
- return_dict = (
1037
- return_dict if return_dict is not None else self.config.use_return_dict
1038
- )
1039
 
1040
  # retrieve input_ids and inputs_embeds
1041
  if input_ids is not None and inputs_embeds is not None:
1042
- raise ValueError(
1043
- "You cannot specify both input_ids and inputs_embeds at the same time"
1044
- )
1045
  elif input_ids is not None:
1046
- batch_size, seq_length = input_ids.shape
1047
  elif inputs_embeds is not None:
1048
- batch_size, seq_length, _ = inputs_embeds.shape
1049
  else:
1050
  raise ValueError("You have to specify either input_ids or inputs_embeds")
1051
 
1052
- seq_length_with_past = seq_length
1053
- past_key_values_length = 0
 
 
 
 
1054
 
1055
- if past_key_values is not None:
1056
- past_key_values_length = past_key_values[0][0].shape[2]
1057
- seq_length_with_past = seq_length_with_past + past_key_values_length
 
 
 
1058
 
1059
  if position_ids is None:
1060
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1061
  position_ids = torch.arange(
1062
- past_key_values_length,
1063
- seq_length + past_key_values_length,
1064
- dtype=torch.long,
1065
- device=device,
1066
  )
1067
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1068
- else:
1069
- position_ids = position_ids.view(-1, seq_length).long()
1070
 
1071
  if inputs_embeds is None:
1072
  inputs_embeds = self.embed_tokens(input_ids)
1073
- # embed positions
1074
- if attention_mask is None:
1075
- attention_mask = torch.ones(
1076
- (batch_size, seq_length_with_past),
1077
- dtype=torch.bool,
1078
- device=inputs_embeds.device,
 
 
 
 
 
 
1079
  )
1080
- padding_mask = None
1081
  else:
1082
- if 0 in attention_mask:
1083
- padding_mask = attention_mask
1084
- else:
1085
- padding_mask = None
1086
-
1087
- attention_mask = self._prepare_decoder_attention_mask(
1088
- attention_mask,
1089
- (batch_size, seq_length),
1090
- inputs_embeds,
1091
- past_key_values_length,
1092
- )
1093
 
 
1094
  hidden_states = inputs_embeds
1095
 
1096
- if self.gradient_checkpointing and self.training:
1097
- if use_cache:
1098
- logger.warning_once(
1099
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1100
- )
1101
- use_cache = False
1102
-
1103
  # decoder layers
1104
  all_hidden_states = () if output_hidden_states else None
1105
  all_self_attns = () if output_attentions else None
1106
- next_decoder_cache = () if use_cache else None
1107
 
1108
- for idx, decoder_layer in enumerate(self.layers):
1109
  if output_hidden_states:
1110
  all_hidden_states += (hidden_states,)
1111
 
1112
- past_key_value = (
1113
- past_key_values[idx] if past_key_values is not None else None
1114
- )
1115
-
1116
  if self.gradient_checkpointing and self.training:
1117
-
1118
- def create_custom_forward(module):
1119
- def custom_forward(*inputs):
1120
- # None for past_key_value
1121
- return module(
1122
- *inputs,
1123
- past_key_value,
1124
- output_attentions,
1125
- padding_mask=padding_mask,
1126
- )
1127
-
1128
- return custom_forward
1129
-
1130
- layer_outputs = torch.utils.checkpoint.checkpoint(
1131
- create_custom_forward(decoder_layer),
1132
  hidden_states,
1133
  attention_mask,
1134
  position_ids,
 
 
 
1135
  )
1136
  else:
1137
  layer_outputs = decoder_layer(
1138
  hidden_states,
1139
  attention_mask=attention_mask,
1140
  position_ids=position_ids,
1141
- past_key_value=past_key_value,
1142
  output_attentions=output_attentions,
1143
  use_cache=use_cache,
1144
- padding_mask=padding_mask,
1145
  )
1146
 
1147
  hidden_states = layer_outputs[0]
1148
 
1149
  if use_cache:
1150
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1151
 
1152
  if output_attentions:
1153
  all_self_attns += (layer_outputs[1],)
@@ -1158,13 +1098,11 @@ class LlamaModel(LlamaPreTrainedModel):
1158
  if output_hidden_states:
1159
  all_hidden_states += (hidden_states,)
1160
 
1161
- next_cache = next_decoder_cache if use_cache else None
 
 
1162
  if not return_dict:
1163
- return tuple(
1164
- v
1165
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1166
- if v is not None
1167
- )
1168
  return BaseModelOutputWithPast(
1169
  last_hidden_state=hidden_states,
1170
  past_key_values=next_cache,
@@ -1204,9 +1142,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1204
  return self.model
1205
 
1206
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1207
- @replace_return_docstrings(
1208
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1209
- )
1210
  def forward(
1211
  self,
1212
  input_ids: torch.LongTensor = None,
@@ -1245,20 +1181,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1245
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1246
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1247
  ```"""
1248
-
1249
- output_attentions = (
1250
- output_attentions
1251
- if output_attentions is not None
1252
- else self.config.output_attentions
1253
- )
1254
  output_hidden_states = (
1255
- output_hidden_states
1256
- if output_hidden_states is not None
1257
- else self.config.output_hidden_states
1258
- )
1259
- return_dict = (
1260
- return_dict if return_dict is not None else self.config.use_return_dict
1261
  )
 
1262
 
1263
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1264
  outputs = self.model(
@@ -1275,13 +1202,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1275
 
1276
  hidden_states = outputs[0]
1277
  if self.config.pretraining_tp > 1:
1278
- lm_head_slices = self.lm_head.weight.split(
1279
- self.vocab_size // self.config.pretraining_tp, dim=0
1280
- )
1281
- logits = [
1282
- F.linear(hidden_states, lm_head_slices[i])
1283
- for i in range(self.config.pretraining_tp)
1284
- ]
1285
  logits = torch.cat(logits, dim=-1)
1286
  else:
1287
  logits = self.lm_head(hidden_states)
@@ -1313,15 +1235,36 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1313
  )
1314
 
1315
  def prepare_inputs_for_generation(
1316
- self,
1317
- input_ids,
1318
- past_key_values=None,
1319
- attention_mask=None,
1320
- inputs_embeds=None,
1321
- **kwargs,
1322
  ):
1323
- if past_key_values:
1324
- input_ids = input_ids[:, -1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1325
 
1326
  position_ids = kwargs.get("position_ids", None)
1327
  if attention_mask is not None and position_ids is None:
@@ -1329,7 +1272,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1329
  position_ids = attention_mask.long().cumsum(-1) - 1
1330
  position_ids.masked_fill_(attention_mask == 0, 1)
1331
  if past_key_values:
1332
- position_ids = position_ids[:, -1].unsqueeze(-1)
1333
 
1334
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1335
  if inputs_embeds is not None and past_key_values is None:
@@ -1352,10 +1295,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1352
  reordered_past = ()
1353
  for layer_past in past_key_values:
1354
  reordered_past += (
1355
- tuple(
1356
- past_state.index_select(0, beam_idx.to(past_state.device))
1357
- for past_state in layer_past
1358
- ),
1359
  )
1360
  return reordered_past
1361
 
@@ -1411,9 +1351,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1411
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1412
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1413
  """
1414
- return_dict = (
1415
- return_dict if return_dict is not None else self.config.use_return_dict
1416
- )
1417
 
1418
  transformer_outputs = self.model(
1419
  input_ids,
@@ -1435,22 +1373,18 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1435
  batch_size = inputs_embeds.shape[0]
1436
 
1437
  if self.config.pad_token_id is None and batch_size != 1:
1438
- raise ValueError(
1439
- "Cannot handle batch sizes > 1 if no padding token is defined."
1440
- )
1441
  if self.config.pad_token_id is None:
1442
  sequence_lengths = -1
1443
  else:
1444
  if input_ids is not None:
1445
- sequence_lengths = (
1446
- torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
1447
- ).to(logits.device)
1448
  else:
1449
  sequence_lengths = -1
1450
 
1451
- pooled_logits = logits[
1452
- torch.arange(batch_size, device=logits.device), sequence_lengths
1453
- ]
1454
 
1455
  loss = None
1456
  if labels is not None:
@@ -1458,9 +1392,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1458
  if self.config.problem_type is None:
1459
  if self.num_labels == 1:
1460
  self.config.problem_type = "regression"
1461
- elif self.num_labels > 1 and (
1462
- labels.dtype == torch.long or labels.dtype == torch.int
1463
- ):
1464
  self.config.problem_type = "single_label_classification"
1465
  else:
1466
  self.config.problem_type = "multi_label_classification"
@@ -1473,9 +1405,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1473
  loss = loss_fct(pooled_logits, labels)
1474
  elif self.config.problem_type == "single_label_classification":
1475
  loss_fct = CrossEntropyLoss()
1476
- loss = loss_fct(
1477
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1478
- )
1479
  elif self.config.problem_type == "multi_label_classification":
1480
  loss_fct = BCEWithLogitsLoss()
1481
  loss = loss_fct(pooled_logits, labels)
 
19
  # limitations under the License.
20
  """ PyTorch LLaMA model."""
21
  import math
22
+ import warnings
23
  from typing import List, Optional, Tuple, Union
24
 
25
  import torch
 
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
  from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache
32
+ from transformers.modeling_attn_mask_utils import (
33
+ AttentionMaskConverter,
34
+ _prepare_4d_attention_mask,
35
+ _prepare_4d_causal_attention_mask,
36
+ _prepare_4d_causal_attention_mask_for_sdpa,
37
+ )
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPast,
40
+ CausalLMOutputWithPast,
41
+ SequenceClassifierOutputWithPast,
42
+ )
43
  from transformers.modeling_utils import PreTrainedModel
44
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
45
+ from transformers.utils import (
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ is_flash_attn_2_available,
49
+ is_flash_attn_greater_or_equal_2_10,
50
+ logging,
51
+ replace_return_docstrings,
52
+ )
53
+ from transformers.utils.import_utils import is_torch_fx_available
54
 
55
+ from .configuration_llama_aqlm import LlamaConfig
56
  from .inference import FinalizedQuantizedLinear
57
 
58
+ if is_flash_attn_2_available():
59
  from flash_attn import flash_attn_func, flash_attn_varlen_func
60
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
61
+
62
+
63
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
64
+ # It means that the function will not be traced through and simply appear as a node in the graph.
65
+ if is_torch_fx_available():
66
+ if not is_torch_greater_or_equal_than_1_13:
67
+ import torch.fx
68
+
69
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
70
 
71
 
72
  logger = logging.get_logger(__name__)
 
74
  _CONFIG_FOR_DOC = "LlamaConfig"
75
 
76
 
77
+ def _get_unpad_data(attention_mask):
78
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
79
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
80
  max_seqlen_in_batch = seqlens_in_batch.max().item()
81
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
82
  return (
83
  indices,
84
  cu_seqlens,
 
86
  )
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
90
+ warnings.warn(
91
+ "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
92
+ )
93
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
 
 
 
94
 
 
95
 
96
+ def _make_causal_mask(
97
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
98
+ ):
99
+ warnings.warn(
100
+ "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask"
101
+ )
102
+ return AttentionMaskConverter._make_causal_mask(
103
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
104
  )
105
 
106
 
 
131
  self.dim = dim
132
  self.max_position_embeddings = max_position_embeddings
133
  self.base = base
134
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
 
 
135
  self.register_buffer("inv_freq", inv_freq, persistent=False)
136
 
137
  # Build here to make `torch.jit.trace` work.
138
  self._set_cos_sin_cache(
139
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
 
 
140
  )
141
 
142
  def _set_cos_sin_cache(self, seq_len, device, dtype):
143
  self.max_seq_len_cached = seq_len
144
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
145
 
146
+ freqs = torch.outer(t, self.inv_freq)
147
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
  emb = torch.cat((freqs, freqs), dim=-1)
149
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
150
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 
 
 
 
151
 
152
  def forward(self, x, seq_len=None):
153
  # x: [bs, num_attention_heads, seq_len, head_size]
 
155
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
156
 
157
  return (
158
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
159
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
160
  )
161
 
162
 
163
  class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
164
  """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
165
 
166
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
 
 
 
 
 
 
 
167
  self.scaling_factor = scaling_factor
168
  super().__init__(dim, max_position_embeddings, base, device)
169
 
170
  def _set_cos_sin_cache(self, seq_len, device, dtype):
171
  self.max_seq_len_cached = seq_len
172
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
173
  t = t / self.scaling_factor
174
 
175
+ freqs = torch.outer(t, self.inv_freq)
176
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
177
  emb = torch.cat((freqs, freqs), dim=-1)
178
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
179
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 
 
 
 
180
 
181
 
182
  class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
183
  """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
184
 
185
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
 
 
 
 
 
 
 
186
  self.scaling_factor = scaling_factor
187
  super().__init__(dim, max_position_embeddings, base, device)
188
 
 
191
 
192
  if seq_len > self.max_position_embeddings:
193
  base = self.base * (
194
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
 
195
  ) ** (self.dim / (self.dim - 2))
196
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
 
 
197
  self.register_buffer("inv_freq", inv_freq, persistent=False)
198
 
199
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
200
 
201
+ freqs = torch.outer(t, self.inv_freq)
202
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
203
  emb = torch.cat((freqs, freqs), dim=-1)
204
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
205
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 
 
 
 
206
 
207
 
208
  def rotate_half(x):
 
212
  return torch.cat((-x2, x1), dim=-1)
213
 
214
 
215
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
216
+ """Applies Rotary Position Embedding to the query and key tensors.
217
+
218
+ Args:
219
+ q (`torch.Tensor`): The query tensor.
220
+ k (`torch.Tensor`): The key tensor.
221
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
222
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
223
+ position_ids (`torch.Tensor`):
224
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
225
+ used to pass offsetted position ids when working with a KV-cache.
226
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
227
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
228
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
229
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
230
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
231
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
232
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
233
+ Returns:
234
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
235
+ """
236
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
237
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
238
  q_embed = (q * cos) + (rotate_half(q) * sin)
239
  k_embed = (k * cos) + (rotate_half(k) * sin)
240
  return q_embed, k_embed
 
246
  self.config = config
247
  self.hidden_size = config.hidden_size
248
  self.intermediate_size = config.intermediate_size
249
+ self.gate_proj = FinalizedQuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
250
+ self.up_proj = FinalizedQuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
251
+ self.down_proj = FinalizedQuantizedLinear(self.intermediate_size, self.hidden_size, bias=False, **config.aqlm)
 
 
 
 
 
 
252
  self.act_fn = ACT2FN[config.hidden_act]
253
 
254
  def forward(self, x):
 
258
  up_proj_slices = self.up_proj.weight.split(slice, dim=0)
259
  down_proj_slices = self.down_proj.weight.split(slice, dim=1)
260
 
261
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
262
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
265
  down_proj = [
266
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
 
267
  ]
268
  down_proj = sum(down_proj)
269
  else:
 
280
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
281
  if n_rep == 1:
282
  return hidden_states
283
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
284
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
285
 
286
 
287
  class LlamaAttention(nn.Module):
288
  """Multi-headed attention from 'Attention Is All You Need' paper"""
289
 
290
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
291
  super().__init__()
292
  self.config = config
293
+ self.layer_idx = layer_idx
294
+ if layer_idx is None:
295
+ logger.warning_once(
296
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
297
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
298
+ "when creating this class."
299
+ )
300
+
301
+ self.attention_dropout = config.attention_dropout
302
  self.hidden_size = config.hidden_size
303
  self.num_heads = config.num_attention_heads
304
  self.head_dim = self.hidden_size // self.num_heads
 
306
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
307
  self.max_position_embeddings = config.max_position_embeddings
308
  self.rope_theta = config.rope_theta
309
+ self.is_causal = True
310
 
311
  if (self.head_dim * self.num_heads) != self.hidden_size:
312
  raise ValueError(
313
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
314
  f" and `num_heads`: {self.num_heads})."
315
  )
316
+
317
  self.q_proj = FinalizedQuantizedLinear(
318
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
 
 
 
319
  )
320
  self.k_proj = FinalizedQuantizedLinear(
321
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
 
 
 
322
  )
323
  self.v_proj = FinalizedQuantizedLinear(
324
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
 
 
 
325
  )
326
  self.o_proj = FinalizedQuantizedLinear(
327
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, **config.aqlm
 
 
 
328
  )
329
  self._init_rope()
330
 
 
356
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
357
 
358
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
359
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
 
 
 
360
 
361
  def forward(
362
  self,
363
  hidden_states: torch.Tensor,
364
  attention_mask: Optional[torch.Tensor] = None,
365
  position_ids: Optional[torch.LongTensor] = None,
366
+ past_key_value: Optional[Cache] = None,
367
  output_attentions: bool = False,
368
  use_cache: bool = False,
369
+ **kwargs,
370
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
371
+ if "padding_mask" in kwargs:
372
+ warnings.warn(
373
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
374
+ )
375
+
376
  bsz, q_len, _ = hidden_states.size()
377
 
378
  if self.config.pretraining_tp > 1:
379
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
 
 
380
  query_slices = self.q_proj.weight.split(
381
  (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
382
  )
383
  key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
384
  value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
385
 
386
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
 
 
 
387
  query_states = torch.cat(query_states, dim=-1)
388
 
389
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
 
 
 
390
  key_states = torch.cat(key_states, dim=-1)
391
 
392
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
 
 
 
393
  value_states = torch.cat(value_states, dim=-1)
394
 
395
  else:
 
397
  key_states = self.k_proj(hidden_states)
398
  value_states = self.v_proj(hidden_states)
399
 
400
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
401
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
402
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
403
 
404
  kv_seq_len = key_states.shape[-2]
405
  if past_key_value is not None:
406
+ if self.layer_idx is None:
407
+ raise ValueError(
408
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
409
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
410
+ "with a layer index."
411
+ )
412
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
413
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
414
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
415
 
416
  if past_key_value is not None:
417
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
418
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
419
 
420
  key_states = repeat_kv(key_states, self.num_key_value_groups)
421
  value_states = repeat_kv(value_states, self.num_key_value_groups)
422
 
423
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
424
 
425
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
426
  raise ValueError(
 
436
  attn_weights = attn_weights + attention_mask
437
 
438
  # upcast attention to fp32
439
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
440
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
441
  attn_output = torch.matmul(attn_weights, value_states)
442
 
443
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
451
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
452
 
453
  if self.config.pretraining_tp > 1:
454
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
455
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
456
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
 
 
 
 
 
 
 
 
 
457
  else:
458
  attn_output = self.o_proj(attn_output)
459
 
 
470
  flash attention and deal with padding tokens in case the input contains any of them.
471
  """
472
 
473
+ def __init__(self, *args, **kwargs):
474
+ super().__init__(*args, **kwargs)
475
+
476
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
477
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
478
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
479
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
480
+
481
  def forward(
482
  self,
483
  hidden_states: torch.Tensor,
484
+ attention_mask: Optional[torch.LongTensor] = None,
485
  position_ids: Optional[torch.LongTensor] = None,
486
+ past_key_value: Optional[Cache] = None,
487
  output_attentions: bool = False,
488
  use_cache: bool = False,
489
+ **kwargs,
490
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
491
  # LlamaFlashAttention2 attention does not support output_attentions
492
+ if "padding_mask" in kwargs:
493
+ warnings.warn(
494
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
495
+ )
496
+
497
+ # overwrite attention_mask with padding_mask
498
+ attention_mask = kwargs.pop("padding_mask")
499
+
500
  output_attentions = False
501
 
502
  bsz, q_len, _ = hidden_states.size()
 
506
  value_states = self.v_proj(hidden_states)
507
 
508
  # Flash attention requires the input to have the shape
509
+ # batch_size x seq_length x head_dim x hidden_dim
510
  # therefore we just need to keep the original shape
511
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
512
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
513
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
514
 
515
  kv_seq_len = key_states.shape[-2]
516
  if past_key_value is not None:
517
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
518
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
519
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
 
520
 
521
  if past_key_value is not None:
522
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
523
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
524
 
525
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
526
+ # to be able to avoid many of these transpose/reshape/view.
527
  query_states = query_states.transpose(1, 2)
528
  key_states = key_states.transpose(1, 2)
529
  value_states = value_states.transpose(1, 2)
530
 
531
+ dropout_rate = self.attention_dropout if self.training else 0.0
 
 
 
532
 
533
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
534
  # therefore the input hidden states gets silently casted in float32. Hence, we need
535
+ # cast them back in the correct dtype just to be sure everything works as expected.
536
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
537
  # in fp32. (LlamaRMSNorm handles it correctly)
538
+
539
  input_dtype = query_states.dtype
540
  if input_dtype == torch.float32:
541
+ # Handle the case where the model is quantized
542
+ if hasattr(self.config, "_pre_quantization_dtype"):
543
+ target_dtype = self.config._pre_quantization_dtype
544
+ else:
545
+ target_dtype = self.q_proj.weight.dtype
546
+
547
  logger.warning_once(
548
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
549
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
550
+ f" {target_dtype}."
551
  )
552
 
553
+ query_states = query_states.to(target_dtype)
554
+ key_states = key_states.to(target_dtype)
555
+ value_states = value_states.to(target_dtype)
556
 
557
  attn_output = self._flash_attention_forward(
558
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
 
 
 
 
 
559
  )
560
 
561
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
 
567
  return attn_output, attn_weights, past_key_value
568
 
569
  def _flash_attention_forward(
570
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
 
 
 
 
 
 
 
571
  ):
572
  """
573
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
 
580
  Input key states to be passed to Flash Attention API
581
  value_states (`torch.Tensor`):
582
  Input value states to be passed to Flash Attention API
583
+ attention_mask (`torch.Tensor`):
584
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
585
  position of padding tokens and 1 for the position of non-padding tokens.
586
  dropout (`int`, *optional*):
 
588
  softmax_scale (`float`, *optional*):
589
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
590
  """
591
+ if not self._flash_attn_uses_top_left_mask:
592
+ causal = self.is_causal
593
+ else:
594
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
595
+ causal = self.is_causal and query_length != 1
596
+
597
  # Contains at least one padding token in the sequence
598
+ if attention_mask is not None:
599
  batch_size = query_states.shape[0]
600
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
601
+ query_states, key_states, value_states, attention_mask, query_length
 
 
 
 
 
 
 
602
  )
603
 
604
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
 
614
  max_seqlen_k=max_seqlen_in_batch_k,
615
  dropout_p=dropout,
616
  softmax_scale=softmax_scale,
617
+ causal=causal,
618
  )
619
 
620
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
 
621
  else:
622
  attn_output = flash_attn_func(
623
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
 
 
 
 
 
624
  )
625
 
626
  return attn_output
627
 
628
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
629
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
 
 
630
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
631
 
632
  key_layer = index_first_axis(
633
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
634
  )
635
  value_layer = index_first_axis(
636
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
637
  )
638
  if query_length == kv_seq_len:
639
  query_layer = index_first_axis(
640
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
 
641
  )
642
  cu_seqlens_q = cu_seqlens_k
643
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
651
  query_layer = query_layer.squeeze(1)
652
  else:
653
  # The -q_len: slice assumes left padding.
654
+ attention_mask = attention_mask[:, -query_length:]
655
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
 
 
656
 
657
  return (
658
  query_layer,
 
664
  )
665
 
666
 
667
+ class LlamaSdpaAttention(LlamaAttention):
668
+ """
669
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
670
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
671
+ SDPA API.
672
+ """
673
+
674
+ # Adapted from LlamaAttention.forward
675
+ def forward(
676
+ self,
677
+ hidden_states: torch.Tensor,
678
+ attention_mask: Optional[torch.Tensor] = None,
679
+ position_ids: Optional[torch.LongTensor] = None,
680
+ past_key_value: Optional[Cache] = None,
681
+ output_attentions: bool = False,
682
+ use_cache: bool = False,
683
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
684
+ if output_attentions:
685
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
686
+ logger.warning_once(
687
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
688
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
689
+ )
690
+ return super().forward(
691
+ hidden_states=hidden_states,
692
+ attention_mask=attention_mask,
693
+ position_ids=position_ids,
694
+ past_key_value=past_key_value,
695
+ output_attentions=output_attentions,
696
+ use_cache=use_cache,
697
+ )
698
+
699
+ bsz, q_len, _ = hidden_states.size()
700
+
701
+ query_states = self.q_proj(hidden_states)
702
+ key_states = self.k_proj(hidden_states)
703
+ value_states = self.v_proj(hidden_states)
704
+
705
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
706
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
707
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
708
+
709
+ kv_seq_len = key_states.shape[-2]
710
+ if past_key_value is not None:
711
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
712
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
713
+
714
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
715
+
716
+ if past_key_value is not None:
717
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
718
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
719
+
720
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
721
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
722
+
723
+ if attention_mask is not None:
724
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
725
+ raise ValueError(
726
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
727
+ )
728
+
729
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
730
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
731
+ if query_states.device.type == "cuda" and attention_mask is not None:
732
+ query_states = query_states.contiguous()
733
+ key_states = key_states.contiguous()
734
+ value_states = value_states.contiguous()
735
+
736
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
737
+ query_states,
738
+ key_states,
739
+ value_states,
740
+ attn_mask=attention_mask,
741
+ dropout_p=self.attention_dropout if self.training else 0.0,
742
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
743
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
744
+ )
745
+
746
+ attn_output = attn_output.transpose(1, 2).contiguous()
747
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
748
+
749
+ attn_output = self.o_proj(attn_output)
750
+
751
+ return attn_output, None, past_key_value
752
+
753
+
754
+ LLAMA_ATTENTION_CLASSES = {
755
+ "eager": LlamaAttention,
756
+ "flash_attention_2": LlamaFlashAttention2,
757
+ "sdpa": LlamaSdpaAttention,
758
+ }
759
+
760
+
761
  class LlamaDecoderLayer(nn.Module):
762
+ def __init__(self, config: LlamaConfig, layer_idx: int):
763
  super().__init__()
764
  self.hidden_size = config.hidden_size
765
+
766
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
767
+
 
 
768
  self.mlp = LlamaMLP(config)
769
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
770
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
771
 
772
  def forward(
773
  self,
 
777
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
778
  output_attentions: Optional[bool] = False,
779
  use_cache: Optional[bool] = False,
780
+ **kwargs,
781
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
782
  """
783
  Args:
784
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
785
+ attention_mask (`torch.FloatTensor`, *optional*):
786
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
787
+ query_sequence_length, key_sequence_length)` if default attention is used.
788
  output_attentions (`bool`, *optional*):
789
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
790
  returned tensors for more detail.
 
793
  (see `past_key_values`).
794
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
795
  """
796
+ if "padding_mask" in kwargs:
797
+ warnings.warn(
798
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
799
+ )
800
 
801
  residual = hidden_states
802
 
 
810
  past_key_value=past_key_value,
811
  output_attentions=output_attentions,
812
  use_cache=use_cache,
813
+ **kwargs,
814
  )
815
  hidden_states = residual + hidden_states
816
 
 
859
  _no_split_modules = ["LlamaDecoderLayer"]
860
  _skip_keys_device_placement = "past_key_values"
861
  _supports_flash_attn_2 = True
862
+ _supports_sdpa = True
863
+ _supports_cache_class = True
864
 
865
  def _init_weights(self, module):
866
  std = self.config.initializer_range
 
873
  if module.padding_idx is not None:
874
  module.weight.data[module.padding_idx].zero_()
875
 
 
 
 
 
876
 
877
  LLAMA_INPUTS_DOCSTRING = r"""
878
  Args:
 
909
  config.n_positions - 1]`.
910
 
911
  [What are position IDs?](../glossary#position-ids)
912
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
913
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
914
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
915
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
916
 
917
+ Two formats are allowed:
918
+ - a [`~cache_utils.Cache`] instance;
919
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
920
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
921
+ cache format.
922
+
923
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
924
+ legacy cache format will be returned.
925
 
926
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
927
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
 
961
  self.padding_idx = config.pad_token_id
962
  self.vocab_size = config.vocab_size
963
 
964
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
965
  self.layers = nn.ModuleList(
966
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
967
  )
968
+ self._use_sdpa = config._attn_implementation == "sdpa"
969
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
970
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
971
 
972
  self.gradient_checkpointing = False
 
979
  def set_input_embeddings(self, value):
980
  self.embed_tokens = value
981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
983
  def forward(
984
  self,
 
992
  output_hidden_states: Optional[bool] = None,
993
  return_dict: Optional[bool] = None,
994
  ) -> Union[Tuple, BaseModelOutputWithPast]:
995
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
996
  output_hidden_states = (
997
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
998
  )
999
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1000
 
1001
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1002
 
1003
  # retrieve input_ids and inputs_embeds
1004
  if input_ids is not None and inputs_embeds is not None:
1005
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
 
1006
  elif input_ids is not None:
1007
+ batch_size, seq_length = input_ids.shape[:2]
1008
  elif inputs_embeds is not None:
1009
+ batch_size, seq_length = inputs_embeds.shape[:2]
1010
  else:
1011
  raise ValueError("You have to specify either input_ids or inputs_embeds")
1012
 
1013
+ if self.gradient_checkpointing and self.training:
1014
+ if use_cache:
1015
+ logger.warning_once(
1016
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1017
+ )
1018
+ use_cache = False
1019
 
1020
+ past_key_values_length = 0
1021
+ if use_cache:
1022
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1023
+ if use_legacy_cache:
1024
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1025
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1026
 
1027
  if position_ids is None:
1028
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1029
  position_ids = torch.arange(
1030
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 
 
 
1031
  )
1032
+ position_ids = position_ids.unsqueeze(0)
 
 
1033
 
1034
  if inputs_embeds is None:
1035
  inputs_embeds = self.embed_tokens(input_ids)
1036
+
1037
+ if self._use_flash_attention_2:
1038
+ # 2d mask is passed through the layers
1039
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1040
+ elif self._use_sdpa and not output_attentions:
1041
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1042
+ # the manual implementation that requires a 4D causal mask in all cases.
1043
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1044
+ attention_mask,
1045
+ (batch_size, seq_length),
1046
+ inputs_embeds,
1047
+ past_key_values_length,
1048
  )
 
1049
  else:
1050
+ # 4d mask is passed through the layers
1051
+ attention_mask = _prepare_4d_causal_attention_mask(
1052
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1053
+ )
 
 
 
 
 
 
 
1054
 
1055
+ # embed positions
1056
  hidden_states = inputs_embeds
1057
 
 
 
 
 
 
 
 
1058
  # decoder layers
1059
  all_hidden_states = () if output_hidden_states else None
1060
  all_self_attns = () if output_attentions else None
1061
+ next_decoder_cache = None
1062
 
1063
+ for decoder_layer in self.layers:
1064
  if output_hidden_states:
1065
  all_hidden_states += (hidden_states,)
1066
 
 
 
 
 
1067
  if self.gradient_checkpointing and self.training:
1068
+ layer_outputs = self._gradient_checkpointing_func(
1069
+ decoder_layer.__call__,
 
 
 
 
 
 
 
 
 
 
 
 
 
1070
  hidden_states,
1071
  attention_mask,
1072
  position_ids,
1073
+ past_key_values,
1074
+ output_attentions,
1075
+ use_cache,
1076
  )
1077
  else:
1078
  layer_outputs = decoder_layer(
1079
  hidden_states,
1080
  attention_mask=attention_mask,
1081
  position_ids=position_ids,
1082
+ past_key_value=past_key_values,
1083
  output_attentions=output_attentions,
1084
  use_cache=use_cache,
 
1085
  )
1086
 
1087
  hidden_states = layer_outputs[0]
1088
 
1089
  if use_cache:
1090
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1091
 
1092
  if output_attentions:
1093
  all_self_attns += (layer_outputs[1],)
 
1098
  if output_hidden_states:
1099
  all_hidden_states += (hidden_states,)
1100
 
1101
+ next_cache = None
1102
+ if use_cache:
1103
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1104
  if not return_dict:
1105
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
1106
  return BaseModelOutputWithPast(
1107
  last_hidden_state=hidden_states,
1108
  past_key_values=next_cache,
 
1142
  return self.model
1143
 
1144
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1145
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
1146
  def forward(
1147
  self,
1148
  input_ids: torch.LongTensor = None,
 
1181
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1182
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1183
  ```"""
1184
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
 
1185
  output_hidden_states = (
1186
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
1187
  )
1188
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1189
 
1190
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1191
  outputs = self.model(
 
1202
 
1203
  hidden_states = outputs[0]
1204
  if self.config.pretraining_tp > 1:
1205
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1206
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
 
 
 
 
 
1207
  logits = torch.cat(logits, dim=-1)
1208
  else:
1209
  logits = self.lm_head(hidden_states)
 
1235
  )
1236
 
1237
  def prepare_inputs_for_generation(
1238
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
 
 
 
 
1239
  ):
1240
+ if past_key_values is not None:
1241
+ if isinstance(past_key_values, Cache):
1242
+ cache_length = past_key_values.get_seq_length()
1243
+ past_length = past_key_values.seen_tokens
1244
+ max_cache_length = past_key_values.get_max_length()
1245
+ else:
1246
+ cache_length = past_length = past_key_values[0][0].shape[2]
1247
+ max_cache_length = None
1248
+
1249
+ # Keep only the unprocessed tokens:
1250
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1251
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1252
+ # input)
1253
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1254
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1255
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1256
+ # input_ids based on the past_length.
1257
+ elif past_length < input_ids.shape[1]:
1258
+ input_ids = input_ids[:, past_length:]
1259
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1260
+
1261
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1262
+ if (
1263
+ max_cache_length is not None
1264
+ and attention_mask is not None
1265
+ and cache_length + input_ids.shape[1] > max_cache_length
1266
+ ):
1267
+ attention_mask = attention_mask[:, -max_cache_length:]
1268
 
1269
  position_ids = kwargs.get("position_ids", None)
1270
  if attention_mask is not None and position_ids is None:
 
1272
  position_ids = attention_mask.long().cumsum(-1) - 1
1273
  position_ids.masked_fill_(attention_mask == 0, 1)
1274
  if past_key_values:
1275
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1276
 
1277
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1278
  if inputs_embeds is not None and past_key_values is None:
 
1295
  reordered_past = ()
1296
  for layer_past in past_key_values:
1297
  reordered_past += (
1298
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
 
 
 
1299
  )
1300
  return reordered_past
1301
 
 
1351
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1352
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1353
  """
1354
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1355
 
1356
  transformer_outputs = self.model(
1357
  input_ids,
 
1373
  batch_size = inputs_embeds.shape[0]
1374
 
1375
  if self.config.pad_token_id is None and batch_size != 1:
1376
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
 
 
1377
  if self.config.pad_token_id is None:
1378
  sequence_lengths = -1
1379
  else:
1380
  if input_ids is not None:
1381
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1382
+ logits.device
1383
+ )
1384
  else:
1385
  sequence_lengths = -1
1386
 
1387
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1388
 
1389
  loss = None
1390
  if labels is not None:
 
1392
  if self.config.problem_type is None:
1393
  if self.num_labels == 1:
1394
  self.config.problem_type = "regression"
1395
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
1396
  self.config.problem_type = "single_label_classification"
1397
  else:
1398
  self.config.problem_type = "multi_label_classification"
 
1405
  loss = loss_fct(pooled_logits, labels)
1406
  elif self.config.problem_type == "single_label_classification":
1407
  loss_fct = CrossEntropyLoss()
1408
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1409
  elif self.config.problem_type == "multi_label_classification":
1410
  loss_fct = BCEWithLogitsLoss()
1411
  loss = loss_fct(pooled_logits, labels)