Andreas Köpf commited on
Commit
0d0ff25
1 Parent(s): adaadf5

add falcon landmark code (incomplete)

Browse files
code/configuration_RW.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Bloom configuration"""
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class RWConfig(PretrainedConfig):
24
+ model_type = "RefinedWebModel"
25
+ keys_to_ignore_at_inference = ["past_key_values"]
26
+ attribute_map = {
27
+ "num_hidden_layers": "n_layer",
28
+ "num_attention_heads": "n_head",
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ vocab_size=250880,
34
+ hidden_size=64,
35
+ n_layer=2,
36
+ n_head=8,
37
+ layer_norm_epsilon=1e-5,
38
+ initializer_range=0.02,
39
+ use_cache=True,
40
+ bos_token_id=1,
41
+ eos_token_id=2,
42
+ apply_residual_connection_post_layernorm=False,
43
+ hidden_dropout=0.0,
44
+ attention_dropout=0.0,
45
+ multi_query=False,
46
+ alibi=False,
47
+ bias=False,
48
+ parallel_attn=False,
49
+ mem_id=None,
50
+ mem_freq=50,
51
+ train_context_length=512,
52
+ **kwargs,
53
+ ):
54
+ self.vocab_size = vocab_size
55
+ # Backward compatibility with n_embed kwarg
56
+ n_embed = kwargs.pop("n_embed", None)
57
+ self.hidden_size = hidden_size if n_embed is None else n_embed
58
+ self.n_layer = n_layer
59
+ self.n_head = n_head
60
+ self.layer_norm_epsilon = layer_norm_epsilon
61
+ self.initializer_range = initializer_range
62
+ self.use_cache = use_cache
63
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
64
+ self.hidden_dropout = hidden_dropout
65
+ self.attention_dropout = attention_dropout
66
+
67
+ self.bos_token_id = bos_token_id
68
+ self.eos_token_id = eos_token_id
69
+ self.multi_query = multi_query
70
+ self.alibi = alibi
71
+ self.bias = bias
72
+ self.parallel_attn = parallel_attn
73
+
74
+ self.mem_id = mem_id
75
+ self.mem_freq = mem_freq
76
+ self.train_context_length = train_context_length
77
+
78
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
79
+
80
+ @property
81
+ def head_dim(self):
82
+ return self.hidden_size // self.n_head
83
+
84
+ @property
85
+ def rotary(self):
86
+ return not self.alibi
code/install_deps.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ pip install -r requirements.txt
4
+ pip install "git+https://github.com/openai/triton.git#subdirectory=python"
code/llama_landmark_config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.llama.configuration_llama import LlamaConfig
2
+
3
+ class LlamaLandmarkConfig(LlamaConfig):
4
+ model_type = "llama_with_landmark"
5
+
6
+ def __init__(
7
+ self,
8
+ mem_id=32001,
9
+ mem_freq=50,
10
+ train_context_length=512,
11
+ **kwargs,
12
+ ):
13
+ self.mem_id = mem_id
14
+ self.mem_freq = mem_freq
15
+ self.train_context_length = train_context_length
16
+ super().__init__(**kwargs)
code/llama_mem.py ADDED
@@ -0,0 +1,1295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
+ from llama_landmark_config import LlamaLandmarkConfig
34
+ from ltriton.flash_landmark_attention import fused_landmark_attention
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ _CONFIG_FOR_DOC = "LlamaLandmarkConfig"
40
+
41
+
42
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
43
+ def _make_causal_mask(
44
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
45
+ ):
46
+ """
47
+ Make causal mask used for bi-directional self-attention.
48
+ """
49
+ bsz, tgt_len = input_ids_shape
50
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
51
+ mask_cond = torch.arange(mask.size(-1), device=device)
52
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
53
+ mask = mask.to(dtype)
54
+
55
+ if past_key_values_length > 0:
56
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
57
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
58
+
59
+
60
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
61
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
62
+ """
63
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
64
+ """
65
+ bsz, src_len = mask.size()
66
+ tgt_len = tgt_len if tgt_len is not None else src_len
67
+
68
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
69
+
70
+ inverted_mask = 1.0 - expanded_mask
71
+
72
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
73
+
74
+
75
+ class LlamaRMSNorm(nn.Module):
76
+ def __init__(self, hidden_size, eps=1e-6):
77
+ """
78
+ LlamaRMSNorm is equivalent to T5LayerNorm
79
+ """
80
+ super().__init__()
81
+ self.weight = nn.Parameter(torch.ones(hidden_size))
82
+ self.variance_epsilon = eps
83
+
84
+ def forward(self, hidden_states):
85
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
86
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
87
+
88
+ # convert into half-precision if necessary
89
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
90
+ hidden_states = hidden_states.to(self.weight.dtype)
91
+
92
+ return self.weight * hidden_states
93
+
94
+
95
+ class LlamaRotaryEmbedding(torch.nn.Module):
96
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
97
+ super().__init__()
98
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
99
+ self.register_buffer("inv_freq", inv_freq)
100
+
101
+ # Build here to make `torch.jit.trace` work.
102
+ self.max_seq_len_cached = max_position_embeddings
103
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
104
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
105
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
106
+ emb = torch.cat((freqs, freqs), dim=-1)
107
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
108
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
109
+
110
+ def forward(self, x, seq_len=None):
111
+ # x: [bs, num_attention_heads, seq_len, head_size]
112
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
113
+ if seq_len > self.max_seq_len_cached:
114
+ self.max_seq_len_cached = seq_len
115
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
116
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
117
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
118
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
119
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
120
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
121
+ return (
122
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
123
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
124
+ )
125
+
126
+
127
+ def rotate_half(x):
128
+ """Rotates half the hidden dims of the input."""
129
+ x1 = x[..., : x.shape[-1] // 2]
130
+ x2 = x[..., x.shape[-1] // 2 :]
131
+ return torch.cat((-x2, x1), dim=-1)
132
+
133
+
134
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
135
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
136
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
137
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
138
+ if position_ids.ndim == 2:
139
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
140
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
141
+ else:
142
+ cos = cos[position_ids]
143
+ sin = sin[position_ids]
144
+ if q is None:
145
+ q_embed = None
146
+ else:
147
+ q_embed = (q * cos) + (rotate_half(q) * sin)
148
+ k_embed = (k * cos) + (rotate_half(k) * sin)
149
+ return q_embed, k_embed
150
+
151
+
152
+ class LlamaMLP(nn.Module):
153
+ def __init__(
154
+ self,
155
+ hidden_size: int,
156
+ intermediate_size: int,
157
+ hidden_act: str,
158
+ ):
159
+ super().__init__()
160
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
161
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
162
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
163
+ self.act_fn = ACT2FN[hidden_act]
164
+
165
+ def forward(self, x):
166
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
167
+
168
+ class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
169
+
170
+ # Note that forward, setup_context, and backward are @staticmethods
171
+ @staticmethod
172
+ def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
173
+ new_shape = list(x.shape)
174
+ new_shape[dim] = mem_cnt # max_mem_cnt.item()
175
+ max_by_group = x.new_zeros((*new_shape,))
176
+ max_by_group.scatter_reduce_(src=x, index=resp_mem_idx, dim=dim, reduce="amax", include_self=False)
177
+
178
+ maxes = torch.gather(max_by_group, dim, resp_mem_idx)
179
+ #x_exp = torch.exp(x - torch.where(torch.isinf(maxes), 0, maxes))
180
+ x_exp = torch.exp((x - maxes).to(torch.float32))
181
+
182
+ cumsum_by_group = torch.zeros_like(max_by_group, dtype=x_exp.dtype)
183
+
184
+ cumsum_by_group.scatter_add_(dim, resp_mem_idx, x_exp, )
185
+ denom = torch.gather(cumsum_by_group, dim, resp_mem_idx)
186
+
187
+ #probs = torch.where(denom < 0.5, 0, x_exp / denom)
188
+ probs = x_exp / denom
189
+
190
+
191
+ ctx.mem_cnt = mem_cnt
192
+ ctx.dim = dim
193
+ ctx.save_for_backward(resp_mem_idx, probs)
194
+
195
+ return probs
196
+
197
+ @staticmethod
198
+ def backward(ctx, grad_probs):
199
+ mem_cnt = ctx.mem_cnt
200
+ dim = ctx.dim
201
+ resp_mem_idx, probs = ctx.saved_tensors
202
+ grad_x = grad_dim = grad_mem_cnt = grad_resp_mem_idx = None
203
+
204
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[4]:
205
+ grad_pair = grad_probs * probs
206
+
207
+ new_shape = list(probs.shape)
208
+ new_shape[dim] = mem_cnt # max_mem_cnt.item()
209
+ cumsum_by_group = grad_pair.new_zeros((*new_shape,))
210
+ cumsum_by_group.scatter_add_(dim, resp_mem_idx, grad_pair)
211
+
212
+
213
+ if ctx.needs_input_grad[0]:
214
+ grad_sum = torch.gather(cumsum_by_group, dim, resp_mem_idx)
215
+ grad_x = grad_pair - probs * grad_sum
216
+ assert not ctx.needs_input_grad[1]
217
+ assert not ctx.needs_input_grad[2]
218
+ assert not ctx.needs_input_grad[3]
219
+
220
+ return grad_x, grad_dim, grad_mem_cnt, grad_resp_mem_idx
221
+
222
+ def landmark_grouped_softmax(x, dim, is_mem, last_section_mask):
223
+
224
+ last_and_rest_mask = last_section_mask # | mask
225
+
226
+ full_access_mask = is_mem | last_and_rest_mask
227
+
228
+ max_mem_cnt = 64
229
+ mem_group_idx = torch.cumsum(is_mem, dim=dim)
230
+ mem_bucket_id = max_mem_cnt - 1
231
+ resp_mem_idx = torch.where(last_and_rest_mask,
232
+ max_mem_cnt - 1,
233
+ torch.where(is_mem, mem_bucket_id, mem_group_idx))
234
+ probs = LandmarkGroupedSoftmaxFunction.apply(x, dim, max_mem_cnt, resp_mem_idx)
235
+
236
+ new_shape = list(x.shape)
237
+ new_shape[dim] = max_mem_cnt
238
+ group_prob = probs.new_zeros((*new_shape, ))
239
+ group_prob.scatter_(dim, torch.where(is_mem, mem_group_idx - 1, max_mem_cnt - 1), probs)
240
+ probs = probs.mul(torch.where(full_access_mask, last_section_mask, torch.gather(group_prob, dim, resp_mem_idx)))
241
+
242
+
243
+ return probs
244
+
245
+ class LlamaAttention(nn.Module):
246
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
247
+
248
+ def __init__(self, config: LlamaLandmarkConfig):
249
+ super().__init__()
250
+ self.config = config
251
+ self.hidden_size = config.hidden_size
252
+ self.num_heads = config.num_attention_heads
253
+ self.head_dim = self.hidden_size // self.num_heads
254
+ self.max_position_embeddings = config.max_position_embeddings
255
+
256
+ if (self.head_dim * self.num_heads) != self.hidden_size:
257
+ raise ValueError(
258
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
259
+ f" and `num_heads`: {self.num_heads})."
260
+ )
261
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
262
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
263
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
264
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
265
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
266
+
267
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
268
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
269
+
270
+ def forward(
271
+ self,
272
+ hidden_states: torch.Tensor,
273
+ attention_mask: Optional[torch.Tensor] = None,
274
+ position_ids: Optional[torch.LongTensor] = None,
275
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
276
+ output_attentions: bool = False,
277
+ use_cache: bool = False,
278
+ is_mem: Optional[torch.Tensor] = None,
279
+ last_section_mask: Optional[torch.Tensor] = None,
280
+ offload_cache_to_cpu: bool = False,
281
+ use_flash: bool = False,
282
+ cache_top_k: Optional[int] = None,
283
+ mem_freq: Optional[int] = None
284
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
285
+ bsz, q_len, _ = hidden_states.size()
286
+
287
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
288
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
289
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
290
+
291
+ kv_seq_len = key_states.shape[-2]
292
+ if past_key_value is not None:
293
+ kv_seq_len += past_key_value[0].shape[-2]
294
+ if len(past_key_value) > 2:
295
+ kv_seq_len += past_key_value[3].shape[2] * past_key_value[3].shape[3]
296
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
297
+ key_states_before_pos = key_states
298
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
299
+ # [bsz, nh, t, hd]
300
+
301
+ attn_prefix = None
302
+ if past_key_value is not None:
303
+ # reuse k, v, self_attention
304
+ if mem_freq is None:
305
+ cache_len = past_key_value[0].shape[2]
306
+ if is_mem is not None:
307
+ if use_flash:
308
+ is_mem = torch.cat((is_mem.new_zeros((1, cache_len)), is_mem), dim=-1)
309
+ else:
310
+ is_mem = torch.cat((is_mem.new_zeros((1, 1, q_len, cache_len)), is_mem), dim=-1)
311
+ last_section_mask = torch.cat((last_section_mask.new_ones((1, 1, q_len, cache_len)), last_section_mask), dim=-1)
312
+
313
+ past_key_states = torch.cat([past_key_value[0], key_states], dim=2)
314
+ past_value_states = torch.cat([past_key_value[1], value_states], dim=2)
315
+ key_states = past_key_states[:, :, -(q_len + cache_len):]
316
+ value_states = past_value_states[:, :, -(q_len + cache_len):]
317
+ expected_att_size = (bsz, self.num_heads, q_len, cache_len + q_len)
318
+ else:
319
+ orig_value_states = value_states
320
+
321
+ incomplete_len = past_key_value[0].shape[2] % (mem_freq + 1)
322
+ full_len = past_key_value[0].shape[2] - incomplete_len
323
+ past_key_mem, past_key_incomplete = torch.split(past_key_value[0], (full_len, incomplete_len), dim=2)
324
+ past_value_mem, past_value_incomplete = torch.split(past_key_value[1], (full_len, incomplete_len), dim=2)
325
+
326
+ if offload_cache_to_cpu:
327
+ past_key_value = (past_key_incomplete, past_value_incomplete, *past_key_value[2:])
328
+
329
+ if incomplete_len > 0:
330
+ assert q_len + incomplete_len <= (mem_freq + 1)
331
+ if use_flash:
332
+ is_mem = torch.cat((is_mem.new_zeros((1, incomplete_len)), is_mem), dim=-1)
333
+ else:
334
+ is_mem = torch.cat((is_mem.new_zeros((1, 1, q_len, incomplete_len)), is_mem), dim=-1)
335
+ last_section_mask = torch.cat((last_section_mask.new_ones((1, 1, q_len, incomplete_len)), last_section_mask), dim=-1)
336
+
337
+ if len(past_key_value) > 2:
338
+ full_len += past_key_value[3].shape[2] * past_key_value[3].shape[3]
339
+ past_key_incomplete_pos = torch.arange(full_len, full_len + incomplete_len, dtype=torch.long, device=position_ids.device).unsqueeze(0)
340
+ _, past_key_incomplete = apply_rotary_pos_emb(None, past_key_incomplete, cos, sin, past_key_incomplete_pos)
341
+ key_states = torch.cat((past_key_incomplete, key_states), dim=2)
342
+ value_states = torch.cat((past_value_incomplete, value_states), dim=2)
343
+
344
+ past_key_mem = past_key_mem.view(bsz, self.num_heads, -1, mem_freq + 1, self.head_dim)
345
+ past_value_mem = past_value_mem.view(bsz, self.num_heads, -1, mem_freq + 1, self.head_dim)
346
+
347
+ if len(past_key_value) > 2:
348
+ mem_key_nopos = torch.cat((
349
+ past_key_value[2],
350
+ past_key_mem.select(dim=3, index=mem_freq)), dim=2)
351
+ past_key_mem_offload = past_key_value[3]
352
+ past_key_mem = torch.cat((
353
+ past_key_mem_offload,
354
+ past_key_mem.to(past_key_mem_offload.device)), dim=2)
355
+ past_value_mem = torch.cat((past_key_value[4], past_value_mem.to(past_key_mem_offload.device)), dim=2)
356
+ else:
357
+ mem_key_nopos = past_key_mem.select(dim=3, index=mem_freq)
358
+
359
+ num_mems = past_key_mem.shape[2]
360
+ top_k = min(cache_top_k, num_mems)
361
+ prefix_len = full_len - (top_k + 1) * (mem_freq + 1)
362
+ mem_indices = torch.cat(
363
+ (position_ids.new_zeros((max(0, num_mems - top_k), )),
364
+ torch.arange(1, top_k + 1, device=query_states.device, dtype=position_ids.dtype)), dim=0)
365
+ mem_pos = (mem_indices * (mem_freq + 1) + mem_freq).unsqueeze(0).expand(bsz, -1) + prefix_len
366
+ _, mem_key = apply_rotary_pos_emb(None, mem_key_nopos, cos, sin, mem_pos)
367
+ mem_attn_weights = torch.matmul(query_states, mem_key.transpose(2, 3)) / math.sqrt(self.head_dim)
368
+
369
+ if offload_cache_to_cpu:
370
+ aggregate = "max_over_tokens"
371
+ else:
372
+ aggregate = None
373
+ if aggregate == "max_over_tokens":
374
+ token_retrievers = 1
375
+ head_retrievers = self.num_heads
376
+ mem_attn_weights = torch.nn.functional.softmax(mem_attn_weights, dim=-1,dtype=torch.float32).to(query_states.dtype)
377
+ mem_attn_weights = mem_attn_weights.amax(dim=2, keepdim=True)
378
+ elif aggregate is None:
379
+ token_retrievers = q_len
380
+ head_retrievers = self.num_heads
381
+ else:
382
+ raise NotImplementedError()
383
+
384
+ mem_selected_idx = mem_attn_weights.topk(dim=-1,k=top_k)[1].sort(dim=-1)[0].view(bsz, head_retrievers, token_retrievers, top_k)
385
+
386
+ selected_indices = torch.arange(0, top_k * (mem_freq + 1), device=query_states.device, dtype=position_ids.dtype)
387
+ selected_indices = torch.where(mem_selected_idx >= num_mems - top_k, mem_freq + 1, 0).unsqueeze(-1) + selected_indices.view(1, 1, 1, top_k, mem_freq + 1)
388
+ selected_indices = selected_indices.view(bsz, head_retrievers, token_retrievers, -1) + prefix_len
389
+
390
+
391
+
392
+
393
+ mem_selected_idx = mem_selected_idx.to(past_key_mem.device)
394
+
395
+ mem_selected_idx = mem_selected_idx.view(bsz, self.num_heads, token_retrievers, top_k, 1, 1).expand(bsz, self.num_heads, token_retrievers, top_k, mem_freq + 1, self.head_dim)
396
+ selected_keys = past_key_mem.unsqueeze(2).expand(bsz, self.num_heads, token_retrievers, -1, mem_freq + 1, self.head_dim)
397
+ selected_keys = selected_keys.take_along_dim(mem_selected_idx, dim=3).to(query_states.device)
398
+ selected_values = past_value_mem.unsqueeze(2).expand(bsz, self.num_heads, token_retrievers, -1, mem_freq + 1, self.head_dim).take_along_dim(mem_selected_idx, dim=3).to(query_states.device)
399
+
400
+ if aggregate == "max_over_tokens":
401
+ selected_indices = selected_indices.squeeze(2)
402
+ selected_keys = selected_keys.view(bsz, self.num_heads, -1, self.head_dim)
403
+ selected_keys = apply_rotary_pos_emb(None, selected_keys, cos, sin, selected_indices)[1]
404
+ key_states = torch.cat((selected_keys, key_states), dim=2)
405
+ value_states = torch.cat((selected_values.view(bsz, self.num_heads, -1, self.head_dim), value_states), dim=2)
406
+ expected_att_size = (bsz, self.num_heads, q_len, key_states.shape[2])
407
+ else:
408
+ selected_indices = selected_indices.expand(bsz, self.num_heads, q_len, -1)
409
+ selected_keys = selected_keys.view(bsz, self.num_heads, token_retrievers, -1, self.head_dim).expand(bsz, self.num_heads, q_len, -1, self.head_dim)
410
+ selected_keys = apply_rotary_pos_emb(None, selected_keys, cos, sin, selected_indices)[1]
411
+ selected_values = selected_values.view(bsz, self.num_heads, token_retrievers, -1, self.head_dim).expand(bsz, self.num_heads, q_len, -1, self.head_dim)
412
+ attn_prefix = torch.matmul(query_states.unsqueeze(3), selected_keys.transpose(3, 4)).squeeze(3) / math.sqrt(self.head_dim)
413
+ expected_att_size = (bsz, self.num_heads, q_len, q_len + incomplete_len)
414
+
415
+ is_mem_prefix = torch.cat((is_mem.new_zeros((mem_freq, )), is_mem.new_ones((1, )))).unsqueeze(0).repeat((top_k, 1))
416
+ if use_flash:
417
+ is_mem_prefix = is_mem_prefix.view(1, -1)
418
+ else:
419
+ is_mem_prefix = is_mem_prefix.view(1, 1, 1, -1).expand(1, 1, q_len, -1)
420
+ last_section_mask = torch.cat((last_section_mask.new_zeros((1, 1, q_len, top_k * (mem_freq + 1))), last_section_mask), dim=-1)
421
+ is_mem = torch.cat((is_mem_prefix, is_mem), dim=-1)
422
+
423
+
424
+ past_key_states = torch.cat([past_key_value[0], key_states_before_pos], dim=2)
425
+ past_value_states = torch.cat([past_key_value[1], orig_value_states], dim=2)
426
+
427
+ if offload_cache_to_cpu:
428
+ past_key_value = (past_key_states, past_value_states, mem_key_nopos, past_key_mem.to("cpu"), past_value_mem.to("cpu"), *past_key_value[5:]) if use_cache else None
429
+ else:
430
+ past_key_value = (past_key_states, past_value_states) if use_cache else None
431
+
432
+ else:
433
+ if mem_freq is None:
434
+ past_key_states = key_states
435
+ else:
436
+ past_key_states = key_states_before_pos
437
+ past_value_states = value_states
438
+ expected_att_size = (bsz, self.num_heads, q_len, kv_seq_len)
439
+ past_key_value = (past_key_states, past_value_states) if use_cache else None
440
+
441
+ if use_flash:
442
+ assert attn_prefix is None
443
+ assert not output_attentions
444
+ assert mem_freq is not None
445
+ attn_output = fused_landmark_attention(query_states, key_states, value_states, is_mem, block_size=mem_freq+1)
446
+ attn_weights = None
447
+ else:
448
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
449
+ if attn_weights.size() != expected_att_size:
450
+ raise ValueError(
451
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
452
+ f" {attn_weights.size()}"
453
+ )
454
+
455
+ if attention_mask is not None:
456
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
457
+ raise ValueError(
458
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
459
+ )
460
+ attn_weights = attn_weights + attention_mask[...,-attn_weights.shape[-1]:]
461
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
462
+ if attn_prefix is not None:
463
+ attn_weights = torch.cat((attn_prefix, attn_weights), dim=-1)
464
+ # upcast attention to fp32
465
+ if is_mem is None:
466
+ raise ValueError("Don't use this without landmarks")
467
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
468
+ else:
469
+ attn_weights = landmark_grouped_softmax(attn_weights, dim=-1, is_mem=is_mem.expand(-1, self.num_heads, -1, -1), last_section_mask=last_section_mask).to(query_states.dtype)
470
+ if attn_prefix is not None:
471
+ attn_prefix, attn_weights = torch.split(attn_weights, (attn_prefix.shape[-1], attn_weights.shape[-1] - attn_prefix.shape[-1]), dim=-1)
472
+ attn_output = torch.matmul(attn_weights, value_states)
473
+ if attn_prefix is not None:
474
+ attn_output += torch.matmul(attn_prefix.unsqueeze(3), selected_values).squeeze(3)
475
+
476
+ if not output_attentions:
477
+ attn_weights = None
478
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
479
+ raise ValueError(
480
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
481
+ f" {attn_output.size()}"
482
+ )
483
+
484
+ attn_output = attn_output.transpose(1, 2)
485
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
486
+ attn_output = self.o_proj(attn_output)
487
+
488
+ return attn_output, attn_weights, past_key_value
489
+
490
+
491
+ class LlamaDecoderLayer(nn.Module):
492
+ def __init__(self, config: LlamaLandmarkConfig):
493
+ super().__init__()
494
+ self.hidden_size = config.hidden_size
495
+ self.self_attn = LlamaAttention(config=config)
496
+ self.mlp = LlamaMLP(
497
+ hidden_size=self.hidden_size,
498
+ intermediate_size=config.intermediate_size,
499
+ hidden_act=config.hidden_act,
500
+ )
501
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
502
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
503
+
504
+ def forward(
505
+ self,
506
+ hidden_states: torch.Tensor,
507
+ attention_mask: Optional[torch.Tensor] = None,
508
+ position_ids: Optional[torch.LongTensor] = None,
509
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
510
+ output_attentions: Optional[bool] = False,
511
+ use_cache: Optional[bool] = False,
512
+ is_mem: Optional[torch.Tensor] = None,
513
+ last_section_mask: Optional[torch.Tensor] = None,
514
+ offload_cache_to_cpu: bool = False,
515
+ use_flash: bool = False,
516
+ cache_top_k: Optional[int] = None,
517
+ mem_freq: Optional[int] = None
518
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
519
+ """
520
+ Args:
521
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
522
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
523
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
524
+ output_attentions (`bool`, *optional*):
525
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
526
+ returned tensors for more detail.
527
+ use_cache (`bool`, *optional*):
528
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
529
+ (see `past_key_values`).
530
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
531
+ """
532
+
533
+ residual = hidden_states
534
+
535
+ hidden_states = self.input_layernorm(hidden_states)
536
+
537
+ # Self Attention
538
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
539
+ hidden_states=hidden_states,
540
+ attention_mask=attention_mask,
541
+ position_ids=position_ids,
542
+ past_key_value=past_key_value,
543
+ output_attentions=output_attentions,
544
+ use_cache=use_cache,
545
+ is_mem=is_mem,
546
+ last_section_mask=last_section_mask,
547
+ offload_cache_to_cpu=offload_cache_to_cpu,
548
+ use_flash=use_flash,
549
+ cache_top_k=cache_top_k,
550
+ mem_freq=mem_freq
551
+ )
552
+ hidden_states = residual + hidden_states
553
+
554
+ # Fully Connected
555
+ residual = hidden_states
556
+ hidden_states = self.post_attention_layernorm(hidden_states)
557
+ hidden_states = self.mlp(hidden_states)
558
+ hidden_states = residual + hidden_states
559
+
560
+ outputs = (hidden_states,)
561
+
562
+ if output_attentions:
563
+ outputs += (self_attn_weights,)
564
+
565
+ if use_cache:
566
+ outputs += (present_key_value,)
567
+
568
+ return outputs
569
+
570
+
571
+ LLAMA_START_DOCSTRING = r"""
572
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
573
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
574
+ etc.)
575
+
576
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
577
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
578
+ and behavior.
579
+
580
+ Parameters:
581
+ config ([`LlamaLandmarkConfig`]):
582
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
583
+ load the weights associated with the model, only the configuration. Check out the
584
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
585
+ """
586
+
587
+
588
+ @add_start_docstrings(
589
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
590
+ LLAMA_START_DOCSTRING,
591
+ )
592
+ class LlamaPreTrainedModel(PreTrainedModel):
593
+ config_class = LlamaLandmarkConfig
594
+ base_model_prefix = "model"
595
+ supports_gradient_checkpointing = True
596
+ _no_split_modules = ["LlamaDecoderLayer"]
597
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
598
+
599
+ def _init_weights(self, module):
600
+ std = self.config.initializer_range
601
+ if isinstance(module, nn.Linear):
602
+ module.weight.data.normal_(mean=0.0, std=std)
603
+ if module.bias is not None:
604
+ module.bias.data.zero_()
605
+ elif isinstance(module, nn.Embedding):
606
+ module.weight.data.normal_(mean=0.0, std=std)
607
+ if module.padding_idx is not None:
608
+ module.weight.data[module.padding_idx].zero_()
609
+
610
+ def _set_gradient_checkpointing(self, module, value=False):
611
+ if isinstance(module, LlamaModel):
612
+ module.gradient_checkpointing = value
613
+
614
+
615
+ LLAMA_INPUTS_DOCSTRING = r"""
616
+ Args:
617
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
618
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
619
+ it.
620
+
621
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
622
+ [`PreTrainedTokenizer.__call__`] for details.
623
+
624
+ [What are input IDs?](../glossary#input-ids)
625
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
626
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
627
+
628
+ - 1 for tokens that are **not masked**,
629
+ - 0 for tokens that are **masked**.
630
+
631
+ [What are attention masks?](../glossary#attention-mask)
632
+
633
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
634
+ [`PreTrainedTokenizer.__call__`] for details.
635
+
636
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
637
+ `past_key_values`).
638
+
639
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
640
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
641
+ information on the default strategy.
642
+
643
+ - 1 indicates the head is **not masked**,
644
+ - 0 indicates the head is **masked**.
645
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
646
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
647
+ config.n_positions - 1]`.
648
+
649
+ [What are position IDs?](../glossary#position-ids)
650
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
651
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
652
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
653
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
654
+
655
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
656
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
657
+
658
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
659
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
660
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
661
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
662
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
663
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
664
+ model's internal embedding lookup matrix.
665
+ use_cache (`bool`, *optional*):
666
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
667
+ `past_key_values`).
668
+ output_attentions (`bool`, *optional*):
669
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
670
+ tensors for more detail.
671
+ output_hidden_states (`bool`, *optional*):
672
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
673
+ more detail.
674
+ return_dict (`bool`, *optional*):
675
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
676
+ """
677
+
678
+
679
+ @add_start_docstrings(
680
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
681
+ LLAMA_START_DOCSTRING,
682
+ )
683
+ class LlamaModel(LlamaPreTrainedModel):
684
+ """
685
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
686
+
687
+ Args:
688
+ config: LlamaLandmarkConfig
689
+ """
690
+
691
+ def __init__(self, config: LlamaLandmarkConfig):
692
+ super().__init__(config)
693
+ self.padding_idx = config.pad_token_id
694
+ self.vocab_size = config.vocab_size
695
+
696
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
697
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
698
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
699
+
700
+ self.gradient_checkpointing = False
701
+ # Initialize weights and apply final processing
702
+ self.post_init()
703
+
704
+ def get_input_embeddings(self):
705
+ return self.embed_tokens
706
+
707
+ def set_input_embeddings(self, value):
708
+ self.embed_tokens = value
709
+
710
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
711
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
712
+ # create causal mask
713
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
714
+ combined_attention_mask = None
715
+ if input_shape[-1] > 1:
716
+ combined_attention_mask = _make_causal_mask(
717
+ input_shape,
718
+ inputs_embeds.dtype,
719
+ device=inputs_embeds.device,
720
+ past_key_values_length=past_key_values_length,
721
+ )
722
+
723
+ if attention_mask is not None:
724
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
725
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
726
+ inputs_embeds.device
727
+ )
728
+ combined_attention_mask = (
729
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
730
+ )
731
+
732
+ return combined_attention_mask
733
+
734
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
735
+ def forward(
736
+ self,
737
+ input_ids: torch.LongTensor = None,
738
+ attention_mask: Optional[torch.Tensor] = None,
739
+ position_ids: Optional[torch.LongTensor] = None,
740
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
741
+ inputs_embeds: Optional[torch.FloatTensor] = None,
742
+ use_cache: Optional[bool] = None,
743
+ output_attentions: Optional[bool] = None,
744
+ output_hidden_states: Optional[bool] = None,
745
+ return_dict: Optional[bool] = None,
746
+ offload_cache_to_cpu: Optional[bool] = None,
747
+ use_flash: Optional[bool] = None,
748
+ cache_top_k: Optional[int] = None,
749
+ mem_freq: Optional[int] = None
750
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
751
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
752
+ output_hidden_states = (
753
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
754
+ )
755
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
756
+
757
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
758
+
759
+ # retrieve input_ids and inputs_embeds
760
+ is_mem = None
761
+ if input_ids is not None and inputs_embeds is not None:
762
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
763
+ elif input_ids is not None:
764
+ batch_size, seq_length = input_ids.shape
765
+ if self.config.mem_id is not None:
766
+ with torch.no_grad():
767
+ is_mem = input_ids == self.config.mem_id
768
+ elif inputs_embeds is not None:
769
+ batch_size, seq_length, _ = inputs_embeds.shape
770
+ if self.config.mem_id is not None:
771
+ raise NotImplementedError
772
+ else:
773
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
774
+
775
+ seq_length_with_past = seq_length
776
+ past_key_values_length = 0
777
+
778
+ if past_key_values is not None:
779
+ if is_mem is not None:
780
+ pass
781
+ #raise NotImplementedError
782
+ past_key_values_length = past_key_values[0][0].shape[2]
783
+ if len(past_key_values[0]) > 2:
784
+ past_key_values_length += past_key_values[0][3].shape[2] * past_key_values[0][3].shape[3]
785
+ seq_length_with_past = seq_length_with_past + past_key_values_length
786
+
787
+ if position_ids is None:
788
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
789
+ position_ids = torch.arange(
790
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
791
+ )
792
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
793
+ else:
794
+ position_ids = position_ids.view(-1, seq_length).long()
795
+
796
+ if inputs_embeds is None:
797
+ inputs_embeds = self.embed_tokens(input_ids)
798
+ # embed positions
799
+ if attention_mask is None:
800
+ attention_mask = torch.ones(
801
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
802
+ )
803
+ attention_mask = self._prepare_decoder_attention_mask(
804
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
805
+ )
806
+
807
+ last_section_mask = None
808
+ if is_mem is not None and not use_flash:
809
+ is_mem = is_mem.unsqueeze(1).unsqueeze(2)
810
+ current_len = input_ids.shape[1]
811
+ mem_ids = torch.where(attention_mask[..., -current_len:] < -1, 0, torch.cumsum(is_mem, -1) - is_mem.int())
812
+ last_section_mask = torch.amax(mem_ids, -1, keepdim=True) == mem_ids
813
+ attention_mask[..., -current_len:].masked_fill_(last_section_mask & is_mem, torch.tensor(torch.finfo(inputs_embeds.dtype).min, device=inputs_embeds.device))
814
+ last_section_mask.logical_and_(attention_mask[..., -current_len:] > -1)
815
+ is_mem = is_mem.logical_and(attention_mask[..., -current_len:] > -1)
816
+
817
+
818
+ hidden_states = inputs_embeds
819
+
820
+ if self.gradient_checkpointing and self.training:
821
+ if use_cache:
822
+ logger.warning_once(
823
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
824
+ )
825
+ use_cache = False
826
+
827
+ # decoder layers
828
+ all_hidden_states = () if output_hidden_states else None
829
+ all_self_attns = () if output_attentions else None
830
+ next_decoder_cache = () if use_cache else None
831
+
832
+ for idx, decoder_layer in enumerate(self.layers):
833
+ if output_hidden_states:
834
+ all_hidden_states += (hidden_states,)
835
+
836
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
837
+
838
+ if self.gradient_checkpointing and self.training:
839
+
840
+ def create_custom_forward(module):
841
+ def custom_forward(*inputs):
842
+ # None for past_key_value
843
+ return module(*inputs, output_attentions, None)
844
+
845
+ return custom_forward
846
+
847
+ layer_outputs = torch.utils.checkpoint.checkpoint(
848
+ create_custom_forward(decoder_layer),
849
+ hidden_states,
850
+ attention_mask,
851
+ position_ids,
852
+ None,
853
+ is_mem,
854
+ last_section_mask,
855
+ offload_cache_to_cpu,
856
+ use_flash,
857
+ cache_top_k,
858
+ mem_freq
859
+ )
860
+ else:
861
+ layer_outputs = decoder_layer(
862
+ hidden_states,
863
+ attention_mask=attention_mask,
864
+ position_ids=position_ids,
865
+ past_key_value=past_key_value,
866
+ output_attentions=output_attentions,
867
+ use_cache=use_cache,
868
+ is_mem=is_mem,
869
+ last_section_mask=last_section_mask,
870
+ offload_cache_to_cpu=offload_cache_to_cpu,
871
+ use_flash=use_flash,
872
+ cache_top_k=cache_top_k,
873
+ mem_freq=mem_freq,
874
+ )
875
+
876
+ hidden_states = layer_outputs[0]
877
+
878
+ if use_cache:
879
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
880
+
881
+ if output_attentions:
882
+ all_self_attns += (layer_outputs[1],)
883
+
884
+ hidden_states = self.norm(hidden_states)
885
+
886
+ # add hidden states from the last decoder layer
887
+ if output_hidden_states:
888
+ all_hidden_states += (hidden_states,)
889
+
890
+ next_cache = next_decoder_cache if use_cache else None
891
+ if not return_dict:
892
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
893
+ return BaseModelOutputWithPast(
894
+ last_hidden_state=hidden_states,
895
+ past_key_values=next_cache,
896
+ hidden_states=all_hidden_states,
897
+ attentions=all_self_attns,
898
+ )
899
+
900
+
901
+ class LlamaForCausalLM(LlamaPreTrainedModel):
902
+ def __init__(self, config):
903
+ super().__init__(config)
904
+ self.model = LlamaModel(config)
905
+
906
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
907
+
908
+ self.auto_insert_landmarks = False
909
+ self.always_use_flash = False
910
+
911
+ # Initialize weights and apply final processing
912
+ self.post_init()
913
+
914
+ def get_input_embeddings(self):
915
+ return self.model.embed_tokens
916
+
917
+ def set_input_embeddings(self, value):
918
+ self.model.embed_tokens = value
919
+
920
+ def get_output_embeddings(self):
921
+ return self.lm_head
922
+
923
+ def set_output_embeddings(self, new_embeddings):
924
+ self.lm_head = new_embeddings
925
+
926
+ def set_decoder(self, decoder):
927
+ self.model = decoder
928
+
929
+ def get_decoder(self):
930
+ return self.model
931
+
932
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
933
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
934
+ def forward(
935
+ self,
936
+ input_ids: torch.LongTensor = None,
937
+ attention_mask: Optional[torch.Tensor] = None,
938
+ position_ids: Optional[torch.LongTensor] = None,
939
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
940
+ inputs_embeds: Optional[torch.FloatTensor] = None,
941
+ labels: Optional[torch.LongTensor] = None,
942
+ use_cache: Optional[bool] = None,
943
+ output_attentions: Optional[bool] = None,
944
+ output_hidden_states: Optional[bool] = None,
945
+ return_dict: Optional[bool] = None,
946
+ offload_cache_to_cpu: Optional[bool] = None,
947
+ use_flash: Optional[bool] = None,
948
+ cache_top_k: Optional[int] = None,
949
+ max_chunk_length: Optional[int] = 0,
950
+ mem_freq: Optional[int] = None,
951
+ drop_last_logit_if_mem: Optional[bool] = False,
952
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
953
+ r"""
954
+ Args:
955
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
956
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
957
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
958
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
959
+
960
+ Returns:
961
+
962
+ Example:
963
+
964
+ ```python
965
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
966
+
967
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
968
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
969
+
970
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
971
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
972
+
973
+ >>> # Generate
974
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
975
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
976
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
977
+ ```"""
978
+
979
+ use_flash = use_flash if use_flash is not None else self.always_use_flash
980
+
981
+ if self.auto_insert_landmarks:
982
+ mem_freq = self.config.mem_freq
983
+ assert self.config.mem_freq is not None
984
+ block_size = self.config.mem_freq + 1
985
+ input_ids = input_ids.view(input_ids.shape[0], -1, block_size - 1)
986
+ input_ids = torch.cat((input_ids, input_ids.new_full((input_ids.shape[0], input_ids.shape[1], 1), self.config.mem_id)), dim=-1)
987
+ input_ids = input_ids.view(input_ids.shape[0], -1)
988
+ if attention_mask is not None:
989
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1, block_size - 1)
990
+ attention_mask = torch.cat((attention_mask, attention_mask.new_ones((attention_mask.shape[0], attention_mask.shape[1], 1))), dim=-1)
991
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1)
992
+
993
+
994
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
995
+ output_hidden_states = (
996
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
997
+ )
998
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
999
+
1000
+ if max_chunk_length == 0:
1001
+ if cache_top_k is not None:
1002
+ max_chunk_length = self.config.train_context_length - self.config.train_context_length % (mem_freq + 1) - (cache_top_k + 1) * (mem_freq + 1)
1003
+ if max_chunk_length <= 0:
1004
+ raise ValueError("K is too large for this model.")
1005
+ else:
1006
+ max_chunk_length = None
1007
+
1008
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1009
+ window_len = max_chunk_length or input_ids.shape[1]
1010
+ if use_flash:
1011
+ assert window_len % (mem_freq + 1) == 0
1012
+ last_logits = None
1013
+ for step, idx in enumerate(range(0, input_ids.shape[1], window_len)):
1014
+ if idx >= 1:
1015
+ if output_attentions or output_hidden_states:
1016
+ raise NotImplementedError
1017
+ if not use_cache:
1018
+ raise NotImplementedError
1019
+ outputs = self.model(
1020
+ input_ids=input_ids[:, idx:idx + window_len],
1021
+ attention_mask=attention_mask[:, :idx + window_len + attention_mask.shape[1] - input_ids.shape[1]] if attention_mask is not None else None,
1022
+ position_ids=position_ids[:, idx:idx + window_len] if position_ids is not None else None,
1023
+ past_key_values=past_key_values,
1024
+ inputs_embeds=inputs_embeds[:, idx:idx + window_len] if inputs_embeds is not None else None,
1025
+ use_cache=use_cache,
1026
+ output_attentions=output_attentions,
1027
+ output_hidden_states=output_hidden_states,
1028
+ return_dict=return_dict,
1029
+ offload_cache_to_cpu=offload_cache_to_cpu,
1030
+ use_flash=(use_flash or self.auto_insert_landmarks),
1031
+ cache_top_k=cache_top_k,
1032
+ mem_freq=mem_freq,
1033
+ )
1034
+ past_key_values = outputs[1]
1035
+ if last_logits is not None:
1036
+ last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
1037
+ last_logits = outputs[0]
1038
+
1039
+ hidden_states = last_logits
1040
+ if self.auto_insert_landmarks:
1041
+ block_size = self.config.mem_freq + 1
1042
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], hidden_states.shape[1] // block_size, block_size, hidden_states.shape[2])
1043
+ hidden_states = hidden_states[:, :, :block_size - 1]
1044
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], -1, hidden_states.shape[3])
1045
+ if drop_last_logit_if_mem:
1046
+ is_any_mem = (input_ids[:, -1] == self.config.mem_id).any()
1047
+ are_all_mem = (input_ids[:, -1] == self.config.mem_id).all()
1048
+ assert is_any_mem == are_all_mem
1049
+ if is_any_mem:
1050
+ hidden_states = hidden_states[:, :-1]
1051
+ logits = self.lm_head(hidden_states)
1052
+
1053
+ loss = None
1054
+ if labels is not None:
1055
+ # Shift so that tokens < n predict n
1056
+ shift_logits = logits[..., :-1, :].contiguous()
1057
+ shift_labels = labels[..., 1:].contiguous()
1058
+ # Flatten the tokens
1059
+ loss_fct = CrossEntropyLoss()
1060
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1061
+ shift_labels = shift_labels.view(-1)
1062
+ # Enable model parallelism
1063
+ shift_labels = shift_labels.to(shift_logits.device)
1064
+ loss = loss_fct(shift_logits, shift_labels)
1065
+
1066
+ if not return_dict:
1067
+ output = (logits,) + outputs[1:]
1068
+ return (loss,) + output if loss is not None else output
1069
+
1070
+ return CausalLMOutputWithPast(
1071
+ loss=loss,
1072
+ logits=logits,
1073
+ past_key_values=outputs.past_key_values,
1074
+ hidden_states=outputs.hidden_states,
1075
+ attentions=outputs.attentions,
1076
+ )
1077
+
1078
+ def set_mem_id(self, mem_id):
1079
+ if self.config.mem_id is not None:
1080
+ assert mem_id == self.config.mem_id, "Chanigng mem_id can break the model. If you really intend to do this, manually disable this check"
1081
+ self.config.mem_id = mem_id
1082
+
1083
+ def enable_landmark_insertion(self):
1084
+ self.auto_insert_landmarks = True
1085
+
1086
+ def enable_flash(self):
1087
+ self.always_use_flash = True
1088
+
1089
+ def prepare_inputs_for_generation(
1090
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1091
+ ):
1092
+ total_len = input_ids.shape[1]
1093
+ if past_key_values:
1094
+ prev_len = input_ids.shape[1] - 1
1095
+ use_flash = False if kwargs.get("use_flash") is not None else None
1096
+ else:
1097
+ prev_len = 0
1098
+ use_flash = kwargs.get("use_flash")
1099
+
1100
+ position_ids = kwargs.get("position_ids", None)
1101
+
1102
+ mem_freq = kwargs.get("mem_freq") or self.config.mem_freq
1103
+
1104
+ if mem_freq is not None:
1105
+ if position_ids is not None:
1106
+ raise NotImplementedError
1107
+ T = input_ids.shape[1]
1108
+
1109
+ prev_incomplete_len = prev_len % mem_freq
1110
+ prev_complete_len = prev_len - prev_incomplete_len
1111
+ incomplete_len = total_len % mem_freq
1112
+ new_full_len = total_len - prev_complete_len - incomplete_len
1113
+
1114
+ prev_input, input_ids_with_mem, input_ids_without_mem = torch.split(input_ids, (prev_complete_len, new_full_len, incomplete_len), dim=-1)
1115
+
1116
+ bsz, q_len = input_ids.size()
1117
+ input_ids_with_mem = input_ids_with_mem.view(bsz, -1, mem_freq)
1118
+ input_ids_with_mem = torch.cat(
1119
+ (
1120
+ input_ids_with_mem,
1121
+ input_ids_with_mem.new_full((bsz, input_ids_with_mem.shape[1], 1), self.config.mem_id)
1122
+ ),
1123
+ dim=-1
1124
+ ).view(bsz, -1)
1125
+ input_ids = torch.cat((prev_input, input_ids_with_mem, input_ids_without_mem), dim=-1)
1126
+ if attention_mask is not None:
1127
+ attention_mask_with_mem, attention_mask_without_mem = torch.split(attention_mask, (prev_complete_len + new_full_len, incomplete_len), dim=-1)
1128
+ attention_mask_with_mem = attention_mask_with_mem.view(bsz, -1, mem_freq)
1129
+ attention_mask_with_mem = torch.cat(
1130
+ (
1131
+ attention_mask_with_mem,
1132
+ attention_mask_with_mem.new_ones((bsz, attention_mask_with_mem.shape[1], 1))
1133
+ ),
1134
+ dim=-1
1135
+ ).view(bsz, -1)
1136
+ attention_mask = torch.cat((attention_mask_with_mem, attention_mask_without_mem), dim=-1)
1137
+
1138
+
1139
+ input_ids = input_ids[:, prev_len:]
1140
+ if attention_mask is not None and position_ids is None:
1141
+ # create position_ids on the fly for batch generation
1142
+ position_ids = attention_mask.long().cumsum(-1) - 1
1143
+ position_ids.masked_fill_(attention_mask == 0, 1)
1144
+ position_ids = position_ids[:, -input_ids.shape[1]:].unsqueeze(-1)
1145
+
1146
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1147
+ if inputs_embeds is not None and past_key_values is None and mem_freq is None:
1148
+ model_inputs = {"inputs_embeds": inputs_embeds}
1149
+ else:
1150
+ model_inputs = {"input_ids": input_ids}
1151
+
1152
+ model_inputs.update(
1153
+ {
1154
+ "position_ids": position_ids,
1155
+ "past_key_values": past_key_values,
1156
+ "use_cache": kwargs.get("use_cache"),
1157
+ "attention_mask": attention_mask,
1158
+ "offload_cache_to_cpu": kwargs.get("offload_cache_to_cpu"),
1159
+ "use_flash": use_flash,
1160
+ "cache_top_k": kwargs.get("cache_top_k"),
1161
+ "max_chunk_length": kwargs.get("max_chunk_length", 0),
1162
+ "mem_freq": mem_freq,
1163
+ "drop_last_logit_if_mem": True,
1164
+ }
1165
+ )
1166
+ return model_inputs
1167
+
1168
+ @staticmethod
1169
+ def _reorder_cache(past_key_values, beam_idx):
1170
+ reordered_past = ()
1171
+ for layer_past in past_key_values:
1172
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1173
+ return reordered_past
1174
+
1175
+
1176
+ @add_start_docstrings(
1177
+ """
1178
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1179
+
1180
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1181
+ (e.g. GPT-2) do.
1182
+
1183
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1184
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1185
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1186
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1187
+ each row of the batch).
1188
+ """,
1189
+ LLAMA_START_DOCSTRING,
1190
+ )
1191
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1192
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1193
+
1194
+ def __init__(self, config):
1195
+ super().__init__(config)
1196
+ self.num_labels = config.num_labels
1197
+ self.model = LlamaModel(config)
1198
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1199
+
1200
+ # Initialize weights and apply final processing
1201
+ self.post_init()
1202
+
1203
+ def get_input_embeddings(self):
1204
+ return self.model.embed_tokens
1205
+
1206
+ def set_input_embeddings(self, value):
1207
+ self.model.embed_tokens = value
1208
+
1209
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1210
+ def forward(
1211
+ self,
1212
+ input_ids: torch.LongTensor = None,
1213
+ attention_mask: Optional[torch.Tensor] = None,
1214
+ position_ids: Optional[torch.LongTensor] = None,
1215
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1216
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1217
+ labels: Optional[torch.LongTensor] = None,
1218
+ use_cache: Optional[bool] = None,
1219
+ output_attentions: Optional[bool] = None,
1220
+ output_hidden_states: Optional[bool] = None,
1221
+ return_dict: Optional[bool] = None,
1222
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1223
+ r"""
1224
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1225
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1226
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1227
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1228
+ """
1229
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1230
+
1231
+ transformer_outputs = self.model(
1232
+ input_ids,
1233
+ attention_mask=attention_mask,
1234
+ position_ids=position_ids,
1235
+ past_key_values=past_key_values,
1236
+ inputs_embeds=inputs_embeds,
1237
+ use_cache=use_cache,
1238
+ output_attentions=output_attentions,
1239
+ output_hidden_states=output_hidden_states,
1240
+ return_dict=return_dict,
1241
+ )
1242
+ hidden_states = transformer_outputs[0]
1243
+ logits = self.score(hidden_states)
1244
+
1245
+ if input_ids is not None:
1246
+ batch_size = input_ids.shape[0]
1247
+ else:
1248
+ batch_size = inputs_embeds.shape[0]
1249
+
1250
+ if self.config.pad_token_id is None and batch_size != 1:
1251
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1252
+ if self.config.pad_token_id is None:
1253
+ sequence_lengths = -1
1254
+ else:
1255
+ if input_ids is not None:
1256
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
1257
+ else:
1258
+ sequence_lengths = -1
1259
+
1260
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1261
+
1262
+ loss = None
1263
+ if labels is not None:
1264
+ labels = labels.to(logits.device)
1265
+ if self.config.problem_type is None:
1266
+ if self.num_labels == 1:
1267
+ self.config.problem_type = "regression"
1268
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1269
+ self.config.problem_type = "single_label_classification"
1270
+ else:
1271
+ self.config.problem_type = "multi_label_classification"
1272
+
1273
+ if self.config.problem_type == "regression":
1274
+ loss_fct = MSELoss()
1275
+ if self.num_labels == 1:
1276
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1277
+ else:
1278
+ loss = loss_fct(pooled_logits, labels)
1279
+ elif self.config.problem_type == "single_label_classification":
1280
+ loss_fct = CrossEntropyLoss()
1281
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1282
+ elif self.config.problem_type == "multi_label_classification":
1283
+ loss_fct = BCEWithLogitsLoss()
1284
+ loss = loss_fct(pooled_logits, labels)
1285
+ if not return_dict:
1286
+ output = (pooled_logits,) + transformer_outputs[1:]
1287
+ return ((loss,) + output) if loss is not None else output
1288
+
1289
+ return SequenceClassifierOutputWithPast(
1290
+ loss=loss,
1291
+ logits=pooled_logits,
1292
+ past_key_values=transformer_outputs.past_key_values,
1293
+ hidden_states=transformer_outputs.hidden_states,
1294
+ attentions=transformer_outputs.attentions,
1295
+ )
code/llama_orig.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from ...activations import ACT2FN
30
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
+ from ...modeling_utils import PreTrainedModel
32
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
+ from .configuration_llama import LlamaConfig
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = "LlamaConfig"
39
+
40
+
41
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42
+ def _make_causal_mask(
43
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44
+ ):
45
+ """
46
+ Make causal mask used for bi-directional self-attention.
47
+ """
48
+ bsz, tgt_len = input_ids_shape
49
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
50
+ mask_cond = torch.arange(mask.size(-1), device=device)
51
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52
+ mask = mask.to(dtype)
53
+
54
+ if past_key_values_length > 0:
55
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
60
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
+ """
62
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
+ """
64
+ bsz, src_len = mask.size()
65
+ tgt_len = tgt_len if tgt_len is not None else src_len
66
+
67
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+
69
+ inverted_mask = 1.0 - expanded_mask
70
+
71
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72
+
73
+
74
+ class LlamaRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ LlamaRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ input_dtype = hidden_states.dtype
85
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
86
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
87
+
88
+ return (self.weight * hidden_states).to(input_dtype)
89
+
90
+
91
+ class LlamaRotaryEmbedding(torch.nn.Module):
92
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
93
+ super().__init__()
94
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
95
+ self.register_buffer("inv_freq", inv_freq)
96
+
97
+ # Build here to make `torch.jit.trace` work.
98
+ self.max_seq_len_cached = max_position_embeddings
99
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
100
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
101
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
102
+ emb = torch.cat((freqs, freqs), dim=-1)
103
+ dtype = torch.get_default_dtype()
104
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
105
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
106
+
107
+ def forward(self, x, seq_len=None):
108
+ # x: [bs, num_attention_heads, seq_len, head_size]
109
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
110
+ if seq_len > self.max_seq_len_cached:
111
+ self.max_seq_len_cached = seq_len
112
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
113
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
114
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
115
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
116
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
117
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
118
+ return (
119
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
120
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
121
+ )
122
+
123
+
124
+ def rotate_half(x):
125
+ """Rotates half the hidden dims of the input."""
126
+ x1 = x[..., : x.shape[-1] // 2]
127
+ x2 = x[..., x.shape[-1] // 2 :]
128
+ return torch.cat((-x2, x1), dim=-1)
129
+
130
+
131
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
132
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
133
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
134
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
135
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
136
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
137
+ q_embed = (q * cos) + (rotate_half(q) * sin)
138
+ k_embed = (k * cos) + (rotate_half(k) * sin)
139
+ return q_embed, k_embed
140
+
141
+
142
+ class LlamaMLP(nn.Module):
143
+ def __init__(
144
+ self,
145
+ hidden_size: int,
146
+ intermediate_size: int,
147
+ hidden_act: str,
148
+ ):
149
+ super().__init__()
150
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
151
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
152
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
153
+ self.act_fn = ACT2FN[hidden_act]
154
+
155
+ def forward(self, x):
156
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
157
+
158
+
159
+ class LlamaAttention(nn.Module):
160
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
161
+
162
+ def __init__(self, config: LlamaConfig):
163
+ super().__init__()
164
+ self.config = config
165
+ self.hidden_size = config.hidden_size
166
+ self.num_heads = config.num_attention_heads
167
+ self.head_dim = self.hidden_size // self.num_heads
168
+ self.max_position_embeddings = config.max_position_embeddings
169
+
170
+ if (self.head_dim * self.num_heads) != self.hidden_size:
171
+ raise ValueError(
172
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
173
+ f" and `num_heads`: {self.num_heads})."
174
+ )
175
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
176
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
177
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
178
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
179
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
180
+
181
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
182
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
183
+
184
+ def forward(
185
+ self,
186
+ hidden_states: torch.Tensor,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ position_ids: Optional[torch.LongTensor] = None,
189
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
190
+ output_attentions: bool = False,
191
+ use_cache: bool = False,
192
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
193
+ bsz, q_len, _ = hidden_states.size()
194
+
195
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
196
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
197
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198
+
199
+ kv_seq_len = key_states.shape[-2]
200
+ if past_key_value is not None:
201
+ kv_seq_len += past_key_value[0].shape[-2]
202
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
203
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
204
+ # [bsz, nh, t, hd]
205
+
206
+ if past_key_value is not None:
207
+ # reuse k, v, self_attention
208
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
209
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
210
+
211
+ past_key_value = (key_states, value_states) if use_cache else None
212
+
213
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
214
+
215
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
216
+ raise ValueError(
217
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
218
+ f" {attn_weights.size()}"
219
+ )
220
+
221
+ if attention_mask is not None:
222
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
223
+ raise ValueError(
224
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
225
+ )
226
+ attn_weights = attn_weights + attention_mask
227
+ attn_weights = torch.max(
228
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
229
+ )
230
+
231
+ # upcast attention to fp32
232
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
233
+ attn_output = torch.matmul(attn_weights, value_states)
234
+
235
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
236
+ raise ValueError(
237
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
238
+ f" {attn_output.size()}"
239
+ )
240
+
241
+ attn_output = attn_output.transpose(1, 2)
242
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
243
+
244
+ attn_output = self.o_proj(attn_output)
245
+
246
+ if not output_attentions:
247
+ attn_weights = None
248
+
249
+ return attn_output, attn_weights, past_key_value
250
+
251
+
252
+ class LlamaDecoderLayer(nn.Module):
253
+ def __init__(self, config: LlamaConfig):
254
+ super().__init__()
255
+ self.hidden_size = config.hidden_size
256
+ self.self_attn = LlamaAttention(config=config)
257
+ self.mlp = LlamaMLP(
258
+ hidden_size=self.hidden_size,
259
+ intermediate_size=config.intermediate_size,
260
+ hidden_act=config.hidden_act,
261
+ )
262
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
264
+
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ attention_mask: Optional[torch.Tensor] = None,
269
+ position_ids: Optional[torch.LongTensor] = None,
270
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
271
+ output_attentions: Optional[bool] = False,
272
+ use_cache: Optional[bool] = False,
273
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
274
+ """
275
+ Args:
276
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
277
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
278
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
279
+ output_attentions (`bool`, *optional*):
280
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
281
+ returned tensors for more detail.
282
+ use_cache (`bool`, *optional*):
283
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
284
+ (see `past_key_values`).
285
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
286
+ """
287
+
288
+ residual = hidden_states
289
+
290
+ hidden_states = self.input_layernorm(hidden_states)
291
+
292
+ # Self Attention
293
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
294
+ hidden_states=hidden_states,
295
+ attention_mask=attention_mask,
296
+ position_ids=position_ids,
297
+ past_key_value=past_key_value,
298
+ output_attentions=output_attentions,
299
+ use_cache=use_cache,
300
+ )
301
+ hidden_states = residual + hidden_states
302
+
303
+ # Fully Connected
304
+ residual = hidden_states
305
+ hidden_states = self.post_attention_layernorm(hidden_states)
306
+ hidden_states = self.mlp(hidden_states)
307
+ hidden_states = residual + hidden_states
308
+
309
+ outputs = (hidden_states,)
310
+
311
+ if output_attentions:
312
+ outputs += (self_attn_weights,)
313
+
314
+ if use_cache:
315
+ outputs += (present_key_value,)
316
+
317
+ return outputs
318
+
319
+
320
+ LLAMA_START_DOCSTRING = r"""
321
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
322
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
323
+ etc.)
324
+
325
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
326
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
327
+ and behavior.
328
+
329
+ Parameters:
330
+ config ([`LlamaConfig`]):
331
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
332
+ load the weights associated with the model, only the configuration. Check out the
333
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
334
+ """
335
+
336
+
337
+ @add_start_docstrings(
338
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
339
+ LLAMA_START_DOCSTRING,
340
+ )
341
+ class LlamaPreTrainedModel(PreTrainedModel):
342
+ config_class = LlamaConfig
343
+ base_model_prefix = "model"
344
+ supports_gradient_checkpointing = True
345
+ _no_split_modules = ["LlamaDecoderLayer"]
346
+ _skip_keys_device_placement = "past_key_values"
347
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
348
+
349
+ def _init_weights(self, module):
350
+ std = self.config.initializer_range
351
+ if isinstance(module, nn.Linear):
352
+ module.weight.data.normal_(mean=0.0, std=std)
353
+ if module.bias is not None:
354
+ module.bias.data.zero_()
355
+ elif isinstance(module, nn.Embedding):
356
+ module.weight.data.normal_(mean=0.0, std=std)
357
+ if module.padding_idx is not None:
358
+ module.weight.data[module.padding_idx].zero_()
359
+
360
+ def _set_gradient_checkpointing(self, module, value=False):
361
+ if isinstance(module, LlamaModel):
362
+ module.gradient_checkpointing = value
363
+
364
+
365
+ LLAMA_INPUTS_DOCSTRING = r"""
366
+ Args:
367
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
368
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
369
+ it.
370
+
371
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
372
+ [`PreTrainedTokenizer.__call__`] for details.
373
+
374
+ [What are input IDs?](../glossary#input-ids)
375
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
376
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
377
+
378
+ - 1 for tokens that are **not masked**,
379
+ - 0 for tokens that are **masked**.
380
+
381
+ [What are attention masks?](../glossary#attention-mask)
382
+
383
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
384
+ [`PreTrainedTokenizer.__call__`] for details.
385
+
386
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
387
+ `past_key_values`).
388
+
389
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
390
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
391
+ information on the default strategy.
392
+
393
+ - 1 indicates the head is **not masked**,
394
+ - 0 indicates the head is **masked**.
395
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
396
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
397
+ config.n_positions - 1]`.
398
+
399
+ [What are position IDs?](../glossary#position-ids)
400
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
401
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
402
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
403
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
404
+
405
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
406
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
407
+
408
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
409
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
410
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
411
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
412
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
413
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
414
+ model's internal embedding lookup matrix.
415
+ use_cache (`bool`, *optional*):
416
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
417
+ `past_key_values`).
418
+ output_attentions (`bool`, *optional*):
419
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
420
+ tensors for more detail.
421
+ output_hidden_states (`bool`, *optional*):
422
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
423
+ more detail.
424
+ return_dict (`bool`, *optional*):
425
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
426
+ """
427
+
428
+
429
+ @add_start_docstrings(
430
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
431
+ LLAMA_START_DOCSTRING,
432
+ )
433
+ class LlamaModel(LlamaPreTrainedModel):
434
+ """
435
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
436
+
437
+ Args:
438
+ config: LlamaConfig
439
+ """
440
+
441
+ def __init__(self, config: LlamaConfig):
442
+ super().__init__(config)
443
+ self.padding_idx = config.pad_token_id
444
+ self.vocab_size = config.vocab_size
445
+
446
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
447
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
448
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
449
+
450
+ self.gradient_checkpointing = False
451
+ # Initialize weights and apply final processing
452
+ self.post_init()
453
+
454
+ def get_input_embeddings(self):
455
+ return self.embed_tokens
456
+
457
+ def set_input_embeddings(self, value):
458
+ self.embed_tokens = value
459
+
460
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
461
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
462
+ # create causal mask
463
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
464
+ combined_attention_mask = None
465
+ if input_shape[-1] > 1:
466
+ combined_attention_mask = _make_causal_mask(
467
+ input_shape,
468
+ inputs_embeds.dtype,
469
+ device=inputs_embeds.device,
470
+ past_key_values_length=past_key_values_length,
471
+ )
472
+
473
+ if attention_mask is not None:
474
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
475
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
476
+ inputs_embeds.device
477
+ )
478
+ combined_attention_mask = (
479
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
480
+ )
481
+
482
+ return combined_attention_mask
483
+
484
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
485
+ def forward(
486
+ self,
487
+ input_ids: torch.LongTensor = None,
488
+ attention_mask: Optional[torch.Tensor] = None,
489
+ position_ids: Optional[torch.LongTensor] = None,
490
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
491
+ inputs_embeds: Optional[torch.FloatTensor] = None,
492
+ use_cache: Optional[bool] = None,
493
+ output_attentions: Optional[bool] = None,
494
+ output_hidden_states: Optional[bool] = None,
495
+ return_dict: Optional[bool] = None,
496
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
497
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
498
+ output_hidden_states = (
499
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
500
+ )
501
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
502
+
503
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
504
+
505
+ # retrieve input_ids and inputs_embeds
506
+ if input_ids is not None and inputs_embeds is not None:
507
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
508
+ elif input_ids is not None:
509
+ batch_size, seq_length = input_ids.shape
510
+ elif inputs_embeds is not None:
511
+ batch_size, seq_length, _ = inputs_embeds.shape
512
+ else:
513
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
514
+
515
+ seq_length_with_past = seq_length
516
+ past_key_values_length = 0
517
+
518
+ if past_key_values is not None:
519
+ past_key_values_length = past_key_values[0][0].shape[2]
520
+ seq_length_with_past = seq_length_with_past + past_key_values_length
521
+
522
+ if position_ids is None:
523
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
524
+ position_ids = torch.arange(
525
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
526
+ )
527
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
528
+ else:
529
+ position_ids = position_ids.view(-1, seq_length).long()
530
+
531
+ if inputs_embeds is None:
532
+ inputs_embeds = self.embed_tokens(input_ids)
533
+ # embed positions
534
+ if attention_mask is None:
535
+ attention_mask = torch.ones(
536
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
537
+ )
538
+ attention_mask = self._prepare_decoder_attention_mask(
539
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
540
+ )
541
+
542
+ hidden_states = inputs_embeds
543
+
544
+ if self.gradient_checkpointing and self.training:
545
+ if use_cache:
546
+ logger.warning_once(
547
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
548
+ )
549
+ use_cache = False
550
+
551
+ # decoder layers
552
+ all_hidden_states = () if output_hidden_states else None
553
+ all_self_attns = () if output_attentions else None
554
+ next_decoder_cache = () if use_cache else None
555
+
556
+ for idx, decoder_layer in enumerate(self.layers):
557
+ if output_hidden_states:
558
+ all_hidden_states += (hidden_states,)
559
+
560
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
561
+
562
+ if self.gradient_checkpointing and self.training:
563
+
564
+ def create_custom_forward(module):
565
+ def custom_forward(*inputs):
566
+ # None for past_key_value
567
+ return module(*inputs, output_attentions, None)
568
+
569
+ return custom_forward
570
+
571
+ layer_outputs = torch.utils.checkpoint.checkpoint(
572
+ create_custom_forward(decoder_layer),
573
+ hidden_states,
574
+ attention_mask,
575
+ position_ids,
576
+ None,
577
+ )
578
+ else:
579
+ layer_outputs = decoder_layer(
580
+ hidden_states,
581
+ attention_mask=attention_mask,
582
+ position_ids=position_ids,
583
+ past_key_value=past_key_value,
584
+ output_attentions=output_attentions,
585
+ use_cache=use_cache,
586
+ )
587
+
588
+ hidden_states = layer_outputs[0]
589
+
590
+ if use_cache:
591
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
592
+
593
+ if output_attentions:
594
+ all_self_attns += (layer_outputs[1],)
595
+
596
+ hidden_states = self.norm(hidden_states)
597
+
598
+ # add hidden states from the last decoder layer
599
+ if output_hidden_states:
600
+ all_hidden_states += (hidden_states,)
601
+
602
+ next_cache = next_decoder_cache if use_cache else None
603
+ if not return_dict:
604
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
605
+ return BaseModelOutputWithPast(
606
+ last_hidden_state=hidden_states,
607
+ past_key_values=next_cache,
608
+ hidden_states=all_hidden_states,
609
+ attentions=all_self_attns,
610
+ )
611
+
612
+
613
+ class LlamaForCausalLM(LlamaPreTrainedModel):
614
+ _tied_weights_keys = ["lm_head.weight"]
615
+
616
+ def __init__(self, config):
617
+ super().__init__(config)
618
+ self.model = LlamaModel(config)
619
+
620
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
621
+
622
+ # Initialize weights and apply final processing
623
+ self.post_init()
624
+
625
+ def get_input_embeddings(self):
626
+ return self.model.embed_tokens
627
+
628
+ def set_input_embeddings(self, value):
629
+ self.model.embed_tokens = value
630
+
631
+ def get_output_embeddings(self):
632
+ return self.lm_head
633
+
634
+ def set_output_embeddings(self, new_embeddings):
635
+ self.lm_head = new_embeddings
636
+
637
+ def set_decoder(self, decoder):
638
+ self.model = decoder
639
+
640
+ def get_decoder(self):
641
+ return self.model
642
+
643
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
644
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
645
+ def forward(
646
+ self,
647
+ input_ids: torch.LongTensor = None,
648
+ attention_mask: Optional[torch.Tensor] = None,
649
+ position_ids: Optional[torch.LongTensor] = None,
650
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
651
+ inputs_embeds: Optional[torch.FloatTensor] = None,
652
+ labels: Optional[torch.LongTensor] = None,
653
+ use_cache: Optional[bool] = None,
654
+ output_attentions: Optional[bool] = None,
655
+ output_hidden_states: Optional[bool] = None,
656
+ return_dict: Optional[bool] = None,
657
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
658
+ r"""
659
+ Args:
660
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
661
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
662
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
663
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
664
+
665
+ Returns:
666
+
667
+ Example:
668
+
669
+ ```python
670
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
671
+
672
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
673
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
674
+
675
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
676
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
677
+
678
+ >>> # Generate
679
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
680
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
681
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
682
+ ```"""
683
+
684
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
685
+ output_hidden_states = (
686
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
687
+ )
688
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
689
+
690
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
691
+ outputs = self.model(
692
+ input_ids=input_ids,
693
+ attention_mask=attention_mask,
694
+ position_ids=position_ids,
695
+ past_key_values=past_key_values,
696
+ inputs_embeds=inputs_embeds,
697
+ use_cache=use_cache,
698
+ output_attentions=output_attentions,
699
+ output_hidden_states=output_hidden_states,
700
+ return_dict=return_dict,
701
+ )
702
+
703
+ hidden_states = outputs[0]
704
+ logits = self.lm_head(hidden_states)
705
+
706
+ loss = None
707
+ if labels is not None:
708
+ # Shift so that tokens < n predict n
709
+ shift_logits = logits[..., :-1, :].contiguous()
710
+ shift_labels = labels[..., 1:].contiguous()
711
+ # Flatten the tokens
712
+ loss_fct = CrossEntropyLoss()
713
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
714
+ shift_labels = shift_labels.view(-1)
715
+ # Enable model parallelism
716
+ shift_labels = shift_labels.to(shift_logits.device)
717
+ loss = loss_fct(shift_logits, shift_labels)
718
+
719
+ if not return_dict:
720
+ output = (logits,) + outputs[1:]
721
+ return (loss,) + output if loss is not None else output
722
+
723
+ return CausalLMOutputWithPast(
724
+ loss=loss,
725
+ logits=logits,
726
+ past_key_values=outputs.past_key_values,
727
+ hidden_states=outputs.hidden_states,
728
+ attentions=outputs.attentions,
729
+ )
730
+
731
+ def prepare_inputs_for_generation(
732
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
733
+ ):
734
+ if past_key_values:
735
+ input_ids = input_ids[:, -1:]
736
+
737
+ position_ids = kwargs.get("position_ids", None)
738
+ if attention_mask is not None and position_ids is None:
739
+ # create position_ids on the fly for batch generation
740
+ position_ids = attention_mask.long().cumsum(-1) - 1
741
+ position_ids.masked_fill_(attention_mask == 0, 1)
742
+ if past_key_values:
743
+ position_ids = position_ids[:, -1].unsqueeze(-1)
744
+
745
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
746
+ if inputs_embeds is not None and past_key_values is None:
747
+ model_inputs = {"inputs_embeds": inputs_embeds}
748
+ else:
749
+ model_inputs = {"input_ids": input_ids}
750
+
751
+ model_inputs.update(
752
+ {
753
+ "position_ids": position_ids,
754
+ "past_key_values": past_key_values,
755
+ "use_cache": kwargs.get("use_cache"),
756
+ "attention_mask": attention_mask,
757
+ }
758
+ )
759
+ return model_inputs
760
+
761
+ @staticmethod
762
+ def _reorder_cache(past_key_values, beam_idx):
763
+ reordered_past = ()
764
+ for layer_past in past_key_values:
765
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
766
+ return reordered_past
767
+
768
+
769
+ @add_start_docstrings(
770
+ """
771
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
772
+
773
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
774
+ (e.g. GPT-2) do.
775
+
776
+ Since it does classification on the last token, it requires to know the position of the last token. If a
777
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
778
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
779
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
780
+ each row of the batch).
781
+ """,
782
+ LLAMA_START_DOCSTRING,
783
+ )
784
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
785
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
786
+
787
+ def __init__(self, config):
788
+ super().__init__(config)
789
+ self.num_labels = config.num_labels
790
+ self.model = LlamaModel(config)
791
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
792
+
793
+ # Initialize weights and apply final processing
794
+ self.post_init()
795
+
796
+ def get_input_embeddings(self):
797
+ return self.model.embed_tokens
798
+
799
+ def set_input_embeddings(self, value):
800
+ self.model.embed_tokens = value
801
+
802
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
803
+ def forward(
804
+ self,
805
+ input_ids: torch.LongTensor = None,
806
+ attention_mask: Optional[torch.Tensor] = None,
807
+ position_ids: Optional[torch.LongTensor] = None,
808
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
809
+ inputs_embeds: Optional[torch.FloatTensor] = None,
810
+ labels: Optional[torch.LongTensor] = None,
811
+ use_cache: Optional[bool] = None,
812
+ output_attentions: Optional[bool] = None,
813
+ output_hidden_states: Optional[bool] = None,
814
+ return_dict: Optional[bool] = None,
815
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
816
+ r"""
817
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
818
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
819
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
820
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
821
+ """
822
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
823
+
824
+ transformer_outputs = self.model(
825
+ input_ids,
826
+ attention_mask=attention_mask,
827
+ position_ids=position_ids,
828
+ past_key_values=past_key_values,
829
+ inputs_embeds=inputs_embeds,
830
+ use_cache=use_cache,
831
+ output_attentions=output_attentions,
832
+ output_hidden_states=output_hidden_states,
833
+ return_dict=return_dict,
834
+ )
835
+ hidden_states = transformer_outputs[0]
836
+ logits = self.score(hidden_states)
837
+
838
+ if input_ids is not None:
839
+ batch_size = input_ids.shape[0]
840
+ else:
841
+ batch_size = inputs_embeds.shape[0]
842
+
843
+ if self.config.pad_token_id is None and batch_size != 1:
844
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
845
+ if self.config.pad_token_id is None:
846
+ sequence_lengths = -1
847
+ else:
848
+ if input_ids is not None:
849
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
850
+ else:
851
+ sequence_lengths = -1
852
+
853
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
854
+
855
+ loss = None
856
+ if labels is not None:
857
+ labels = labels.to(logits.device)
858
+ if self.config.problem_type is None:
859
+ if self.num_labels == 1:
860
+ self.config.problem_type = "regression"
861
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
862
+ self.config.problem_type = "single_label_classification"
863
+ else:
864
+ self.config.problem_type = "multi_label_classification"
865
+
866
+ if self.config.problem_type == "regression":
867
+ loss_fct = MSELoss()
868
+ if self.num_labels == 1:
869
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
870
+ else:
871
+ loss = loss_fct(pooled_logits, labels)
872
+ elif self.config.problem_type == "single_label_classification":
873
+ loss_fct = CrossEntropyLoss()
874
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
875
+ elif self.config.problem_type == "multi_label_classification":
876
+ loss_fct = BCEWithLogitsLoss()
877
+ loss = loss_fct(pooled_logits, labels)
878
+ if not return_dict:
879
+ output = (pooled_logits,) + transformer_outputs[1:]
880
+ return ((loss,) + output) if loss is not None else output
881
+
882
+ return SequenceClassifierOutputWithPast(
883
+ loss=loss,
884
+ logits=pooled_logits,
885
+ past_key_values=transformer_outputs.past_key_values,
886
+ hidden_states=transformer_outputs.hidden_states,
887
+ attentions=transformer_outputs.attentions,
888
+ )
code/modelling_RW.py ADDED
@@ -0,0 +1,1362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # port of models described in RW
2
+ # We use the bloom model as a starting point for these model.
3
+ # Please refer to the bloom models for usage instructions.
4
+
5
+ import math
6
+ import warnings
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
+ from torch.nn import functional as F
14
+
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPastAndCrossAttentions,
17
+ CausalLMOutputWithCrossAttentions,
18
+ QuestionAnsweringModelOutput,
19
+ SequenceClassifierOutputWithPast,
20
+ TokenClassifierOutput,
21
+ )
22
+ from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import logging
24
+ from configuration_RW import RWConfig
25
+ from ltriton.flash_landmark_attention import fused_landmark_attention
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
30
+ # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
31
+ class Linear(nn.Linear):
32
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
33
+ ret = input @ self.weight.T
34
+ if self.bias is None:
35
+ return ret
36
+ else:
37
+ return ret + self.bias
38
+
39
+
40
+ from einops import rearrange
41
+
42
+ # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
43
+ def rotate_half(x):
44
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
45
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
46
+
47
+
48
+ class RotaryEmbedding(torch.nn.Module):
49
+ """Implementation of RotaryEmbedding from GPT-NeoX.
50
+ This implementation is design to operate on queries and keys that are compatible with
51
+ [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ head_dim: int,
57
+ base=10000,
58
+ ):
59
+ super().__init__()
60
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
61
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
62
+ self.head_dim = head_dim
63
+ self.seq_len_cached = None
64
+ self.batch_size_cached = None
65
+ self.cos_cached: torch.Tensor | None = None
66
+ self.sin_cached: torch.Tensor | None = None
67
+
68
+ def cos_sin(
69
+ self,
70
+ seq_len: int,
71
+ device="cuda",
72
+ dtype=torch.bfloat16,
73
+ ) -> torch.Tensor:
74
+ if seq_len != self.seq_len_cached:
75
+ self.seq_len_cached = seq_len
76
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
77
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
78
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
79
+
80
+ if dtype in [torch.float16, torch.bfloat16]:
81
+ emb = emb.float()
82
+
83
+ self.cos_cached = emb.cos()[None, :, :]
84
+ self.sin_cached = emb.sin()[None, :, :]
85
+
86
+ self.cos_cached = self.cos_cached.type(dtype)
87
+ self.sin_cached = self.sin_cached.type(dtype)
88
+
89
+ return self.cos_cached, self.sin_cached
90
+
91
+ def forward(self, q, k):
92
+ batch, seq_len, head_dim = q.shape
93
+ cos, sin = self.cos_sin(seq_len)
94
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
95
+
96
+
97
+ def _make_causal_mask(
98
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
99
+ ) -> torch.BoolTensor:
100
+ batch_size, target_length = input_ids_shape
101
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
102
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
103
+ seq_ids = torch.arange(target_length, device=device)
104
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
105
+
106
+ if past_key_values_length > 0:
107
+ mask[:, :past_key_values_length] = False
108
+
109
+ expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
110
+ return expanded_mask
111
+
112
+
113
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
114
+ batch_size, src_length = mask.shape
115
+ tgt_length = tgt_length if tgt_length is not None else src_length
116
+
117
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
118
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
119
+
120
+
121
+ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
122
+ batch_size, seq_length = attention_mask.shape
123
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
124
+ base = torch.tensor(
125
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
126
+ )
127
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
128
+ slopes = torch.pow(base, powers)
129
+
130
+ if closest_power_of_2 != num_heads:
131
+ extra_base = torch.tensor(
132
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
133
+ )
134
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
135
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
136
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
137
+
138
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
139
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
140
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
141
+ # => the query_length dimension will then be broadcasted correctly
142
+ # This is more or less identical to T5's relative position bias:
143
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
144
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
145
+ alibi = slopes[..., None].bfloat16() * arange_tensor
146
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
147
+
148
+
149
+ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
150
+ out = F.dropout(x, p=prob, training=training)
151
+ out = residual + out
152
+ return out
153
+
154
+
155
+ class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
156
+
157
+ # Note that forward, setup_context, and backward are @staticmethods
158
+ @staticmethod
159
+ def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
160
+ new_shape = list(x.shape)
161
+ new_shape[dim] = mem_cnt # max_mem_cnt.item()
162
+ max_by_group = x.new_zeros((*new_shape,))
163
+ max_by_group.scatter_reduce_(src=x, index=resp_mem_idx, dim=dim, reduce="amax", include_self=False)
164
+
165
+ maxes = torch.gather(max_by_group, dim, resp_mem_idx)
166
+ #x_exp = torch.exp(x - torch.where(torch.isinf(maxes), 0, maxes))
167
+ x_exp = torch.exp((x - maxes).to(torch.float32))
168
+
169
+ cumsum_by_group = torch.zeros_like(max_by_group, dtype=x_exp.dtype)
170
+
171
+ cumsum_by_group.scatter_add_(dim, resp_mem_idx, x_exp, )
172
+ denom = torch.gather(cumsum_by_group, dim, resp_mem_idx)
173
+
174
+ #probs = torch.where(denom < 0.5, 0, x_exp / denom)
175
+ probs = x_exp / denom
176
+
177
+
178
+ ctx.mem_cnt = mem_cnt
179
+ ctx.dim = dim
180
+ ctx.save_for_backward(resp_mem_idx, probs)
181
+
182
+ return probs
183
+
184
+ @staticmethod
185
+ def backward(ctx, grad_probs):
186
+ mem_cnt = ctx.mem_cnt
187
+ dim = ctx.dim
188
+ resp_mem_idx, probs = ctx.saved_tensors
189
+ grad_x = grad_dim = grad_mem_cnt = grad_resp_mem_idx = None
190
+
191
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[4]:
192
+ grad_pair = grad_probs * probs
193
+
194
+ new_shape = list(probs.shape)
195
+ new_shape[dim] = mem_cnt # max_mem_cnt.item()
196
+ cumsum_by_group = grad_pair.new_zeros((*new_shape,))
197
+ cumsum_by_group.scatter_add_(dim, resp_mem_idx, grad_pair)
198
+
199
+ if ctx.needs_input_grad[0]:
200
+ grad_sum = torch.gather(cumsum_by_group, dim, resp_mem_idx)
201
+ grad_x = grad_pair - probs * grad_sum
202
+ assert not ctx.needs_input_grad[1]
203
+ assert not ctx.needs_input_grad[2]
204
+ assert not ctx.needs_input_grad[3]
205
+
206
+ return grad_x, grad_dim, grad_mem_cnt, grad_resp_mem_idx
207
+
208
+
209
+ def landmark_grouped_softmax(x, dim, is_mem, last_section_mask):
210
+
211
+ last_and_rest_mask = last_section_mask # | mask
212
+
213
+ full_access_mask = is_mem | last_and_rest_mask
214
+
215
+ max_mem_cnt = 64
216
+ mem_group_idx = torch.cumsum(is_mem, dim=dim)
217
+ mem_bucket_id = max_mem_cnt - 1
218
+ resp_mem_idx = torch.where(last_and_rest_mask,
219
+ max_mem_cnt - 1,
220
+ torch.where(is_mem, mem_bucket_id, mem_group_idx))
221
+ probs = LandmarkGroupedSoftmaxFunction.apply(x, dim, max_mem_cnt, resp_mem_idx)
222
+
223
+ new_shape = list(x.shape)
224
+ new_shape[dim] = max_mem_cnt
225
+ group_prob = probs.new_zeros((*new_shape, ))
226
+ group_prob.scatter_(dim, torch.where(is_mem, mem_group_idx - 1, max_mem_cnt - 1), probs)
227
+ probs = probs.mul(torch.where(full_access_mask, last_section_mask, torch.gather(group_prob, dim, resp_mem_idx)))
228
+
229
+ return probs
230
+
231
+
232
+ class Attention(nn.Module):
233
+ def __init__(self, config: RWConfig):
234
+ super().__init__()
235
+
236
+ self.hidden_size = config.hidden_size
237
+ self.num_heads = config.n_head
238
+ self.head_dim = self.hidden_size // self.num_heads
239
+ self.split_size = self.hidden_size
240
+ self.hidden_dropout = config.hidden_dropout
241
+
242
+ if self.head_dim * self.num_heads != self.hidden_size:
243
+ raise ValueError(
244
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
245
+ f" {self.num_heads})."
246
+ )
247
+
248
+ self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
249
+
250
+ # Layer-wise attention scaling
251
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
252
+ self.beta = self.inv_norm_factor
253
+
254
+ self.query_key_value = Linear(
255
+ self.hidden_size,
256
+ 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
257
+ bias=config.bias,
258
+ )
259
+ self.multi_query = config.multi_query
260
+ self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
261
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
262
+ self.num_kv = config.n_head if not self.multi_query else 1
263
+
264
+ def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
265
+ """
266
+ Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
267
+ storage as `fused_qkv`
268
+
269
+ Args:
270
+ fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
271
+
272
+ Returns:
273
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
274
+ value: [batch_size, seq_length, num_heads, head_dim]
275
+ """
276
+ if not self.multi_query:
277
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
278
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
279
+ return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
280
+ else:
281
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
282
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
283
+ return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
284
+
285
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
286
+ """
287
+ Merge heads together over the last dimenstion
288
+
289
+ Args:
290
+ x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
291
+
292
+ Returns:
293
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
294
+ """
295
+ # What we want to achieve is:
296
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
297
+ batch_size_and_num_heads, seq_length, _ = x.shape
298
+ batch_size = batch_size_and_num_heads // self.num_heads
299
+
300
+ # First view to decompose the batch size
301
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
302
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
303
+
304
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
305
+ x = x.permute(0, 2, 1, 3)
306
+
307
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
308
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
309
+
310
+ def forward(
311
+ self,
312
+ hidden_states: torch.Tensor,
313
+ alibi: torch.Tensor,
314
+ attention_mask: torch.Tensor,
315
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
316
+ head_mask: Optional[torch.Tensor] = None,
317
+ use_cache: bool = False,
318
+ output_attentions: bool = False,
319
+ is_mem: Optional[torch.Tensor] = None,
320
+ last_section_mask: Optional[torch.Tensor] = None,
321
+ offload_cache_to_cpu: bool = False,
322
+ use_flash: bool = False,
323
+ cache_top_k: Optional[int] = None,
324
+ mem_freq: Optional[int] = None
325
+ ):
326
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
327
+
328
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
329
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
330
+
331
+ batch_size, q_length, _, _ = query_layer.shape
332
+
333
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
334
+ key_layer = key_layer.transpose(1, 2).reshape(
335
+ batch_size * self.num_kv,
336
+ q_length,
337
+ self.head_dim,
338
+ )
339
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
340
+
341
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
342
+
343
+ if layer_past is not None:
344
+ past_key, past_value = layer_past
345
+ # concatenate along seq_length dimension:
346
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
347
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
348
+ key_layer = torch.cat((past_key, key_layer), dim=1)
349
+ value_layer = torch.cat((past_value, value_layer), dim=1)
350
+
351
+ _, kv_length, _ = key_layer.shape
352
+
353
+ if use_cache is True:
354
+ present = (key_layer, value_layer)
355
+ else:
356
+ present = None
357
+
358
+ if alibi is None:
359
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
360
+ key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
361
+ value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
362
+
363
+ assert self.num_kv == 1
364
+
365
+ key_layer_ = key_layer_.expand(query_layer_.shape)
366
+ value_layer_ = value_layer_.expand(query_layer_.shape)
367
+
368
+ assert not output_attentions # not supported.
369
+ assert mem_freq is not None
370
+ attn_output = fused_landmark_attention(query_layer_, key_layer_, value_layer_, is_mem, block_size=mem_freq+1)
371
+
372
+ # attn_output = F.scaled_dot_product_attention(
373
+ # query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
374
+ # )
375
+
376
+ x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
377
+ x = x.permute(0, 2, 1, 3)
378
+ attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
379
+
380
+ output_tensor = self.dense(attn_output)
381
+
382
+ outputs = (output_tensor, present)
383
+ return outputs
384
+ else:
385
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
386
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
387
+
388
+ # change view to [batch_size, num_heads, q_length, kv_length]
389
+ attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
390
+
391
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
392
+ input_dtype = attention_scores.dtype
393
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
394
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
395
+ attention_scores = attention_scores.to(torch.float32)
396
+ # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
397
+ attention_probs = F.softmax(
398
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
399
+ dim=-1,
400
+ dtype=hidden_states.dtype,
401
+ )
402
+ # [batch_size, num_heads, q_length, kv_length]
403
+ attention_probs = self.attention_dropout(attention_probs)
404
+
405
+ if head_mask is not None:
406
+ attention_probs = attention_probs * head_mask
407
+
408
+ # change view [batch_size x num_heads, q_length, kv_length]
409
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
410
+
411
+ # matmul: [batch_size * num_heads, q_length, head_dim]
412
+ context_layer = attention_probs_reshaped @ value_layer
413
+
414
+ # change view [batch_size, num_heads, q_length, head_dim]
415
+ context_layer = self._merge_heads(context_layer)
416
+
417
+ output_tensor = self.dense(context_layer)
418
+
419
+ outputs = (output_tensor, present)
420
+ if output_attentions:
421
+ outputs += (attention_probs,)
422
+
423
+ return outputs
424
+
425
+
426
+ class MLP(nn.Module):
427
+ def __init__(self, config: RWConfig):
428
+ super().__init__()
429
+ hidden_size = config.hidden_size
430
+
431
+ self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
432
+ self.act = nn.GELU()
433
+ self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
434
+ self.hidden_dropout = config.hidden_dropout
435
+
436
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
437
+ x = self.act(self.dense_h_to_4h(x))
438
+ x = self.dense_4h_to_h(x)
439
+ return x
440
+
441
+
442
+ class DecoderLayer(nn.Module):
443
+ def __init__(self, config: RWConfig):
444
+ super().__init__()
445
+ hidden_size = config.hidden_size
446
+
447
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
448
+ self.num_heads = config.n_head
449
+ self.self_attention = Attention(config)
450
+
451
+ if not config.parallel_attn:
452
+ # unused if parallel attn
453
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
454
+
455
+ self.mlp = MLP(config)
456
+
457
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
458
+ self.hidden_dropout = config.hidden_dropout
459
+
460
+ self.config = config
461
+
462
+ def forward(
463
+ self,
464
+ hidden_states: torch.Tensor,
465
+ alibi: torch.Tensor,
466
+ attention_mask: torch.Tensor,
467
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
468
+ head_mask: Optional[torch.Tensor] = None,
469
+ use_cache: bool = False,
470
+ output_attentions: bool = False,
471
+ is_mem: Optional[torch.Tensor] = None,
472
+ last_section_mask: Optional[torch.Tensor] = None,
473
+ offload_cache_to_cpu: bool = False,
474
+ use_flash: bool = False,
475
+ cache_top_k: Optional[int] = None,
476
+ mem_freq: Optional[int] = None,
477
+ ):
478
+
479
+ layernorm_output = self.input_layernorm(hidden_states)
480
+ residual = hidden_states
481
+
482
+ # Self attention.
483
+ attn_outputs = self.self_attention(
484
+ layernorm_output,
485
+ layer_past=layer_past,
486
+ attention_mask=attention_mask,
487
+ alibi=alibi,
488
+ head_mask=head_mask,
489
+ use_cache=use_cache,
490
+ output_attentions=output_attentions,
491
+ is_mem=is_mem,
492
+ last_section_mask=last_section_mask,
493
+ offload_cache_to_cpu=offload_cache_to_cpu,
494
+ use_flash=use_flash,
495
+ cache_top_k=cache_top_k,
496
+ mem_freq=mem_freq
497
+ )
498
+
499
+ attention_output = attn_outputs[0]
500
+
501
+ if not self.config.parallel_attn:
502
+ residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
503
+ layernorm_output = self.post_attention_layernorm(residual)
504
+
505
+ outputs = attn_outputs[1:]
506
+
507
+ # MLP.
508
+ mlp_output = self.mlp(layernorm_output)
509
+
510
+ if self.config.parallel_attn:
511
+ mlp_output += attention_output
512
+
513
+ output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
514
+
515
+ if use_cache:
516
+ outputs = (output,) + outputs
517
+ else:
518
+ outputs = (output,) + outputs[1:]
519
+
520
+ return outputs # hidden_states, present, attentions
521
+
522
+
523
+ class RWPreTrainedModel(PreTrainedModel):
524
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
525
+ """
526
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
527
+ models.
528
+ """
529
+
530
+ config_class = RWConfig
531
+ base_model_prefix = "transformer"
532
+ supports_gradient_checkpointing = True
533
+ _no_split_modules = ["DecoderLayer"]
534
+
535
+ def __init__(self, *inputs, **kwargs):
536
+ super().__init__(*inputs, **kwargs)
537
+
538
+ def _init_weights(self, module: nn.Module):
539
+ """Initialize the weights."""
540
+ if isinstance(module, nn.Linear) or isinstance(module, Linear):
541
+ # Slightly different from the TF version which uses truncated_normal for initialization
542
+ # cf https://github.com/pytorch/pytorch/pull/5617
543
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
544
+ if module.bias is not None:
545
+ module.bias.data.zero_()
546
+ elif isinstance(module, nn.Embedding):
547
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
548
+ if module.padding_idx is not None:
549
+ module.weight.data[module.padding_idx].zero_()
550
+ elif isinstance(module, LayerNorm):
551
+ module.bias.data.zero_()
552
+ module.weight.data.fill_(1.0)
553
+
554
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
555
+ if isinstance(module, RWModel):
556
+ module.gradient_checkpointing = value
557
+
558
+ @staticmethod
559
+ def _convert_to_standard_cache(
560
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
561
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
562
+ """
563
+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
564
+ num_heads, ...]))
565
+ """
566
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
567
+ num_heads = batch_size_times_num_heads // batch_size
568
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
569
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
570
+ return tuple(
571
+ (
572
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
573
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
574
+ )
575
+ for layer_past in past_key_value
576
+ )
577
+
578
+ @staticmethod
579
+ def _convert_to_rw_cache(
580
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
581
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
582
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
583
+ batch_size_times_num_heads = batch_size * num_heads
584
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
585
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
586
+ return tuple(
587
+ (
588
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
589
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
590
+ )
591
+ for layer_past in past_key_value
592
+ )
593
+
594
+
595
+ class RWModel(RWPreTrainedModel):
596
+ def __init__(self, config: RWConfig):
597
+ super().__init__(config)
598
+
599
+ self.embed_dim = config.hidden_size
600
+ self.num_heads = config.n_head
601
+ self.alibi = config.alibi
602
+
603
+ # Embedding + LN Embedding
604
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
605
+
606
+ # Transformer blocks
607
+ self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
608
+
609
+ # Final Layer Norm
610
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
611
+
612
+ self.gradient_checkpointing = False
613
+
614
+ # Initialize weights and apply final processing
615
+ self.post_init()
616
+
617
+ def get_input_embeddings(self):
618
+ return self.word_embeddings
619
+
620
+ def _prepare_attn_mask(
621
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
622
+ ) -> torch.BoolTensor:
623
+ # create causal mask
624
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
625
+ combined_attention_mask = None
626
+ device = attention_mask.device
627
+ _, src_length = input_shape
628
+
629
+ if src_length > 1:
630
+ combined_attention_mask = _make_causal_mask(
631
+ input_shape, device=device, past_key_values_length=past_key_values_length
632
+ )
633
+
634
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
635
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
636
+ combined_attention_mask = (
637
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
638
+ )
639
+
640
+ return combined_attention_mask
641
+
642
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
643
+ self.word_embeddings = new_embeddings
644
+
645
+ def forward(
646
+ self,
647
+ input_ids: Optional[torch.LongTensor] = None,
648
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
649
+ attention_mask: Optional[torch.Tensor] = None,
650
+ head_mask: Optional[torch.LongTensor] = None,
651
+ inputs_embeds: Optional[torch.LongTensor] = None,
652
+ use_cache: Optional[bool] = None,
653
+ output_attentions: Optional[bool] = None,
654
+ output_hidden_states: Optional[bool] = None,
655
+ return_dict: Optional[bool] = None,
656
+ offload_cache_to_cpu: Optional[bool] = None,
657
+ use_flash: Optional[bool] = None,
658
+ cache_top_k: Optional[int] = None,
659
+ mem_freq: Optional[int] = None,
660
+ **deprecated_arguments,
661
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
662
+ if deprecated_arguments.pop("position_ids", False) is not False:
663
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
664
+ warnings.warn(
665
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
666
+ " passing `position_ids`.",
667
+ FutureWarning,
668
+ )
669
+ if len(deprecated_arguments) > 0:
670
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
671
+
672
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
673
+ output_hidden_states = (
674
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
675
+ )
676
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
677
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
678
+
679
+ is_mem = None
680
+ if input_ids is not None and inputs_embeds is not None:
681
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
682
+ elif input_ids is not None:
683
+ batch_size, seq_length = input_ids.shape
684
+ if self.config.mem_id is not None:
685
+ with torch.no_grad():
686
+ is_mem = input_ids == self.config.mem_id
687
+ elif inputs_embeds is not None:
688
+ batch_size, seq_length, _ = inputs_embeds.shape
689
+ if self.config.mem_id is not None:
690
+ raise NotImplementedError
691
+ else:
692
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
693
+
694
+ if past_key_values is None:
695
+ past_key_values = tuple([None] * len(self.h))
696
+
697
+ # Prepare head mask if needed
698
+ # 1.0 in head_mask indicate we keep the head
699
+ # attention_probs has shape batch_size x num_heads x N x N
700
+ # head_mask has shape n_layer x batch x num_heads x N x N
701
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
702
+
703
+ if inputs_embeds is None:
704
+ inputs_embeds = self.word_embeddings(input_ids)
705
+
706
+ last_section_mask = None
707
+ if is_mem is not None and not use_flash:
708
+ is_mem = is_mem.unsqueeze(1).unsqueeze(2)
709
+ current_len = input_ids.shape[1]
710
+ mem_ids = torch.where(attention_mask[..., -current_len:] < -1, 0, torch.cumsum(is_mem, -1) - is_mem.int())
711
+ last_section_mask = torch.amax(mem_ids, -1, keepdim=True) == mem_ids
712
+ attention_mask[..., -current_len:].masked_fill_(last_section_mask & is_mem, torch.tensor(torch.finfo(inputs_embeds.dtype).min, device=inputs_embeds.device))
713
+ last_section_mask.logical_and_(attention_mask[..., -current_len:] > -1)
714
+ is_mem = is_mem.logical_and(attention_mask[..., -current_len:] > -1)
715
+
716
+ hidden_states = inputs_embeds
717
+
718
+ presents = () if use_cache else None
719
+ all_self_attentions = () if output_attentions else None
720
+ all_hidden_states = () if output_hidden_states else None
721
+
722
+ # Compute alibi tensor: check build_alibi_tensor documentation
723
+ seq_length_with_past = seq_length
724
+ past_key_values_length = 0
725
+ if past_key_values[0] is not None:
726
+ past_key_values_length = past_key_values[0][0].shape[2]
727
+ if len(past_key_values[0]) > 2:
728
+ past_key_values_length += past_key_values[0][3].shape[2] * past_key_values[0][3].shape[3]
729
+ seq_length_with_past = seq_length_with_past + past_key_values_length
730
+ if attention_mask is None:
731
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
732
+ else:
733
+ attention_mask = attention_mask.to(hidden_states.device)
734
+
735
+ if self.alibi:
736
+ alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
737
+ else:
738
+ alibi = None
739
+
740
+ causal_mask = self._prepare_attn_mask(
741
+ attention_mask,
742
+ input_shape=(batch_size, seq_length),
743
+ past_key_values_length=past_key_values_length,
744
+ )
745
+
746
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
747
+
748
+ if output_hidden_states:
749
+ all_hidden_states = all_hidden_states + (hidden_states,)
750
+
751
+ if self.gradient_checkpointing and self.training:
752
+
753
+ if use_cache:
754
+ logger.warning(
755
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
756
+ )
757
+ use_cache = False
758
+
759
+ def create_custom_forward(module):
760
+ def custom_forward(*inputs, **kwargs):
761
+ # None for past_key_value
762
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions, **kwargs)
763
+
764
+ return custom_forward
765
+
766
+ outputs = torch.utils.checkpoint.checkpoint(
767
+ create_custom_forward(block),
768
+ hidden_states,
769
+ alibi,
770
+ causal_mask,
771
+ head_mask[i],
772
+ use_reentrant=False,
773
+ is_mem=is_mem,
774
+ last_section_mask=last_section_mask,
775
+ offload_cache_to_cpu=offload_cache_to_cpu,
776
+ use_flash=use_flash,
777
+ cache_top_k=cache_top_k,
778
+ mem_freq=mem_freq,
779
+ )
780
+ else:
781
+ outputs = block(
782
+ hidden_states,
783
+ layer_past=layer_past,
784
+ attention_mask=causal_mask,
785
+ head_mask=head_mask[i],
786
+ use_cache=use_cache,
787
+ output_attentions=output_attentions,
788
+ alibi=alibi,
789
+ is_mem=is_mem,
790
+ last_section_mask=last_section_mask,
791
+ offload_cache_to_cpu=offload_cache_to_cpu,
792
+ use_flash=use_flash,
793
+ cache_top_k=cache_top_k,
794
+ mem_freq=mem_freq,
795
+ )
796
+
797
+ hidden_states = outputs[0]
798
+ if use_cache is True:
799
+ presents = presents + (outputs[1],)
800
+
801
+ if output_attentions:
802
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
803
+
804
+ # Add last hidden state
805
+ hidden_states = self.ln_f(hidden_states)
806
+
807
+ if output_hidden_states:
808
+ all_hidden_states = all_hidden_states + (hidden_states,)
809
+
810
+ if not return_dict:
811
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
812
+
813
+ return BaseModelOutputWithPastAndCrossAttentions(
814
+ last_hidden_state=hidden_states,
815
+ past_key_values=presents,
816
+ hidden_states=all_hidden_states,
817
+ attentions=all_self_attentions,
818
+ )
819
+
820
+
821
+ class RWForCausalLM(RWPreTrainedModel):
822
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
823
+
824
+ def __init__(self, config: RWConfig):
825
+ super().__init__(config)
826
+ self.transformer = RWModel(config)
827
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
828
+ self.auto_insert_landmarks = False
829
+ self.always_use_flash = False
830
+
831
+ # Initialize weights and apply final processing
832
+ self.post_init()
833
+
834
+ def get_output_embeddings(self):
835
+ return self.lm_head
836
+
837
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
838
+ self.lm_head = new_embeddings
839
+
840
+ def set_mem_id(self, mem_id):
841
+ if self.config.mem_id is not None:
842
+ assert mem_id == self.config.mem_id, "Chanigng mem_id can break the model. If you really intend to do this, manually disable this check"
843
+ self.config.mem_id = mem_id
844
+
845
+ def enable_landmark_insertion(self):
846
+ self.auto_insert_landmarks = True
847
+
848
+ def enable_flash(self):
849
+ self.always_use_flash = True
850
+
851
+ def prepare_inputs_for_generation(
852
+ self,
853
+ input_ids: torch.LongTensor,
854
+ past: Optional[torch.Tensor] = None,
855
+ attention_mask: Optional[torch.Tensor] = None,
856
+ **kwargs,
857
+ ) -> dict:
858
+ total_len = input_ids.shape[1]
859
+
860
+ # only last token for input_ids if past is not None
861
+ if past:
862
+ input_ids = input_ids[:, -1].unsqueeze(-1)
863
+
864
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
865
+ if past[0][0].shape[0] == input_ids.shape[0]:
866
+ past = self._convert_to_rw_cache(past)
867
+
868
+ use_flash = False if kwargs.get("use_flash") is not None else None
869
+ else:
870
+ prev_len = 0
871
+ use_flash = kwargs.get("use_flash")
872
+
873
+ mem_freq = kwargs.get("mem_freq") or self.config.mem_freq
874
+
875
+ if mem_freq is not None:
876
+ # if position_ids is not None:
877
+ # raise NotImplementedError
878
+ T = input_ids.shape[1]
879
+
880
+ prev_incomplete_len = prev_len % mem_freq
881
+ prev_complete_len = prev_len - prev_incomplete_len
882
+ incomplete_len = total_len % mem_freq
883
+ new_full_len = total_len - prev_complete_len - incomplete_len
884
+
885
+ prev_input, input_ids_with_mem, input_ids_without_mem = torch.split(input_ids, (prev_complete_len, new_full_len, incomplete_len), dim=-1)
886
+
887
+ bsz, q_len = input_ids.size()
888
+ input_ids_with_mem = input_ids_with_mem.view(bsz, -1, mem_freq)
889
+ input_ids_with_mem = torch.cat(
890
+ (
891
+ input_ids_with_mem,
892
+ input_ids_with_mem.new_full((bsz, input_ids_with_mem.shape[1], 1), self.config.mem_id)
893
+ ),
894
+ dim=-1
895
+ ).view(bsz, -1)
896
+ input_ids = torch.cat((prev_input, input_ids_with_mem, input_ids_without_mem), dim=-1)
897
+ if attention_mask is not None:
898
+ attention_mask_with_mem, attention_mask_without_mem = torch.split(attention_mask, (prev_complete_len + new_full_len, incomplete_len), dim=-1)
899
+ attention_mask_with_mem = attention_mask_with_mem.view(bsz, -1, mem_freq)
900
+ attention_mask_with_mem = torch.cat(
901
+ (
902
+ attention_mask_with_mem,
903
+ attention_mask_with_mem.new_ones((bsz, attention_mask_with_mem.shape[1], 1))
904
+ ),
905
+ dim=-1
906
+ ).view(bsz, -1)
907
+ attention_mask = torch.cat((attention_mask_with_mem, attention_mask_without_mem), dim=-1)
908
+
909
+ input_ids = input_ids[:, prev_len:]
910
+
911
+ return {
912
+ "input_ids": input_ids,
913
+ "past_key_values": past,
914
+ "use_cache": kwargs.get("use_cache"),
915
+ "attention_mask": attention_mask,
916
+ "offload_cache_to_cpu": kwargs.get("offload_cache_to_cpu"),
917
+ "use_flash": use_flash,
918
+ "cache_top_k": kwargs.get("cache_top_k"),
919
+ "max_chunk_length": kwargs.get("max_chunk_length", 0),
920
+ "mem_freq": mem_freq,
921
+ "drop_last_logit_if_mem": True,
922
+ }
923
+
924
+ def forward(
925
+ self,
926
+ input_ids: Optional[torch.LongTensor] = None,
927
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
928
+ attention_mask: Optional[torch.Tensor] = None,
929
+ head_mask: Optional[torch.Tensor] = None,
930
+ inputs_embeds: Optional[torch.Tensor] = None,
931
+ labels: Optional[torch.Tensor] = None,
932
+ use_cache: Optional[bool] = None,
933
+ output_attentions: Optional[bool] = None,
934
+ output_hidden_states: Optional[bool] = None,
935
+ return_dict: Optional[bool] = None,
936
+ offload_cache_to_cpu: Optional[bool] = None,
937
+ use_flash: Optional[bool] = None,
938
+ cache_top_k: Optional[int] = None,
939
+ max_chunk_length: Optional[int] = 0,
940
+ mem_freq: Optional[int] = None,
941
+ drop_last_logit_if_mem: Optional[bool] = False,
942
+ **deprecated_arguments,
943
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
944
+ r"""
945
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
946
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
947
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
948
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
949
+ """
950
+ if deprecated_arguments.pop("position_ids", False) is not False:
951
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
952
+ warnings.warn(
953
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
954
+ " passing `position_ids`.",
955
+ FutureWarning,
956
+ )
957
+ if len(deprecated_arguments) > 0:
958
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
959
+
960
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
961
+
962
+ use_flash = use_flash if use_flash is not None else self.always_use_flash
963
+
964
+ if self.auto_insert_landmarks:
965
+ mem_freq = self.config.mem_freq
966
+ assert self.config.mem_freq is not None
967
+ block_size = self.config.mem_freq + 1
968
+ input_ids = input_ids.view(input_ids.shape[0], -1, block_size - 1)
969
+ input_ids = torch.cat((input_ids, input_ids.new_full((input_ids.shape[0], input_ids.shape[1], 1), self.config.mem_id)), dim=-1)
970
+ input_ids = input_ids.view(input_ids.shape[0], -1)
971
+ if attention_mask is not None:
972
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1, block_size - 1)
973
+ attention_mask = torch.cat((attention_mask, attention_mask.new_ones((attention_mask.shape[0], attention_mask.shape[1], 1))), dim=-1)
974
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1)
975
+
976
+ if max_chunk_length == 0:
977
+ if cache_top_k is not None:
978
+ max_chunk_length = self.config.train_context_length - self.config.train_context_length % (mem_freq + 1) - (cache_top_k + 1) * (mem_freq + 1)
979
+ if max_chunk_length <= 0:
980
+ raise ValueError("K is too large for this model.")
981
+ else:
982
+ max_chunk_length = None
983
+
984
+ window_len = max_chunk_length or input_ids.shape[1]
985
+ if use_flash:
986
+ assert window_len % (mem_freq + 1) == 0
987
+ last_logits = None
988
+ for step, idx in enumerate(range(0, input_ids.shape[1], window_len)):
989
+ if idx >= 1:
990
+ if output_attentions or output_hidden_states:
991
+ raise NotImplementedError
992
+ if not use_cache:
993
+ raise NotImplementedError
994
+
995
+ outputs = self.transformer(
996
+ input_ids[:, idx:idx + window_len],
997
+ past_key_values=past_key_values,
998
+ attention_mask=attention_mask[:, :idx + window_len + attention_mask.shape[1] - input_ids.shape[1]] if attention_mask is not None else None,
999
+ head_mask=head_mask, ## ??
1000
+ inputs_embeds=inputs_embeds[:, idx:idx + window_len] if inputs_embeds is not None else None,
1001
+ use_cache=use_cache,
1002
+ output_attentions=output_attentions,
1003
+ output_hidden_states=output_hidden_states,
1004
+ return_dict=return_dict,
1005
+ offload_cache_to_cpu=offload_cache_to_cpu,
1006
+ use_flash=(use_flash or self.auto_insert_landmarks),
1007
+ cache_top_k=cache_top_k,
1008
+ mem_freq=mem_freq,
1009
+ )
1010
+ past_key_values = outputs[1]
1011
+ if last_logits is not None:
1012
+ last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
1013
+ last_logits = outputs[0]
1014
+
1015
+ hidden_states = last_logits
1016
+ if self.auto_insert_landmarks:
1017
+ block_size = self.config.mem_freq + 1
1018
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], hidden_states.shape[1] // block_size, block_size, hidden_states.shape[2])
1019
+ hidden_states = hidden_states[:, :, :block_size - 1]
1020
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], -1, hidden_states.shape[3])
1021
+ if drop_last_logit_if_mem:
1022
+ is_any_mem = (input_ids[:, -1] == self.config.mem_id).any()
1023
+ are_all_mem = (input_ids[:, -1] == self.config.mem_id).all()
1024
+ assert is_any_mem == are_all_mem
1025
+ if is_any_mem:
1026
+ hidden_states = hidden_states[:, :-1]
1027
+
1028
+ lm_logits = self.lm_head(hidden_states)
1029
+
1030
+ loss = None
1031
+ if labels is not None:
1032
+ # Shift so that tokens < n predict n
1033
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1034
+ shift_labels = labels[..., 1:].contiguous()
1035
+ batch_size, seq_length, vocab_size = shift_logits.shape
1036
+ # Flatten the tokens
1037
+ loss_fct = CrossEntropyLoss()
1038
+ loss = loss_fct(
1039
+ shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
1040
+ )
1041
+
1042
+ if not return_dict:
1043
+ output = (lm_logits,) + outputs[1:]
1044
+ return ((loss,) + output) if loss is not None else output
1045
+
1046
+ return CausalLMOutputWithCrossAttentions(
1047
+ loss=loss,
1048
+ logits=lm_logits,
1049
+ past_key_values=outputs.past_key_values,
1050
+ hidden_states=outputs.hidden_states,
1051
+ attentions=outputs.attentions,
1052
+ )
1053
+
1054
+ def _reorder_cache(
1055
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1056
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1057
+ """
1058
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1059
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1060
+ beam_idx at every generation step.
1061
+
1062
+ Output shares the same memory storage as `past`.
1063
+ """
1064
+ standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
1065
+
1066
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
1067
+ device_to_beam_idx = {
1068
+ past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
1069
+ }
1070
+ reordered_past = tuple(
1071
+ (
1072
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
1073
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
1074
+ )
1075
+ for layer_past in standardized_past
1076
+ )
1077
+ return self._convert_to_rw_cache(reordered_past)
1078
+
1079
+
1080
+ class RWForSequenceClassification(RWPreTrainedModel):
1081
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1082
+
1083
+ def __init__(self, config: RWConfig):
1084
+ super().__init__(config)
1085
+ self.num_labels = config.num_labels
1086
+ self.transformer = RWModel(config)
1087
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
1088
+
1089
+ # Initialize weights and apply final processing
1090
+ self.post_init()
1091
+
1092
+ def forward(
1093
+ self,
1094
+ input_ids: Optional[torch.LongTensor] = None,
1095
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1096
+ attention_mask: Optional[torch.Tensor] = None,
1097
+ head_mask: Optional[torch.Tensor] = None,
1098
+ inputs_embeds: Optional[torch.Tensor] = None,
1099
+ labels: Optional[torch.Tensor] = None,
1100
+ use_cache: Optional[bool] = None,
1101
+ output_attentions: Optional[bool] = None,
1102
+ output_hidden_states: Optional[bool] = None,
1103
+ return_dict: Optional[bool] = None,
1104
+ **deprecated_arguments,
1105
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1106
+ r"""
1107
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1108
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1109
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1110
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1111
+ """
1112
+ if deprecated_arguments.pop("position_ids", False) is not False:
1113
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
1114
+ warnings.warn(
1115
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
1116
+ " passing `position_ids`.",
1117
+ FutureWarning,
1118
+ )
1119
+ if len(deprecated_arguments) > 0:
1120
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
1121
+
1122
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1123
+
1124
+ transformer_outputs = self.transformer(
1125
+ input_ids,
1126
+ past_key_values=past_key_values,
1127
+ attention_mask=attention_mask,
1128
+ head_mask=head_mask,
1129
+ inputs_embeds=inputs_embeds,
1130
+ use_cache=use_cache,
1131
+ output_attentions=output_attentions,
1132
+ output_hidden_states=output_hidden_states,
1133
+ return_dict=return_dict,
1134
+ )
1135
+
1136
+ hidden_states = transformer_outputs[0]
1137
+ logits = self.score(hidden_states)
1138
+
1139
+ if input_ids is not None:
1140
+ batch_size = input_ids.shape[0]
1141
+ else:
1142
+ batch_size = inputs_embeds.shape[0]
1143
+
1144
+ if self.config.pad_token_id is None and batch_size != 1:
1145
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1146
+ if self.config.pad_token_id is None:
1147
+ sequence_lengths = -1
1148
+ else:
1149
+ if input_ids is not None:
1150
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
1151
+ else:
1152
+ sequence_lengths = -1
1153
+ logger.warning(
1154
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1155
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1156
+ )
1157
+
1158
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1159
+
1160
+ loss = None
1161
+ if labels is not None:
1162
+ if self.config.problem_type is None:
1163
+ if self.num_labels == 1:
1164
+ self.config.problem_type = "regression"
1165
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1166
+ self.config.problem_type = "single_label_classification"
1167
+ else:
1168
+ self.config.problem_type = "multi_label_classification"
1169
+
1170
+ if self.config.problem_type == "regression":
1171
+ loss_fct = MSELoss()
1172
+ if self.num_labels == 1:
1173
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1174
+ else:
1175
+ loss = loss_fct(pooled_logits, labels)
1176
+ elif self.config.problem_type == "single_label_classification":
1177
+ loss_fct = CrossEntropyLoss()
1178
+ loss = loss_fct(pooled_logits, labels)
1179
+ elif self.config.problem_type == "multi_label_classification":
1180
+ loss_fct = BCEWithLogitsLoss()
1181
+ loss = loss_fct(pooled_logits, labels)
1182
+ if not return_dict:
1183
+ output = (pooled_logits,) + transformer_outputs[1:]
1184
+ return ((loss,) + output) if loss is not None else output
1185
+
1186
+ return SequenceClassifierOutputWithPast(
1187
+ loss=loss,
1188
+ logits=pooled_logits,
1189
+ past_key_values=transformer_outputs.past_key_values,
1190
+ hidden_states=transformer_outputs.hidden_states,
1191
+ attentions=transformer_outputs.attentions,
1192
+ )
1193
+
1194
+
1195
+ class RWForTokenClassification(RWPreTrainedModel):
1196
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1197
+
1198
+ def __init__(self, config: RWConfig):
1199
+ super().__init__(config)
1200
+ self.num_labels = config.num_labels
1201
+
1202
+ self.transformer = RWModel(config)
1203
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1204
+ classifier_dropout = config.classifier_dropout
1205
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1206
+ classifier_dropout = config.hidden_dropout
1207
+ else:
1208
+ classifier_dropout = 0.1
1209
+ self.dropout = nn.Dropout(classifier_dropout)
1210
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1211
+
1212
+ # Initialize weights and apply final processing
1213
+ self.post_init()
1214
+
1215
+ def forward(
1216
+ self,
1217
+ input_ids: Optional[torch.LongTensor] = None,
1218
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1219
+ attention_mask: Optional[torch.Tensor] = None,
1220
+ head_mask: Optional[torch.Tensor] = None,
1221
+ inputs_embeds: Optional[torch.Tensor] = None,
1222
+ labels: Optional[torch.Tensor] = None,
1223
+ use_cache: Optional[bool] = None,
1224
+ output_attentions: Optional[bool] = None,
1225
+ output_hidden_states: Optional[bool] = None,
1226
+ return_dict: Optional[bool] = None,
1227
+ **deprecated_arguments,
1228
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1229
+ r"""
1230
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1231
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1232
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1233
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1234
+ """
1235
+ if deprecated_arguments.pop("position_ids", False) is not False:
1236
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
1237
+ warnings.warn(
1238
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
1239
+ " passing `position_ids`.",
1240
+ FutureWarning,
1241
+ )
1242
+ if len(deprecated_arguments) > 0:
1243
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
1244
+
1245
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1246
+
1247
+ transformer_outputs = self.transformer(
1248
+ input_ids,
1249
+ past_key_values=past_key_values,
1250
+ attention_mask=attention_mask,
1251
+ head_mask=head_mask,
1252
+ inputs_embeds=inputs_embeds,
1253
+ use_cache=use_cache,
1254
+ output_attentions=output_attentions,
1255
+ output_hidden_states=output_hidden_states,
1256
+ return_dict=return_dict,
1257
+ )
1258
+
1259
+ hidden_states = transformer_outputs[0]
1260
+ hidden_states = self.dropout(hidden_states)
1261
+ logits = self.classifier(hidden_states)
1262
+
1263
+ loss = None
1264
+ if labels is not None:
1265
+ batch_size, seq_length = labels.shape
1266
+ loss_fct = CrossEntropyLoss()
1267
+ loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
1268
+
1269
+ if not return_dict:
1270
+ output = (logits,) + transformer_outputs[2:]
1271
+ return ((loss,) + output) if loss is not None else output
1272
+
1273
+ return TokenClassifierOutput(
1274
+ loss=loss,
1275
+ logits=logits,
1276
+ hidden_states=transformer_outputs.hidden_states,
1277
+ attentions=transformer_outputs.attentions,
1278
+ )
1279
+
1280
+
1281
+ class RWForQuestionAnswering(RWPreTrainedModel):
1282
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1283
+
1284
+ def __init__(self, config):
1285
+ super().__init__(config)
1286
+ self.transformer = RWModel(config)
1287
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1288
+
1289
+ # Initialize weights and apply final processing
1290
+ self.post_init()
1291
+
1292
+ def forward(
1293
+ self,
1294
+ input_ids: Optional[torch.LongTensor] = None,
1295
+ attention_mask: Optional[torch.FloatTensor] = None,
1296
+ position_ids: Optional[torch.LongTensor] = None,
1297
+ head_mask: Optional[torch.FloatTensor] = None,
1298
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1299
+ start_positions: Optional[torch.LongTensor] = None,
1300
+ end_positions: Optional[torch.LongTensor] = None,
1301
+ output_attentions: Optional[bool] = None,
1302
+ output_hidden_states: Optional[bool] = None,
1303
+ return_dict: Optional[bool] = None,
1304
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1305
+ r"""
1306
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1307
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1308
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1309
+ are not taken into account for computing the loss.
1310
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1311
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1312
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1313
+ are not taken into account for computing the loss.
1314
+ """
1315
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1316
+
1317
+ outputs = self.transformer(
1318
+ input_ids,
1319
+ attention_mask=attention_mask,
1320
+ position_ids=position_ids,
1321
+ head_mask=head_mask,
1322
+ inputs_embeds=inputs_embeds,
1323
+ output_attentions=output_attentions,
1324
+ output_hidden_states=output_hidden_states,
1325
+ return_dict=return_dict,
1326
+ )
1327
+
1328
+ sequence_output = outputs[0]
1329
+
1330
+ logits = self.qa_outputs(sequence_output)
1331
+ start_logits, end_logits = logits.split(1, dim=-1)
1332
+ start_logits = start_logits.squeeze(-1).contiguous()
1333
+ end_logits = end_logits.squeeze(-1).contiguous()
1334
+
1335
+ total_loss = None
1336
+ if start_positions is not None and end_positions is not None:
1337
+ # If we are on multi-GPU, split add a dimension
1338
+ if len(start_positions.size()) > 1:
1339
+ start_positions = start_positions.squeeze(-1)
1340
+ if len(end_positions.size()) > 1:
1341
+ end_positions = end_positions.squeeze(-1)
1342
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1343
+ ignored_index = start_logits.size(1)
1344
+ start_positions = start_positions.clamp(0, ignored_index)
1345
+ end_positions = end_positions.clamp(0, ignored_index)
1346
+
1347
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1348
+ start_loss = loss_fct(start_logits, start_positions)
1349
+ end_loss = loss_fct(end_logits, end_positions)
1350
+ total_loss = (start_loss + end_loss) / 2
1351
+
1352
+ if not return_dict:
1353
+ output = (start_logits, end_logits) + outputs[2:]
1354
+ return ((total_loss,) + output) if total_loss is not None else output
1355
+
1356
+ return QuestionAnsweringModelOutput(
1357
+ loss=total_loss,
1358
+ start_logits=start_logits,
1359
+ end_logits=end_logits,
1360
+ hidden_states=outputs.hidden_states,
1361
+ attentions=outputs.attentions,
1362
+ )
code/redpajama.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Together Computer
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Lint as: python3
16
+ """RedPajama: An Open-Source, Clean-Room 1.2 Trillion Token Dataset."""
17
+
18
+
19
+ import json
20
+
21
+ import datasets
22
+ import traceback
23
+ import numpy as np
24
+ import math
25
+
26
+ logger = datasets.logging.get_logger(__name__)
27
+
28
+
29
+ _DESCRIPTION = """\
30
+ RedPajama is a clean-room, fully open-source implementation of the LLaMa dataset.
31
+ """
32
+
33
+ _URL_LISTS = {
34
+ "arxiv": "urls/arxiv.txt",
35
+ "book": "urls/book.txt",
36
+ "c4": "urls/c4.txt",
37
+ "common_crawl": "urls/common_crawl.txt",
38
+ "github": "urls/github.txt",
39
+ "stackexchange": "urls/stackexchange.txt",
40
+ "wikipedia": "urls/wikipedia.txt",
41
+ }
42
+
43
+
44
+ class RedPajama1TConfig(datasets.BuilderConfig):
45
+ """BuilderConfig for RedPajama sample."""
46
+
47
+ def __init__(self, *args, subsets, p_sample=None, **kwargs):
48
+ """BuilderConfig for RedPajama.
49
+ Args:
50
+ **kwargs: keyword arguments forwarded to super.
51
+ """
52
+ super(RedPajama1TConfig, self).__init__(**kwargs)
53
+
54
+ self.subsets = subsets
55
+ self.p_sample = p_sample
56
+
57
+
58
+ class RedPajama1T(datasets.GeneratorBasedBuilder):
59
+ """RedPajama: Reproducing the LLaMA training dataset of over 1.2 trillion tokens. Version 1.0.0."""
60
+ BUILDER_CONFIG_CLASS = RedPajama1TConfig
61
+ BUILDER_CONFIGS = [
62
+ RedPajama1TConfig(
63
+ subsets = list(_URL_LISTS.keys()),
64
+ name="plain_text",
65
+ version=datasets.Version("1.0.0", ""),
66
+ description="Plain text",
67
+ ),
68
+ RedPajama1TConfig(
69
+ subsets = list(_URL_LISTS.keys()),
70
+ name="plain_text_tenpercent",
71
+ version=datasets.Version("1.0.0", ""),
72
+ description="Plain text",
73
+ p_sample=0.1
74
+ ),
75
+ ]
76
+
77
+ def _info(self):
78
+ return datasets.DatasetInfo(
79
+ description=_DESCRIPTION,
80
+ features=datasets.Features(
81
+ {
82
+ "text": datasets.Value("string"),
83
+ "meta": datasets.Value("string"),
84
+ "red_pajama_subset": datasets.Value("string"),
85
+ }
86
+ ),
87
+ supervised_keys=None,
88
+ )
89
+
90
+ def _split_generators(self, dl_manager):
91
+ url_lists = dl_manager.download_and_extract({
92
+ subset: _URL_LISTS[subset] for subset in self.config.subsets
93
+ })
94
+
95
+ urls = {}
96
+ rng = np.random.default_rng(seed=2)
97
+
98
+ for subset, url_list in url_lists.items():
99
+ with open(url_list, encoding="utf-8") as f:
100
+ urls[subset] = [line.strip() for line in f]
101
+ if self.config.p_sample is not None:
102
+ urls[subset] = rng.choice(
103
+ urls[subset],
104
+ size=int(math.ceil(len(urls[subset]) * self.config.p_sample)), replace=False).tolist()
105
+
106
+ downloaded_files = dl_manager.download(urls)
107
+
108
+ return [
109
+ datasets.SplitGenerator(
110
+ name=datasets.Split.TRAIN,
111
+ gen_kwargs = {
112
+ "files": {
113
+ subset: downloaded_files[subset]
114
+ for subset in self.config.subsets
115
+ }
116
+ }
117
+ )
118
+ ]
119
+
120
+ def _generate_examples(self, files):
121
+ """This function returns the examples in the raw (text) form."""
122
+ key = 0
123
+ for subset in files:
124
+ if subset == "common_crawl":
125
+ import zstandard as zstd
126
+
127
+ for path in files[subset]:
128
+ with zstd.open(open(path, "rb"), "rt", encoding="utf-8") as f:
129
+ for i, row in enumerate(f):
130
+ try:
131
+ data = json.loads(row)
132
+ text = data["text"]
133
+ del data["text"]
134
+ yield key, {
135
+ "text": text,
136
+ "meta": json.dumps(data),
137
+ "red_pajama_subset": subset,
138
+ }
139
+ key += 1
140
+ except Exception as e:
141
+ print(f'Subset: {subset}')
142
+ print(f'Path: {path}')
143
+ print(f'Row: {row}')
144
+ traceback.print_exc()
145
+
146
+ raise e
147
+ else:
148
+ for path in files[subset]:
149
+ with open(path, encoding="utf-8") as f:
150
+ for i, row in enumerate(f):
151
+ try:
152
+ data = json.loads(row)
153
+ if "meta" not in data:
154
+ text = data["text"]
155
+ del data["text"]
156
+ yield key, {
157
+ "text": text,
158
+ "meta": json.dumps(data),
159
+ "red_pajama_subset": subset,
160
+ }
161
+ else:
162
+ yield key, {
163
+ "text": data["text"],
164
+ "meta": data["meta"],
165
+ "red_pajama_subset": subset,
166
+ }
167
+ key += 1
168
+ except Exception as e:
169
+ print(f'Subset: {subset}')
170
+ print(f'Path: {path}')
171
+ print(f'Row: {row}')
172
+ traceback.print_exc()
173
+
174
+ raise e
code/run_test.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ import os
18
+ import random
19
+ import re
20
+ import requests
21
+
22
+
23
+ llama_weights_7b_base = "/llama_weights/7B_hf/"
24
+ llama_weights_7b_tuned = "/llama-redpajama-mem-15000-with-mem/"
25
+ cache_path = "/hf-cache/"
26
+ use_flash = False # using flash for inference is only implemented for when offloading kv to cpu
27
+ top_k = 5
28
+ dtype = torch.bfloat16
29
+
30
+ def make_llama_base_pipe():
31
+
32
+ from transformers import pipeline
33
+
34
+ from transformers.models.llama import LlamaForCausalLM
35
+
36
+ llama_base = LlamaForCausalLM.from_pretrained(
37
+ llama_weights_7b_base,
38
+ cache_dir=cache_path,
39
+ )
40
+
41
+ llama_base = llama_base.to('cuda:0')
42
+
43
+ import transformers
44
+
45
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
46
+ llama_weights_7b_base,
47
+ cache_dir=cache_path,
48
+ model_max_length=2048,
49
+ padding_side="right",
50
+ use_fast=False,
51
+ )
52
+
53
+ llama_base_pipe = pipeline("text-generation", model=llama_base, tokenizer=tokenizer, device=llama_base.device)
54
+ return llama_base_pipe
55
+
56
+
57
+
58
+ llama_base_pipe = make_llama_base_pipe()
59
+
60
+ def make_llama_mem_pipe():
61
+ from llama_mem import LlamaForCausalLM
62
+
63
+ model = LlamaForCausalLM.from_pretrained(
64
+ llama_weights_7b_tuned,
65
+ cache_dir=cache_path,
66
+ torch_dtype=dtype
67
+ )
68
+
69
+ model.to('cuda:1')
70
+
71
+ import transformers
72
+
73
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
74
+ llama_weights_7b_tuned,
75
+ cache_dir=cache_path,
76
+ model_max_length=model.config.train_context_length,
77
+ padding_side="right",
78
+ use_fast=False,
79
+ )
80
+ mem_id = tokenizer.convert_tokens_to_ids("<landmark>")
81
+ model.set_mem_id(mem_id)
82
+ from transformers import pipeline
83
+ llama_mem_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=model.device,
84
+ offload_cache_to_cpu=use_flash, use_flash=use_flash,
85
+ cache_top_k=top_k)
86
+ return llama_mem_pipe
87
+
88
+
89
+ llama_mem_pipe = make_llama_mem_pipe()
90
+
91
+
92
+
93
+ pipes = {"base": llama_base_pipe, "mem": llama_mem_pipe}
94
+
95
+
96
+ def generate_prompt(n_garbage):
97
+ """Generates a text file and inserts an execute line at a random position."""
98
+ n_garbage_prefix = random.randint(0, n_garbage)
99
+ n_garbage_suffix = n_garbage - n_garbage_prefix
100
+
101
+ task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
102
+ garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
103
+ garbage_inf = " ".join([garbage] * 2000)
104
+ assert len(garbage_inf) >= n_garbage
105
+ garbage_prefix = garbage_inf[:n_garbage_prefix]
106
+ garbage_suffix = garbage_inf[:n_garbage_suffix]
107
+ pass_key = random.randint(1, 50000)
108
+ information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
109
+ final_question = "What is the pass key? The pass key is"
110
+ lines = [
111
+ task_description,
112
+ garbage_prefix,
113
+ information_line,
114
+ garbage_suffix,
115
+ final_question
116
+ ]
117
+ return "\n".join(lines), pass_key
118
+
119
+
120
+
121
+ def test_model(prompt_text, pass_key, model_name):
122
+ response = pipes[model_name](prompt_text,num_return_sequences=1, max_new_tokens=10)[0]["generated_text"][len(prompt_text):]
123
+ assert f"The pass key is {pass_key}" in prompt_text
124
+
125
+ try:
126
+ pass_key = int(re.search(r'\d+', response).group())
127
+ except:
128
+ pass_key = response[:20]
129
+
130
+ return pass_key
131
+
132
+
133
+ n_values = [0, 100, 500, 1000, 5000, 8000, 10000, 12000, 14000, 18000, 20000, 25000, 38000]
134
+ num_tests = 50
135
+ models = ["base", "mem"]
136
+ accuracies = {x: [] for x in models}
137
+ individual_results = {x: [] for x in models}
138
+
139
+ for n in n_values:
140
+
141
+ correct_count = {x: 0 for x in models}
142
+
143
+ n_results = {x: [] for x in models}
144
+ for i in range(num_tests):
145
+ print(f"\nRunning test {i + 1}/{num_tests} for n = {n}...")
146
+ prompt_text, pass_key = generate_prompt(n)
147
+
148
+
149
+
150
+ for model_name in models:
151
+ if pipes[model_name] is None:
152
+ continue
153
+ num_tokens = len(pipes[model_name].tokenizer.encode(prompt_text))
154
+
155
+ print("Number of tokens in this prompt: ", num_tokens)
156
+ model_output = test_model(prompt_text, pass_key, model_name)
157
+ print(f"Expected number in the prompt: {pass_key}, {model_name} output: {model_output}")
158
+
159
+ if pass_key == model_output:
160
+ correct_count[model_name] += 1
161
+ n_results[model_name].append(1)
162
+ print("Success!")
163
+ else:
164
+ n_results[model_name].append(0)
165
+ print("Fail.")
166
+
167
+ for model in models:
168
+ accuracy = (correct_count[model] / num_tests) * 100
169
+ print(f"Accuracy {model} for n = {n}: {accuracy}%")
170
+ accuracies[model].append(accuracy)
171
+ individual_results[model].append(n_results)
code/run_train_1x.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ CUDA_VISIBLE_DEVICES=0 python train.py \
3
+ --model_name_or_path tiiuae/falcon-7b \
4
+ --bf16 True \
5
+ --output_dir ./out_dir/ \
6
+ --cache_dir ./hf-cache/ \
7
+ --num_train_epochs 1 \
8
+ --per_device_train_batch_size 1 \
9
+ --per_device_eval_batch_size 1 \
10
+ --gradient_accumulation_steps 1 \
11
+ --evaluation_strategy "no" \
12
+ --save_strategy "steps" \
13
+ --save_steps 2000 \
14
+ --save_total_limit 2 \
15
+ --learning_rate 2e-5 \
16
+ --weight_decay 0.1 \
17
+ --warmup_ratio 0.03 \
18
+ --lr_scheduler_type "cosine" \
19
+ --logging_steps 1 \
20
+ --tf32 True \
21
+ --max_steps 15000 \
22
+ --model_max_length 1024 \
23
+ --mem_freq 31
code/run_train_8x.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun --nproc_per_node=8 train.py \
3
+ --model_name_or_path tiiuae/falcon-7b \
4
+ --bf16 True \
5
+ --output_dir ./out_dir/ \
6
+ --cache_dir ./hf-cache/ \
7
+ --num_train_epochs 1 \
8
+ --per_device_train_batch_size 2 \
9
+ --per_device_eval_batch_size 2 \
10
+ --gradient_accumulation_steps 8 \
11
+ --evaluation_strategy "no" \
12
+ --save_strategy "steps" \
13
+ --save_steps 2000 \
14
+ --save_total_limit 2 \
15
+ --learning_rate 2e-5 \
16
+ --weight_decay 0.1 \
17
+ --warmup_ratio 0.03 \
18
+ --lr_scheduler_type "cosine" \
19
+ --logging_steps 1 \
20
+ --tf32 True \
21
+ --max_steps 15000 \
22
+ --model_max_length 2048 \
23
+ --mem_freq 31 \
24
+ --fsdp "full_shard auto_wrap" \
25
+ --fsdp_transformer_layer_cls_to_wrap 'DecoderLayer'
code/train.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #import copy
16
+ #import logging
17
+ from dataclasses import dataclass, field
18
+ from functools import partial
19
+ from typing import Dict, Optional, Sequence
20
+
21
+
22
+ import torch
23
+ import transformers
24
+ #from torch.utils.data import Dataset
25
+ from transformers import Trainer, DataCollatorForLanguageModeling, get_cosine_schedule_with_warmup
26
+
27
+ from modelling_RW import RWForCausalLM
28
+ #from transformers import AutoModelForCausalLM
29
+
30
+
31
+ from torch.distributed import barrier
32
+ import os
33
+
34
+
35
+ from datasets import load_dataset
36
+
37
+ IGNORE_INDEX = -100
38
+ DEFAULT_PAD_TOKEN = "[PAD]"
39
+ DEFAULT_EOS_TOKEN = "</s>"
40
+ DEFAULT_BOS_TOKEN = "<s>"
41
+ DEFAULT_UNK_TOKEN = "<unk>"
42
+
43
+
44
+ @dataclass
45
+ class ModelArguments:
46
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
47
+
48
+
49
+ @dataclass
50
+ class TrainingArguments(transformers.TrainingArguments):
51
+ cache_dir: Optional[str] = field(default=None)
52
+ #optim: str = field(default="adamw_hf")
53
+ optim: str = field(default="adamw_torch")
54
+ model_max_length: int = field(
55
+ default=128,
56
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
57
+ )
58
+ use_flash: bool = field(default=False)
59
+ mem_freq: int = field(default=63)
60
+ #report_to: str = "none" # disable logging
61
+
62
+
63
+ class TrainerCosine(Trainer):
64
+ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
65
+ """
66
+ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
67
+ passed as an argument.
68
+
69
+ Args:
70
+ num_training_steps (int): The number of training steps to do.
71
+ """
72
+ if self.args.lr_scheduler_type != "cosine":
73
+ return super().create_scheduler(num_training_steps, optimizer)
74
+ if self.lr_scheduler is None:
75
+ self.lr_scheduler = get_cosine_schedule_with_warmup(
76
+ optimizer=self.optimizer if optimizer is None else optimizer,
77
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
78
+ num_training_steps=num_training_steps,
79
+ num_cycles=0.4 # ~10% of the init lr
80
+ )
81
+ return self.lr_scheduler
82
+
83
+
84
+ def smart_tokenizer_and_embedding_resize(
85
+ special_tokens_dict: Dict,
86
+ tokenizer: transformers.PreTrainedTokenizer,
87
+ model: transformers.PreTrainedModel,
88
+ ):
89
+ """Resize tokenizer and embedding.
90
+
91
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
92
+ """
93
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
94
+ model.resize_token_embeddings(len(tokenizer))
95
+
96
+ if num_new_tokens > 0:
97
+ input_embeddings = model.get_input_embeddings().weight.data
98
+ output_embeddings = model.get_output_embeddings().weight.data
99
+
100
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
101
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
102
+
103
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
104
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
105
+
106
+ def tokenize_fn(tokenizer, example):
107
+ context_length = tokenizer.model_max_length
108
+ outputs = tokenizer(
109
+ tokenizer.eos_token.join(example["text"]),
110
+ truncation=False,
111
+ return_tensors="pt",
112
+ pad_to_multiple_of=context_length,
113
+ padding=True,
114
+ )
115
+ return {"input_ids": outputs["input_ids"].view(-1, context_length)}
116
+
117
+ def train():
118
+ parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
119
+ model_args, training_args = parser.parse_args_into_dataclasses()
120
+
121
+ # ensure max length leaves room for landmark tokens
122
+ model_max_length = training_args.model_max_length - (training_args.model_max_length // training_args.mem_freq)
123
+ model_max_length = model_max_length // training_args.mem_freq * training_args.mem_freq
124
+
125
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
126
+ model_args.model_name_or_path,
127
+ cache_dir=training_args.cache_dir,
128
+ model_max_length=model_max_length,
129
+ padding_side="right",
130
+ use_fast=False,
131
+ )
132
+ special_tokens_dict = dict()
133
+ if tokenizer.pad_token is None:
134
+ special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
135
+ if tokenizer.eos_token is None:
136
+ special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
137
+ if tokenizer.bos_token is None:
138
+ special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
139
+ if tokenizer.unk_token is None:
140
+ special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
141
+ mem_token = "<landmark>"
142
+ special_tokens_dict["additional_special_tokens"] = [mem_token]
143
+
144
+ model = RWForCausalLM.from_pretrained(
145
+ model_args.model_name_or_path,
146
+ cache_dir=training_args.cache_dir,
147
+ mem_freq=training_args.mem_freq,
148
+ torch_dtype=torch.bfloat16,
149
+ )
150
+ # model = AutoModelForCausalLM.from_pretrained(
151
+ # model_args.model_name_or_path,
152
+ # cache_dir=training_args.cache_dir,
153
+ # torch_dtype=torch.bfloat16,
154
+ # trust_remote_code=True,
155
+ # )
156
+
157
+ smart_tokenizer_and_embedding_resize(
158
+ special_tokens_dict=special_tokens_dict,
159
+ tokenizer=tokenizer,
160
+ model=model,
161
+ )
162
+
163
+ mem_id = tokenizer.convert_tokens_to_ids(mem_token)
164
+ model.set_mem_id(mem_id)
165
+ print(f"Landmark token: {mem_token}: {mem_id}")
166
+
167
+ rank = int(os.environ.get('RANK', -1))
168
+ if rank > 0:
169
+ barrier()
170
+ #dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir, split='train[:100]')
171
+ dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir, split='train')
172
+
173
+ dataset = dataset.map(partial(tokenize_fn, tokenizer), batched=True, num_proc=32, remove_columns=["text", "meta"])
174
+
175
+ model.enable_landmark_insertion()
176
+ model.enable_flash()
177
+
178
+ # if training_args.use_flash:
179
+ # model.enable_landmark_insertion()
180
+ # model.enable_flash()
181
+ # else:
182
+ # dataset = dataset.map(
183
+ # partial(
184
+ # add_mem_tokens,
185
+ # mem_freq=training_args.mem_freq,
186
+ # mem_id=mem_id
187
+ # ), batched=False, num_proc=32)
188
+
189
+ if rank == 0:
190
+ barrier()
191
+ print(dataset)
192
+
193
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
194
+
195
+ trainer = TrainerCosine(
196
+ model=model, tokenizer=tokenizer, args=training_args,
197
+ train_dataset=dataset, #dataset["train"],
198
+ eval_dataset=None,
199
+ data_collator=data_collator)
200
+ trainer.train()
201
+ trainer.save_state()
202
+ trainer.save_model(output_dir=training_args.output_dir)
203
+
204
+
205
+ if __name__ == "__main__":
206
+ train()
code/weight_diff.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file has been changed by Amirkeivan Mohtashami
16
+ # to take into account the new token in the embedding layer
17
+
18
+ import os
19
+ from typing import Optional
20
+
21
+ import fire
22
+ import torch
23
+ import tqdm
24
+ import transformers
25
+ from train import smart_tokenizer_and_embedding_resize
26
+ import llama_mem
27
+
28
+ @torch.inference_mode()
29
+ def make_diff(
30
+ path_raw: str, path_tuned: str, path_diff: str, device="cpu", # "cuda" or "cpu"
31
+ ):
32
+ """Make the weight diff.
33
+
34
+ This function is given to present full transparency of how the weight diff was created.
35
+
36
+ Run:
37
+ python weight_diff.py make_diff --path_raw <your_path_raw> --path_tuned <your_path_tuned> --path_diff <your_path_diff>
38
+ """
39
+ model_tuned: transformers.PreTrainedModel = llama_mem.LlamaForCausalLM.from_pretrained(
40
+ path_tuned,
41
+ device_map={"": torch.device(device)},
42
+ torch_dtype=torch.float32,
43
+ low_cpu_mem_usage=True,
44
+ )
45
+ model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
46
+ path_raw,
47
+ device_map={"": torch.device(device)},
48
+ torch_dtype=torch.float32,
49
+ low_cpu_mem_usage=True,
50
+ )
51
+
52
+ tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
53
+ path_tuned
54
+ )
55
+ tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
56
+ path_raw
57
+ )
58
+ smart_tokenizer_and_embedding_resize(
59
+ special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=["<landmark>"]),
60
+ model=model_raw,
61
+ tokenizer=tokenizer_raw,
62
+ )
63
+
64
+
65
+
66
+ state_dict_tuned = model_tuned.state_dict()
67
+ state_dict_raw = model_raw.state_dict()
68
+ with open(os.path.join(path_diff, "checksum_psum.txt"), "w") as f:
69
+ f.write(str(sum(state_dict_tuned[key].sum().item() for key in state_dict_tuned)))
70
+
71
+ for key in tqdm.tqdm(state_dict_tuned):
72
+ state_dict_tuned[key].add_(-state_dict_raw[key])
73
+
74
+ model_tuned.save_pretrained(path_diff)
75
+ tokenizer_tuned.save_pretrained(path_diff)
76
+
77
+
78
+ @torch.inference_mode()
79
+ def recover(
80
+ path_raw,
81
+ path_diff,
82
+ path_tuned: Optional[str] = None,
83
+ device="cpu",
84
+ test_inference=True,
85
+ check_integrity_naively=True,
86
+ ):
87
+ """Recover the original weights from the released weight diff.
88
+
89
+ This function is given for you to run.
90
+
91
+ Things to do before running this:
92
+ 1. Convert Meta's released weights into huggingface format. Follow this guide:
93
+ https://huggingface.co/docs/transformers/main/model_doc/llama
94
+ 2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
95
+ https://huggingface.co/tatsu-lab/alpaca-7b/tree/main
96
+ 3. Run this function with the correct paths. E.g.,
97
+ python weight_diff.py recover --path_raw <path_to_step_1_dir> --path_diff <path_to_step_2_dir>
98
+
99
+ Additional notes:
100
+ - If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`.
101
+ - If you want to save the recovered weights, set `--path_tuned <your_path_tuned>`.
102
+ Next time you can load the recovered weights directly from `<your_path_tuned>`.
103
+ """
104
+ model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
105
+ path_raw,
106
+ device_map={"": torch.device(device)},
107
+ torch_dtype=torch.float32,
108
+ low_cpu_mem_usage=True,
109
+ )
110
+ model_recovered: transformers.PreTrainedModel = llama_mem.LlamaForCausalLM.from_pretrained(
111
+ path_diff,
112
+ device_map={"": torch.device(device)},
113
+ torch_dtype=torch.float32,
114
+ low_cpu_mem_usage=True,
115
+ )
116
+
117
+ tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
118
+ path_raw
119
+ )
120
+ smart_tokenizer_and_embedding_resize(
121
+ special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=["<landmark>"]),
122
+ model=model_raw,
123
+ tokenizer=tokenizer_raw,
124
+ )
125
+ tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
126
+ path_diff
127
+ )
128
+
129
+ state_dict_recovered = model_recovered.state_dict()
130
+ state_dict_raw = model_raw.state_dict()
131
+ for key in tqdm.tqdm(state_dict_recovered):
132
+ state_dict_recovered[key].add_(state_dict_raw[key])
133
+
134
+ if check_integrity_naively:
135
+ # This is not a rigorous, cryptographically strong integrity check :)
136
+ allsum = sum(state_dict_recovered[key].sum() for key in state_dict_recovered)
137
+ if os.path.exists(os.path.join(path_diff, "checksum_psum.txt")):
138
+ with open(os.path.join(path_diff, "checksum_psum.txt")) as f:
139
+ expected_sum = float(f.read())
140
+ else:
141
+ expected_sum = 49798.7656 # backward compatibility with the first released weights
142
+ assert torch.allclose(
143
+ allsum, torch.full_like(allsum, fill_value=expected_sum), atol=1e-2, rtol=0
144
+ ), "Naive integrity check failed. This could imply that some of the checkpoint files are corrupted."
145
+
146
+ if path_tuned is not None:
147
+ model_recovered.save_pretrained(path_tuned)
148
+ tokenizer_recovered.save_pretrained(path_tuned)
149
+
150
+ return model_recovered, tokenizer_recovered
151
+
152
+
153
+ def main(task, **kwargs):
154
+ globals()[task](**kwargs)
155
+
156
+
157
+ if __name__ == "__main__":
158
+ fire.Fire(main)