emozilla commited on
Commit
ec4e634
1 Parent(s): 0a3ddd0

Delete modeling_llama_yarn.py

Browse files
Files changed (1) hide show
  1. modeling_llama_yarn.py +0 -1410
modeling_llama_yarn.py DELETED
@@ -1,1410 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """ PyTorch LLaMA model."""
21
- import math
22
- from typing import List, Optional, Tuple, Union
23
-
24
- import torch
25
- import torch.nn.functional as F
26
- import torch.utils.checkpoint
27
- from torch import nn
28
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
-
30
- from transformers.activations import ACT2FN
31
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
32
- from transformers.modeling_utils import PreTrainedModel
33
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
34
- from transformers.utils import (
35
- add_start_docstrings,
36
- add_start_docstrings_to_model_forward,
37
- is_flash_attn_2_available,
38
- logging,
39
- replace_return_docstrings,
40
- )
41
- from .configuration_llama import LlamaConfig
42
-
43
-
44
- if is_flash_attn_2_available():
45
- from flash_attn import flash_attn_func, flash_attn_varlen_func
46
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
47
-
48
-
49
- logger = logging.get_logger(__name__)
50
-
51
- _CONFIG_FOR_DOC = "LlamaConfig"
52
-
53
-
54
- def _get_unpad_data(padding_mask):
55
- seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
56
- indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
57
- max_seqlen_in_batch = seqlens_in_batch.max().item()
58
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
59
- return (
60
- indices,
61
- cu_seqlens,
62
- max_seqlen_in_batch,
63
- )
64
-
65
-
66
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
67
- def _make_causal_mask(
68
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
69
- ):
70
- """
71
- Make causal mask used for bi-directional self-attention.
72
- """
73
- bsz, tgt_len = input_ids_shape
74
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
75
- mask_cond = torch.arange(mask.size(-1), device=device)
76
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
77
- mask = mask.to(dtype)
78
-
79
- if past_key_values_length > 0:
80
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
81
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
82
-
83
-
84
- # Copied from transformers.models.bart.modeling_bart._expand_mask
85
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
86
- """
87
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
88
- """
89
- bsz, src_len = mask.size()
90
- tgt_len = tgt_len if tgt_len is not None else src_len
91
-
92
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
93
-
94
- inverted_mask = 1.0 - expanded_mask
95
-
96
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
97
-
98
- # Inverse dim formula to find dim based on number of rotations
99
- def _yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
100
- return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
101
-
102
- # Find dim range bounds based on rotations
103
- def _yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
104
- low = math.floor(_yarn_find_correction_dim(
105
- low_rot, dim, base, max_position_embeddings))
106
- high = math.ceil(_yarn_find_correction_dim(
107
- high_rot, dim, base, max_position_embeddings))
108
- return max(low, 0), min(high, dim-1) # Clamp values just in case
109
-
110
- def _yarn_linear_ramp_mask(min, max, dim):
111
- if min == max:
112
- max += 0.001 # Prevent singularity
113
-
114
- linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
115
- ramp_func = torch.clamp(linear_func, 0, 1)
116
- return ramp_func
117
-
118
- def _yarn_get_mscale(scale=1):
119
- if scale <= 1:
120
- return 1.0
121
- return 0.1 * math.log(scale) + 1.0
122
-
123
- class LlamaRMSNorm(nn.Module):
124
- def __init__(self, hidden_size, eps=1e-6):
125
- """
126
- LlamaRMSNorm is equivalent to T5LayerNorm
127
- """
128
- super().__init__()
129
- self.weight = nn.Parameter(torch.ones(hidden_size))
130
- self.variance_epsilon = eps
131
-
132
- def forward(self, hidden_states):
133
- input_dtype = hidden_states.dtype
134
- hidden_states = hidden_states.to(torch.float32)
135
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
136
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
137
- return self.weight * hidden_states.to(input_dtype)
138
-
139
-
140
- ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
141
-
142
-
143
- class LlamaRotaryEmbedding(nn.Module):
144
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
145
- super().__init__()
146
-
147
- self.dim = dim
148
- self.max_position_embeddings = max_position_embeddings
149
- self.base = base
150
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
151
- self.register_buffer("inv_freq", inv_freq, persistent=False)
152
-
153
- # Build here to make `torch.jit.trace` work.
154
- self._set_cos_sin_cache(
155
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
156
- )
157
-
158
- def _set_cos_sin_cache(self, seq_len, device, dtype):
159
- self.max_seq_len_cached = seq_len
160
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
161
-
162
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
163
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
164
- emb = torch.cat((freqs, freqs), dim=-1)
165
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
166
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
167
-
168
- def forward(self, x, seq_len=None):
169
- # x: [bs, num_attention_heads, seq_len, head_size]
170
- if seq_len > self.max_seq_len_cached:
171
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
172
-
173
- return (
174
- self.cos_cached[:seq_len].to(dtype=x.dtype),
175
- self.sin_cached[:seq_len].to(dtype=x.dtype),
176
- )
177
-
178
-
179
- class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
180
- """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
181
-
182
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
183
- self.scaling_factor = scaling_factor
184
- super().__init__(dim, max_position_embeddings, base, device)
185
-
186
- def _set_cos_sin_cache(self, seq_len, device, dtype):
187
- self.max_seq_len_cached = seq_len
188
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
189
- t = t / self.scaling_factor
190
-
191
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
192
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
- emb = torch.cat((freqs, freqs), dim=-1)
194
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
195
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
196
-
197
-
198
- class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
199
- """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
200
-
201
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
202
- self.scaling_factor = scaling_factor
203
- super().__init__(dim, max_position_embeddings, base, device)
204
-
205
- def _set_cos_sin_cache(self, seq_len, device, dtype):
206
- self.max_seq_len_cached = seq_len
207
-
208
- if seq_len > self.max_position_embeddings:
209
- base = self.base * (
210
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
211
- ) ** (self.dim / (self.dim - 2))
212
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
213
- self.register_buffer("inv_freq", inv_freq, persistent=False)
214
-
215
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
216
-
217
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
218
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
219
- emb = torch.cat((freqs, freqs), dim=-1)
220
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
221
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
222
-
223
-
224
- class LlamaYaRNScaledRotaryEmbedding(torch.nn.Module):
225
- def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, original_max_position_embeddings=2048, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, finetuned=False, device=None):
226
- super().__init__()
227
-
228
- self.dim = dim
229
- self.max_position_embeddings = max_position_embeddings
230
- self.base = base
231
- self.scale = scale
232
- self.original_max_position_embeddings = original_max_position_embeddings
233
- self.extrapolation_factor = extrapolation_factor
234
- self.attn_factor = attn_factor
235
- self.beta_fast = beta_fast
236
- self.beta_slow = beta_slow
237
-
238
- self.yarn(device)
239
-
240
- # Build here to make `torch.jit.trace` work.
241
- self.max_seq_len_cached = max_position_embeddings
242
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
243
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
244
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
245
- emb = torch.cat((freqs, freqs), dim=-1)
246
- dtype = torch.get_default_dtype()
247
-
248
- self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
249
- self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
250
-
251
- def forward(self, x, seq_len=None):
252
- # x: [bs, num_attention_heads, seq_len, head_size]
253
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
254
- if seq_len > self.max_seq_len_cached:
255
- self.max_seq_len_cached = seq_len
256
-
257
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
258
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
259
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
260
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
261
-
262
- self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(x.dtype), persistent=False)
263
- self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(x.dtype), persistent=False)
264
- return (
265
- self.cos_cached[:seq_len].to(dtype=x.dtype),
266
- self.sin_cached[:seq_len].to(dtype=x.dtype),
267
- )
268
-
269
- def yarn(self, device):
270
- pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
271
- inv_freq_extrapolation = 1.0 / pos_freqs
272
- inv_freq_interpolation = 1.0 / (self.scale * pos_freqs)
273
-
274
- low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
275
- inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
276
- inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
277
-
278
- self.register_buffer("inv_freq", inv_freq, persistent=False)
279
- self.mscale = float(_yarn_get_mscale(self.scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
280
-
281
-
282
- class LlamaDynamicYaRNScaledRotaryEmbedding(torch.nn.Module):
283
- def __init__(self, dim, max_position_embeddings=2048, base=10000, original_max_position_embeddings=2048, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, finetuned=False, device=None):
284
- super().__init__()
285
-
286
- self.dim = dim
287
- self.max_position_embeddings = max_position_embeddings
288
- self.base = base
289
- self.original_max_position_embeddings = original_max_position_embeddings
290
- self.extrapolation_factor = extrapolation_factor
291
- self.attn_factor = attn_factor
292
- self.beta_fast = beta_fast
293
- self.beta_slow = beta_slow
294
-
295
- if finetuned:
296
- self.yarn(self.max_position_embeddings / self.original_max_position_embeddings, device)
297
- else:
298
- inv_freq = 1.0 / \
299
- (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
300
- self.register_buffer("inv_freq", inv_freq, persistent=False)
301
- self.mscale = 1
302
-
303
- # Build here to make `torch.jit.trace` work.
304
- self.max_seq_len_cached = max_position_embeddings
305
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
306
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
307
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
308
- emb = torch.cat((freqs, freqs), dim=-1)
309
- dtype = torch.get_default_dtype()
310
-
311
- self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
312
- self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
313
-
314
- def forward(self, x, seq_len=None):
315
- # x: [bs, num_attention_heads, seq_len, head_size]
316
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
317
- if seq_len > self.max_seq_len_cached:
318
- self.max_seq_len_cached = seq_len
319
-
320
- self.yarn(seq_len / self.max_position_embeddings, x.device)
321
-
322
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
323
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
324
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
325
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
326
-
327
- self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(x.dtype), persistent=False)
328
- self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(x.dtype), persistent=False)
329
- return (
330
- self.cos_cached[:seq_len].to(dtype=x.dtype),
331
- self.sin_cached[:seq_len].to(dtype=x.dtype),
332
- )
333
-
334
- def yarn(self, scale, device):
335
- pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
336
- inv_freq_extrapolation = 1.0 / pos_freqs
337
- inv_freq_interpolation = 1.0 / (scale * pos_freqs)
338
-
339
- low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
340
- inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
341
- inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
342
-
343
- self.register_buffer("inv_freq", inv_freq, persistent=False)
344
- self.mscale = float(_yarn_get_mscale(scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
345
-
346
-
347
- def rotate_half(x):
348
- """Rotates half the hidden dims of the input."""
349
- x1 = x[..., : x.shape[-1] // 2]
350
- x2 = x[..., x.shape[-1] // 2 :]
351
- return torch.cat((-x2, x1), dim=-1)
352
-
353
-
354
- # Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb
355
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
356
- cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
357
- sin = sin[position_ids].unsqueeze(1)
358
- q_embed = (q * cos) + (rotate_half(q) * sin)
359
- k_embed = (k * cos) + (rotate_half(k) * sin)
360
- return q_embed, k_embed
361
-
362
-
363
- class LlamaMLP(nn.Module):
364
- def __init__(self, config):
365
- super().__init__()
366
- self.config = config
367
- self.hidden_size = config.hidden_size
368
- self.intermediate_size = config.intermediate_size
369
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
370
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
371
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
372
- self.act_fn = ACT2FN[config.hidden_act]
373
-
374
- def forward(self, x):
375
- if self.config.pretraining_tp > 1:
376
- slice = self.intermediate_size // self.config.pretraining_tp
377
- gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
378
- up_proj_slices = self.up_proj.weight.split(slice, dim=0)
379
- down_proj_slices = self.down_proj.weight.split(slice, dim=1)
380
-
381
- gate_proj = torch.cat(
382
- [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
383
- )
384
- up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
385
-
386
- intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
387
- down_proj = [
388
- F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
389
- ]
390
- down_proj = sum(down_proj)
391
- else:
392
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
393
-
394
- return down_proj
395
-
396
-
397
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
398
- """
399
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
400
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
401
- """
402
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
403
- if n_rep == 1:
404
- return hidden_states
405
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
406
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
407
-
408
-
409
- class LlamaAttention(nn.Module):
410
- """Multi-headed attention from 'Attention Is All You Need' paper"""
411
-
412
- def __init__(self, config: LlamaConfig):
413
- super().__init__()
414
- self.config = config
415
- self.hidden_size = config.hidden_size
416
- self.num_heads = config.num_attention_heads
417
- self.head_dim = self.hidden_size // self.num_heads
418
- self.num_key_value_heads = config.num_key_value_heads
419
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
420
- self.max_position_embeddings = config.max_position_embeddings
421
- self.rope_theta = config.rope_theta
422
-
423
- if (self.head_dim * self.num_heads) != self.hidden_size:
424
- raise ValueError(
425
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
426
- f" and `num_heads`: {self.num_heads})."
427
- )
428
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
429
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
430
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
431
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
432
- self._init_rope()
433
-
434
- def _init_rope(self):
435
- if self.config.rope_scaling is None:
436
- self.rotary_emb = LlamaRotaryEmbedding(
437
- self.head_dim,
438
- max_position_embeddings=self.max_position_embeddings,
439
- base=self.rope_theta,
440
- )
441
- else:
442
- scaling_type = self.config.rope_scaling["type"]
443
- scaling_factor = self.config.rope_scaling["factor"]
444
- if scaling_type == "linear":
445
- self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
446
- self.head_dim,
447
- max_position_embeddings=self.max_position_embeddings,
448
- scaling_factor=scaling_factor,
449
- base=self.rope_theta,
450
- )
451
- elif scaling_type == "dynamic":
452
- self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
453
- self.head_dim,
454
- max_position_embeddings=self.max_position_embeddings,
455
- scaling_factor=scaling_factor,
456
- base=self.rope_theta,
457
- )
458
- elif scaling_type == "yarn":
459
- original_max_position_embeddings = self.config.rope_scaling["original_max_position_embeddings"]
460
- self.rotary_emb = LlamaYaRNScaledRotaryEmbedding(
461
- self.head_dim,
462
- max_position_embeddings=self.max_position_embeddings,
463
- scale=scaling_factor,
464
- original_max_position_embeddings=original_max_position_embeddings
465
- )
466
- elif scaling_type == "dynamic-yarn":
467
- original_max_position_embeddings = self.config.rope_scaling["original_max_position_embeddings"]
468
- self.rotary_emb = LlamaDynamicYaRNScaledRotaryEmbedding(
469
- self.head_dim,
470
- max_position_embeddings=self.max_position_embeddings,
471
- original_max_position_embeddings=original_max_position_embeddings
472
- )
473
- else:
474
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
475
-
476
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
477
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
478
-
479
- def forward(
480
- self,
481
- hidden_states: torch.Tensor,
482
- attention_mask: Optional[torch.Tensor] = None,
483
- position_ids: Optional[torch.LongTensor] = None,
484
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
485
- output_attentions: bool = False,
486
- use_cache: bool = False,
487
- padding_mask: Optional[torch.LongTensor] = None,
488
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
489
- bsz, q_len, _ = hidden_states.size()
490
-
491
- if self.config.pretraining_tp > 1:
492
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
493
- query_slices = self.q_proj.weight.split(
494
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
495
- )
496
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
497
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
498
-
499
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
500
- query_states = torch.cat(query_states, dim=-1)
501
-
502
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
503
- key_states = torch.cat(key_states, dim=-1)
504
-
505
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
506
- value_states = torch.cat(value_states, dim=-1)
507
-
508
- else:
509
- query_states = self.q_proj(hidden_states)
510
- key_states = self.k_proj(hidden_states)
511
- value_states = self.v_proj(hidden_states)
512
-
513
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
514
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
515
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
516
-
517
- kv_seq_len = key_states.shape[-2]
518
- if past_key_value is not None:
519
- kv_seq_len += past_key_value[0].shape[-2]
520
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
521
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
522
-
523
- if past_key_value is not None:
524
- # reuse k, v, self_attention
525
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
526
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
527
-
528
- past_key_value = (key_states, value_states) if use_cache else None
529
-
530
- key_states = repeat_kv(key_states, self.num_key_value_groups)
531
- value_states = repeat_kv(value_states, self.num_key_value_groups)
532
-
533
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
534
-
535
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
536
- raise ValueError(
537
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
538
- f" {attn_weights.size()}"
539
- )
540
-
541
- if attention_mask is not None:
542
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
543
- raise ValueError(
544
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
545
- )
546
- attn_weights = attn_weights + attention_mask
547
-
548
- # upcast attention to fp32
549
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
550
- attn_output = torch.matmul(attn_weights, value_states)
551
-
552
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
553
- raise ValueError(
554
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
555
- f" {attn_output.size()}"
556
- )
557
-
558
- attn_output = attn_output.transpose(1, 2).contiguous()
559
-
560
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
561
-
562
- if self.config.pretraining_tp > 1:
563
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
564
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
565
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
566
- else:
567
- attn_output = self.o_proj(attn_output)
568
-
569
- if not output_attentions:
570
- attn_weights = None
571
-
572
- return attn_output, attn_weights, past_key_value
573
-
574
-
575
- class LlamaFlashAttention2(LlamaAttention):
576
- """
577
- Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
578
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
579
- flash attention and deal with padding tokens in case the input contains any of them.
580
- """
581
-
582
- def forward(
583
- self,
584
- hidden_states: torch.Tensor,
585
- attention_mask: Optional[torch.Tensor] = None,
586
- position_ids: Optional[torch.LongTensor] = None,
587
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
588
- output_attentions: bool = False,
589
- use_cache: bool = False,
590
- padding_mask: Optional[torch.LongTensor] = None,
591
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
592
- # LlamaFlashAttention2 attention does not support output_attentions
593
- output_attentions = False
594
-
595
- bsz, q_len, _ = hidden_states.size()
596
-
597
- query_states = self.q_proj(hidden_states)
598
- key_states = self.k_proj(hidden_states)
599
- value_states = self.v_proj(hidden_states)
600
-
601
- # Flash attention requires the input to have the shape
602
- # batch_size x seq_length x head_dime x hidden_dim
603
- # therefore we just need to keep the original shape
604
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
605
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
606
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
607
-
608
- kv_seq_len = key_states.shape[-2]
609
- if past_key_value is not None:
610
- kv_seq_len += past_key_value[0].shape[-2]
611
-
612
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
613
-
614
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
615
-
616
- if past_key_value is not None:
617
- # reuse k, v, self_attention
618
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
619
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
620
-
621
- past_key_value = (key_states, value_states) if use_cache else None
622
-
623
- query_states = query_states.transpose(1, 2)
624
- key_states = key_states.transpose(1, 2)
625
- value_states = value_states.transpose(1, 2)
626
-
627
- # TODO: llama does not have dropout in the config??
628
- # It is recommended to use dropout with FA according to the docs
629
- # when training.
630
- dropout_rate = 0.0 # if not self.training else self.attn_dropout
631
-
632
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
633
- # therefore the input hidden states gets silently casted in float32. Hence, we need
634
- # cast them back in float16 just to be sure everything works as expected.
635
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
636
- # in fp32. (LlamaRMSNorm handles it correctly)
637
- input_dtype = query_states.dtype
638
- if input_dtype == torch.float32:
639
- logger.warning_once(
640
- "The input hidden states seems to be silently casted in float32, this might be related to"
641
- " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
642
- " float16."
643
- )
644
-
645
- query_states = query_states.to(torch.float16)
646
- key_states = key_states.to(torch.float16)
647
- value_states = value_states.to(torch.float16)
648
-
649
- attn_output = self._flash_attention_forward(
650
- query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
651
- )
652
-
653
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
654
- attn_output = self.o_proj(attn_output)
655
-
656
- if not output_attentions:
657
- attn_weights = None
658
-
659
- return attn_output, attn_weights, past_key_value
660
-
661
- def _flash_attention_forward(
662
- self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
663
- ):
664
- """
665
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
666
- first unpad the input, then computes the attention scores and pad the final attention scores.
667
-
668
- Args:
669
- query_states (`torch.Tensor`):
670
- Input query states to be passed to Flash Attention API
671
- key_states (`torch.Tensor`):
672
- Input key states to be passed to Flash Attention API
673
- value_states (`torch.Tensor`):
674
- Input value states to be passed to Flash Attention API
675
- padding_mask (`torch.Tensor`):
676
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
677
- position of padding tokens and 1 for the position of non-padding tokens.
678
- dropout (`int`, *optional*):
679
- Attention dropout
680
- softmax_scale (`float`, *optional*):
681
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
682
- """
683
- # Contains at least one padding token in the sequence
684
- if padding_mask is not None:
685
- batch_size = query_states.shape[0]
686
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
687
- query_states, key_states, value_states, padding_mask, query_length
688
- )
689
-
690
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
691
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
692
-
693
- attn_output_unpad = flash_attn_varlen_func(
694
- query_states,
695
- key_states,
696
- value_states,
697
- cu_seqlens_q=cu_seqlens_q,
698
- cu_seqlens_k=cu_seqlens_k,
699
- max_seqlen_q=max_seqlen_in_batch_q,
700
- max_seqlen_k=max_seqlen_in_batch_k,
701
- dropout_p=dropout,
702
- softmax_scale=softmax_scale,
703
- causal=True,
704
- )
705
-
706
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
707
- else:
708
- attn_output = flash_attn_func(
709
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
710
- )
711
-
712
- return attn_output
713
-
714
- def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
715
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
716
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
717
-
718
- key_layer = index_first_axis(
719
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
720
- )
721
- value_layer = index_first_axis(
722
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
723
- )
724
- if query_length == kv_seq_len:
725
- query_layer = index_first_axis(
726
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
727
- )
728
- cu_seqlens_q = cu_seqlens_k
729
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
730
- indices_q = indices_k
731
- elif query_length == 1:
732
- max_seqlen_in_batch_q = 1
733
- cu_seqlens_q = torch.arange(
734
- batch_size + 1, dtype=torch.int32, device=query_layer.device
735
- ) # There is a memcpy here, that is very bad.
736
- indices_q = cu_seqlens_q[:-1]
737
- query_layer = query_layer.squeeze(1)
738
- else:
739
- # The -q_len: slice assumes left padding.
740
- padding_mask = padding_mask[:, -query_length:]
741
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
742
-
743
- return (
744
- query_layer,
745
- key_layer,
746
- value_layer,
747
- indices_q,
748
- (cu_seqlens_q, cu_seqlens_k),
749
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
750
- )
751
-
752
-
753
- class LlamaDecoderLayer(nn.Module):
754
- def __init__(self, config: LlamaConfig):
755
- super().__init__()
756
- self.hidden_size = config.hidden_size
757
- self.self_attn = (
758
- LlamaAttention(config=config)
759
- if not getattr(config, "_flash_attn_2_enabled", False)
760
- else LlamaFlashAttention2(config=config)
761
- )
762
- self.mlp = LlamaMLP(config)
763
- self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
764
- self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
765
-
766
- def forward(
767
- self,
768
- hidden_states: torch.Tensor,
769
- attention_mask: Optional[torch.Tensor] = None,
770
- position_ids: Optional[torch.LongTensor] = None,
771
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
772
- output_attentions: Optional[bool] = False,
773
- use_cache: Optional[bool] = False,
774
- padding_mask: Optional[torch.LongTensor] = None,
775
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
776
- """
777
- Args:
778
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
779
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
780
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
781
- output_attentions (`bool`, *optional*):
782
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
783
- returned tensors for more detail.
784
- use_cache (`bool`, *optional*):
785
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
786
- (see `past_key_values`).
787
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
788
- """
789
-
790
- residual = hidden_states
791
-
792
- hidden_states = self.input_layernorm(hidden_states)
793
-
794
- # Self Attention
795
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
796
- hidden_states=hidden_states,
797
- attention_mask=attention_mask,
798
- position_ids=position_ids,
799
- past_key_value=past_key_value,
800
- output_attentions=output_attentions,
801
- use_cache=use_cache,
802
- padding_mask=padding_mask,
803
- )
804
- hidden_states = residual + hidden_states
805
-
806
- # Fully Connected
807
- residual = hidden_states
808
- hidden_states = self.post_attention_layernorm(hidden_states)
809
- hidden_states = self.mlp(hidden_states)
810
- hidden_states = residual + hidden_states
811
-
812
- outputs = (hidden_states,)
813
-
814
- if output_attentions:
815
- outputs += (self_attn_weights,)
816
-
817
- if use_cache:
818
- outputs += (present_key_value,)
819
-
820
- return outputs
821
-
822
-
823
- LLAMA_START_DOCSTRING = r"""
824
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
825
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
826
- etc.)
827
-
828
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
829
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
830
- and behavior.
831
-
832
- Parameters:
833
- config ([`LlamaConfig`]):
834
- Model configuration class with all the parameters of the model. Initializing with a config file does not
835
- load the weights associated with the model, only the configuration. Check out the
836
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
837
- """
838
-
839
-
840
- @add_start_docstrings(
841
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
842
- LLAMA_START_DOCSTRING,
843
- )
844
- class LlamaPreTrainedModel(PreTrainedModel):
845
- config_class = LlamaConfig
846
- base_model_prefix = "model"
847
- supports_gradient_checkpointing = True
848
- _no_split_modules = ["LlamaDecoderLayer"]
849
- _skip_keys_device_placement = "past_key_values"
850
- _supports_flash_attn_2 = True
851
-
852
- def _init_weights(self, module):
853
- std = self.config.initializer_range
854
- if isinstance(module, nn.Linear):
855
- module.weight.data.normal_(mean=0.0, std=std)
856
- if module.bias is not None:
857
- module.bias.data.zero_()
858
- elif isinstance(module, nn.Embedding):
859
- module.weight.data.normal_(mean=0.0, std=std)
860
- if module.padding_idx is not None:
861
- module.weight.data[module.padding_idx].zero_()
862
-
863
- def _set_gradient_checkpointing(self, module, value=False):
864
- if isinstance(module, LlamaModel):
865
- module.gradient_checkpointing = value
866
-
867
-
868
- LLAMA_INPUTS_DOCSTRING = r"""
869
- Args:
870
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
871
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
872
- it.
873
-
874
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
875
- [`PreTrainedTokenizer.__call__`] for details.
876
-
877
- [What are input IDs?](../glossary#input-ids)
878
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
879
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
880
-
881
- - 1 for tokens that are **not masked**,
882
- - 0 for tokens that are **masked**.
883
-
884
- [What are attention masks?](../glossary#attention-mask)
885
-
886
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
887
- [`PreTrainedTokenizer.__call__`] for details.
888
-
889
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
890
- `past_key_values`).
891
-
892
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
893
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
894
- information on the default strategy.
895
-
896
- - 1 indicates the head is **not masked**,
897
- - 0 indicates the head is **masked**.
898
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
899
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
900
- config.n_positions - 1]`.
901
-
902
- [What are position IDs?](../glossary#position-ids)
903
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
904
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
905
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
906
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
907
-
908
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
909
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
910
-
911
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
912
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
913
- of shape `(batch_size, sequence_length)`.
914
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
915
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
916
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
917
- model's internal embedding lookup matrix.
918
- use_cache (`bool`, *optional*):
919
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
920
- `past_key_values`).
921
- output_attentions (`bool`, *optional*):
922
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
923
- tensors for more detail.
924
- output_hidden_states (`bool`, *optional*):
925
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
926
- more detail.
927
- return_dict (`bool`, *optional*):
928
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
929
- """
930
-
931
-
932
- @add_start_docstrings(
933
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
934
- LLAMA_START_DOCSTRING,
935
- )
936
- class LlamaModel(LlamaPreTrainedModel):
937
- """
938
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
939
-
940
- Args:
941
- config: LlamaConfig
942
- """
943
-
944
- def __init__(self, config: LlamaConfig):
945
- super().__init__(config)
946
- self.padding_idx = config.pad_token_id
947
- self.vocab_size = config.vocab_size
948
-
949
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
950
- self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
951
- self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
952
-
953
- self.gradient_checkpointing = False
954
- # Initialize weights and apply final processing
955
- self.post_init()
956
-
957
- def get_input_embeddings(self):
958
- return self.embed_tokens
959
-
960
- def set_input_embeddings(self, value):
961
- self.embed_tokens = value
962
-
963
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
964
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
965
- # create causal mask
966
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
967
- combined_attention_mask = None
968
- if input_shape[-1] > 1:
969
- combined_attention_mask = _make_causal_mask(
970
- input_shape,
971
- inputs_embeds.dtype,
972
- device=inputs_embeds.device,
973
- past_key_values_length=past_key_values_length,
974
- )
975
-
976
- if attention_mask is not None:
977
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
978
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
979
- inputs_embeds.device
980
- )
981
- combined_attention_mask = (
982
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
983
- )
984
-
985
- return combined_attention_mask
986
-
987
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
988
- def forward(
989
- self,
990
- input_ids: torch.LongTensor = None,
991
- attention_mask: Optional[torch.Tensor] = None,
992
- position_ids: Optional[torch.LongTensor] = None,
993
- past_key_values: Optional[List[torch.FloatTensor]] = None,
994
- inputs_embeds: Optional[torch.FloatTensor] = None,
995
- use_cache: Optional[bool] = None,
996
- output_attentions: Optional[bool] = None,
997
- output_hidden_states: Optional[bool] = None,
998
- return_dict: Optional[bool] = None,
999
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1000
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1001
- output_hidden_states = (
1002
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1003
- )
1004
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1005
-
1006
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1007
-
1008
- # retrieve input_ids and inputs_embeds
1009
- if input_ids is not None and inputs_embeds is not None:
1010
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1011
- elif input_ids is not None:
1012
- batch_size, seq_length = input_ids.shape
1013
- elif inputs_embeds is not None:
1014
- batch_size, seq_length, _ = inputs_embeds.shape
1015
- else:
1016
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1017
-
1018
- seq_length_with_past = seq_length
1019
- past_key_values_length = 0
1020
-
1021
- if past_key_values is not None:
1022
- past_key_values_length = past_key_values[0][0].shape[2]
1023
- seq_length_with_past = seq_length_with_past + past_key_values_length
1024
-
1025
- if position_ids is None:
1026
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1027
- position_ids = torch.arange(
1028
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1029
- )
1030
- position_ids = position_ids.unsqueeze(0)
1031
-
1032
- if inputs_embeds is None:
1033
- inputs_embeds = self.embed_tokens(input_ids)
1034
- # embed positions
1035
- if attention_mask is None:
1036
- attention_mask = torch.ones(
1037
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
1038
- )
1039
- padding_mask = None
1040
- else:
1041
- if 0 in attention_mask:
1042
- padding_mask = attention_mask
1043
- else:
1044
- padding_mask = None
1045
-
1046
- attention_mask = self._prepare_decoder_attention_mask(
1047
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1048
- )
1049
-
1050
- hidden_states = inputs_embeds
1051
-
1052
- if self.gradient_checkpointing and self.training:
1053
- if use_cache:
1054
- logger.warning_once(
1055
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1056
- )
1057
- use_cache = False
1058
-
1059
- # decoder layers
1060
- all_hidden_states = () if output_hidden_states else None
1061
- all_self_attns = () if output_attentions else None
1062
- next_decoder_cache = () if use_cache else None
1063
-
1064
- for idx, decoder_layer in enumerate(self.layers):
1065
- if output_hidden_states:
1066
- all_hidden_states += (hidden_states,)
1067
-
1068
- past_key_value = past_key_values[idx] if past_key_values is not None else None
1069
-
1070
- if self.gradient_checkpointing and self.training:
1071
-
1072
- def create_custom_forward(module):
1073
- def custom_forward(*inputs):
1074
- # None for past_key_value
1075
- return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
1076
-
1077
- return custom_forward
1078
-
1079
- layer_outputs = torch.utils.checkpoint.checkpoint(
1080
- create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
1081
- )
1082
- else:
1083
- layer_outputs = decoder_layer(
1084
- hidden_states,
1085
- attention_mask=attention_mask,
1086
- position_ids=position_ids,
1087
- past_key_value=past_key_value,
1088
- output_attentions=output_attentions,
1089
- use_cache=use_cache,
1090
- padding_mask=padding_mask,
1091
- )
1092
-
1093
- hidden_states = layer_outputs[0]
1094
-
1095
- if use_cache:
1096
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1097
-
1098
- if output_attentions:
1099
- all_self_attns += (layer_outputs[1],)
1100
-
1101
- hidden_states = self.norm(hidden_states)
1102
-
1103
- # add hidden states from the last decoder layer
1104
- if output_hidden_states:
1105
- all_hidden_states += (hidden_states,)
1106
-
1107
- next_cache = next_decoder_cache if use_cache else None
1108
- if not return_dict:
1109
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1110
- return BaseModelOutputWithPast(
1111
- last_hidden_state=hidden_states,
1112
- past_key_values=next_cache,
1113
- hidden_states=all_hidden_states,
1114
- attentions=all_self_attns,
1115
- )
1116
-
1117
-
1118
- class LlamaForCausalLM(LlamaPreTrainedModel):
1119
- _tied_weights_keys = ["lm_head.weight"]
1120
-
1121
- def __init__(self, config):
1122
- super().__init__(config)
1123
- self.model = LlamaModel(config)
1124
- self.vocab_size = config.vocab_size
1125
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1126
-
1127
- # Initialize weights and apply final processing
1128
- self.post_init()
1129
-
1130
- def get_input_embeddings(self):
1131
- return self.model.embed_tokens
1132
-
1133
- def set_input_embeddings(self, value):
1134
- self.model.embed_tokens = value
1135
-
1136
- def get_output_embeddings(self):
1137
- return self.lm_head
1138
-
1139
- def set_output_embeddings(self, new_embeddings):
1140
- self.lm_head = new_embeddings
1141
-
1142
- def set_decoder(self, decoder):
1143
- self.model = decoder
1144
-
1145
- def get_decoder(self):
1146
- return self.model
1147
-
1148
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1149
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1150
- def forward(
1151
- self,
1152
- input_ids: torch.LongTensor = None,
1153
- attention_mask: Optional[torch.Tensor] = None,
1154
- position_ids: Optional[torch.LongTensor] = None,
1155
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1156
- inputs_embeds: Optional[torch.FloatTensor] = None,
1157
- labels: Optional[torch.LongTensor] = None,
1158
- use_cache: Optional[bool] = None,
1159
- output_attentions: Optional[bool] = None,
1160
- output_hidden_states: Optional[bool] = None,
1161
- return_dict: Optional[bool] = None,
1162
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1163
- r"""
1164
- Args:
1165
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1166
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1167
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1168
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1169
-
1170
- Returns:
1171
-
1172
- Example:
1173
-
1174
- ```python
1175
- >>> from transformers import AutoTokenizer, LlamaForCausalLM
1176
-
1177
- >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1178
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1179
-
1180
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1181
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1182
-
1183
- >>> # Generate
1184
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1185
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1186
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1187
- ```"""
1188
-
1189
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1190
- output_hidden_states = (
1191
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1192
- )
1193
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1194
-
1195
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1196
- outputs = self.model(
1197
- input_ids=input_ids,
1198
- attention_mask=attention_mask,
1199
- position_ids=position_ids,
1200
- past_key_values=past_key_values,
1201
- inputs_embeds=inputs_embeds,
1202
- use_cache=use_cache,
1203
- output_attentions=output_attentions,
1204
- output_hidden_states=output_hidden_states,
1205
- return_dict=return_dict,
1206
- )
1207
-
1208
- hidden_states = outputs[0]
1209
- if self.config.pretraining_tp > 1:
1210
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1211
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1212
- logits = torch.cat(logits, dim=-1)
1213
- else:
1214
- logits = self.lm_head(hidden_states)
1215
- logits = logits.float()
1216
-
1217
- loss = None
1218
- if labels is not None:
1219
- # Shift so that tokens < n predict n
1220
- shift_logits = logits[..., :-1, :].contiguous()
1221
- shift_labels = labels[..., 1:].contiguous()
1222
- # Flatten the tokens
1223
- loss_fct = CrossEntropyLoss()
1224
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1225
- shift_labels = shift_labels.view(-1)
1226
- # Enable model parallelism
1227
- shift_labels = shift_labels.to(shift_logits.device)
1228
- loss = loss_fct(shift_logits, shift_labels)
1229
-
1230
- if not return_dict:
1231
- output = (logits,) + outputs[1:]
1232
- return (loss,) + output if loss is not None else output
1233
-
1234
- return CausalLMOutputWithPast(
1235
- loss=loss,
1236
- logits=logits,
1237
- past_key_values=outputs.past_key_values,
1238
- hidden_states=outputs.hidden_states,
1239
- attentions=outputs.attentions,
1240
- )
1241
-
1242
- def prepare_inputs_for_generation(
1243
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1244
- ):
1245
- if past_key_values is not None:
1246
- past_length = past_key_values[0][0].shape[2]
1247
-
1248
- # Some generation methods already pass only the last input ID
1249
- if input_ids.shape[1] > past_length:
1250
- remove_prefix_length = past_length
1251
- else:
1252
- # Default to old behavior: keep only final ID
1253
- remove_prefix_length = input_ids.shape[1] - 1
1254
-
1255
- input_ids = input_ids[:, remove_prefix_length:]
1256
-
1257
- position_ids = kwargs.get("position_ids", None)
1258
- if attention_mask is not None and position_ids is None:
1259
- # create position_ids on the fly for batch generation
1260
- position_ids = attention_mask.long().cumsum(-1) - 1
1261
- position_ids.masked_fill_(attention_mask == 0, 1)
1262
- if past_key_values:
1263
- position_ids = position_ids[:, -input_ids.shape[1] :]
1264
-
1265
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1266
- if inputs_embeds is not None and past_key_values is None:
1267
- model_inputs = {"inputs_embeds": inputs_embeds}
1268
- else:
1269
- model_inputs = {"input_ids": input_ids}
1270
-
1271
- model_inputs.update(
1272
- {
1273
- "position_ids": position_ids,
1274
- "past_key_values": past_key_values,
1275
- "use_cache": kwargs.get("use_cache"),
1276
- "attention_mask": attention_mask,
1277
- }
1278
- )
1279
- return model_inputs
1280
-
1281
- @staticmethod
1282
- def _reorder_cache(past_key_values, beam_idx):
1283
- reordered_past = ()
1284
- for layer_past in past_key_values:
1285
- reordered_past += (
1286
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1287
- )
1288
- return reordered_past
1289
-
1290
-
1291
- @add_start_docstrings(
1292
- """
1293
- The LLaMa Model transformer with a sequence classification head on top (linear layer).
1294
-
1295
- [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1296
- (e.g. GPT-2) do.
1297
-
1298
- Since it does classification on the last token, it requires to know the position of the last token. If a
1299
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1300
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1301
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1302
- each row of the batch).
1303
- """,
1304
- LLAMA_START_DOCSTRING,
1305
- )
1306
- class LlamaForSequenceClassification(LlamaPreTrainedModel):
1307
- def __init__(self, config):
1308
- super().__init__(config)
1309
- self.num_labels = config.num_labels
1310
- self.model = LlamaModel(config)
1311
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1312
-
1313
- # Initialize weights and apply final processing
1314
- self.post_init()
1315
-
1316
- def get_input_embeddings(self):
1317
- return self.model.embed_tokens
1318
-
1319
- def set_input_embeddings(self, value):
1320
- self.model.embed_tokens = value
1321
-
1322
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1323
- def forward(
1324
- self,
1325
- input_ids: torch.LongTensor = None,
1326
- attention_mask: Optional[torch.Tensor] = None,
1327
- position_ids: Optional[torch.LongTensor] = None,
1328
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1329
- inputs_embeds: Optional[torch.FloatTensor] = None,
1330
- labels: Optional[torch.LongTensor] = None,
1331
- use_cache: Optional[bool] = None,
1332
- output_attentions: Optional[bool] = None,
1333
- output_hidden_states: Optional[bool] = None,
1334
- return_dict: Optional[bool] = None,
1335
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1336
- r"""
1337
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1338
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1339
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1340
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1341
- """
1342
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1343
-
1344
- transformer_outputs = self.model(
1345
- input_ids,
1346
- attention_mask=attention_mask,
1347
- position_ids=position_ids,
1348
- past_key_values=past_key_values,
1349
- inputs_embeds=inputs_embeds,
1350
- use_cache=use_cache,
1351
- output_attentions=output_attentions,
1352
- output_hidden_states=output_hidden_states,
1353
- return_dict=return_dict,
1354
- )
1355
- hidden_states = transformer_outputs[0]
1356
- logits = self.score(hidden_states)
1357
-
1358
- if input_ids is not None:
1359
- batch_size = input_ids.shape[0]
1360
- else:
1361
- batch_size = inputs_embeds.shape[0]
1362
-
1363
- if self.config.pad_token_id is None and batch_size != 1:
1364
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1365
- if self.config.pad_token_id is None:
1366
- sequence_lengths = -1
1367
- else:
1368
- if input_ids is not None:
1369
- sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1370
- logits.device
1371
- )
1372
- else:
1373
- sequence_lengths = -1
1374
-
1375
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1376
-
1377
- loss = None
1378
- if labels is not None:
1379
- labels = labels.to(logits.device)
1380
- if self.config.problem_type is None:
1381
- if self.num_labels == 1:
1382
- self.config.problem_type = "regression"
1383
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1384
- self.config.problem_type = "single_label_classification"
1385
- else:
1386
- self.config.problem_type = "multi_label_classification"
1387
-
1388
- if self.config.problem_type == "regression":
1389
- loss_fct = MSELoss()
1390
- if self.num_labels == 1:
1391
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1392
- else:
1393
- loss = loss_fct(pooled_logits, labels)
1394
- elif self.config.problem_type == "single_label_classification":
1395
- loss_fct = CrossEntropyLoss()
1396
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1397
- elif self.config.problem_type == "multi_label_classification":
1398
- loss_fct = BCEWithLogitsLoss()
1399
- loss = loss_fct(pooled_logits, labels)
1400
- if not return_dict:
1401
- output = (pooled_logits,) + transformer_outputs[1:]
1402
- return ((loss,) + output) if loss is not None else output
1403
-
1404
- return SequenceClassifierOutputWithPast(
1405
- loss=loss,
1406
- logits=pooled_logits,
1407
- past_key_values=transformer_outputs.past_key_values,
1408
- hidden_states=transformer_outputs.hidden_states,
1409
- attentions=transformer_outputs.attentions,
1410
- )