yuhuili commited on
Commit
687d97d
·
1 Parent(s): 06c0ba9

Upload 10 files

Browse files
model/__init__.py ADDED
File without changes
model/choices.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ mc_sim_7b_63 = [[0],[1],[2],[3],[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0]
2
+ ,[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,2,0],[0,2,1],[1,0,0],
3
+ [0,0,0,0],[0,0,0,1],[0,0,0,2],[0,0,0,0,0],[0,0,0,0,1]]
model/cnets.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import copy
22
+ import os
23
+ #os.environ["CUDA_VISIBLE_DEVICES"] = "5"
24
+ import math
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
+
33
+ from transformers.activations import ACT2FN
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
37
+ from transformers.utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ replace_return_docstrings,
42
+ )
43
+ try:
44
+ from .configs import EConfig
45
+ from .utils_c import *
46
+ from .choices import *
47
+ except:
48
+ from configs import EConfig
49
+ from utils_c import *
50
+ from choices import *
51
+ from utils import prepare_logits_processor
52
+ top_k=10
53
+
54
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
55
+ def _make_causal_mask(
56
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
57
+ ):
58
+ """
59
+ Make causal mask used for bi-directional self-attention.
60
+ """
61
+ bsz, tgt_len = input_ids_shape
62
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
63
+ mask_cond = torch.arange(mask.size(-1), device=device)
64
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
65
+ mask = mask.to(dtype)
66
+
67
+ if past_key_values_length > 0:
68
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
69
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
70
+
71
+
72
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
73
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
74
+ """
75
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
76
+ """
77
+ bsz, src_len = mask.size()
78
+ tgt_len = tgt_len if tgt_len is not None else src_len
79
+
80
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
81
+
82
+ inverted_mask = 1.0 - expanded_mask
83
+
84
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
85
+
86
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
87
+ """
88
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
89
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
90
+ """
91
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
92
+ if n_rep == 1:
93
+ return hidden_states
94
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
95
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
96
+
97
+ def rotate_half(x):
98
+ """Rotates half the hidden dims of the input."""
99
+ x1 = x[..., : x.shape[-1] // 2]
100
+ x2 = x[..., x.shape[-1] // 2 :]
101
+ return torch.cat((-x2, x1), dim=-1)
102
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
103
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
104
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
105
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
106
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
107
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
108
+ q_embed = (q * cos) + (rotate_half(q) * sin)
109
+ k_embed = (k * cos) + (rotate_half(k) * sin)
110
+ return q_embed, k_embed
111
+ class LlamaRotaryEmbedding(torch.nn.Module):
112
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
113
+ super().__init__()
114
+
115
+ self.dim = dim
116
+ self.max_position_embeddings = max_position_embeddings
117
+ self.base = base
118
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
119
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
120
+
121
+ # Build here to make `torch.jit.trace` work.
122
+ self._set_cos_sin_cache(
123
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
124
+ )
125
+
126
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
127
+ self.max_seq_len_cached = seq_len
128
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
129
+
130
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
131
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
+ emb = torch.cat((freqs, freqs), dim=-1)
133
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
134
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
135
+
136
+ def forward(self, x, seq_len=None):
137
+ # x: [bs, num_attention_heads, seq_len, head_size]
138
+ if seq_len > self.max_seq_len_cached:
139
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
140
+
141
+ return (
142
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
143
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
144
+ )
145
+
146
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
147
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
148
+
149
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
150
+ self.scaling_factor = scaling_factor
151
+ super().__init__(dim, max_position_embeddings, base, device)
152
+
153
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
154
+ self.max_seq_len_cached = seq_len
155
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
156
+ t = t / self.scaling_factor
157
+
158
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
159
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
162
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
163
+
164
+
165
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
166
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
167
+
168
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
169
+ self.scaling_factor = scaling_factor
170
+ super().__init__(dim, max_position_embeddings, base, device)
171
+
172
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
173
+ self.max_seq_len_cached = seq_len
174
+
175
+ if seq_len > self.max_position_embeddings:
176
+ base = self.base * (
177
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
178
+ ) ** (self.dim / (self.dim - 2))
179
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
180
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
181
+
182
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
183
+
184
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
185
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
186
+ emb = torch.cat((freqs, freqs), dim=-1)
187
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
188
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
189
+
190
+ class LlamaAttention(nn.Module):
191
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
192
+
193
+ def __init__(self, config):
194
+ super().__init__()
195
+ self.config = config
196
+ self.hidden_size = config.hidden_size
197
+ self.num_heads = config.num_attention_heads
198
+ self.head_dim = self.hidden_size // self.num_heads
199
+ self.num_key_value_heads = config.num_key_value_heads
200
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
201
+ self.max_position_embeddings = config.max_position_embeddings
202
+
203
+ if (self.head_dim * self.num_heads) != self.hidden_size:
204
+ raise ValueError(
205
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
206
+ f" and `num_heads`: {self.num_heads})."
207
+ )
208
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
209
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
210
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
211
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
212
+ self._init_rope()
213
+
214
+ def _init_rope(self):
215
+ if self.config.rope_scaling is None:
216
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
217
+ else:
218
+ scaling_type = self.config.rope_scaling["type"]
219
+ scaling_factor = self.config.rope_scaling["factor"]
220
+ if scaling_type == "linear":
221
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
222
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
223
+ )
224
+ elif scaling_type == "dynamic":
225
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
226
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
227
+ )
228
+ else:
229
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
230
+
231
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
232
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor,
237
+ attention_mask: Optional[torch.Tensor] = None,
238
+ position_ids: Optional[torch.LongTensor] = None,
239
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
240
+ output_attentions: bool = False,
241
+ use_cache: bool = False,
242
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
243
+ bsz, q_len, _ = hidden_states.size()
244
+
245
+ if self.config.pretraining_tp > 1:
246
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
247
+ query_slices = self.q_proj.weight.split(
248
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
249
+ )
250
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
251
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
252
+
253
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
254
+ query_states = torch.cat(query_states, dim=-1)
255
+
256
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
257
+ key_states = torch.cat(key_states, dim=-1)
258
+
259
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
260
+ value_states = torch.cat(value_states, dim=-1)
261
+
262
+ else:
263
+ query_states = self.q_proj(hidden_states)
264
+ key_states = self.k_proj(hidden_states)
265
+ value_states = self.v_proj(hidden_states)
266
+
267
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
268
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
269
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
270
+
271
+ kv_seq_len = key_states.shape[-2]
272
+ if past_key_value is not None:
273
+ kv_seq_len += past_key_value[0].shape[-2]
274
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
275
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
276
+
277
+ if past_key_value is not None:
278
+ # reuse k, v, self_attention
279
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
280
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
281
+
282
+ past_key_value = (key_states, value_states) if use_cache else None
283
+
284
+ # repeat k/v heads if n_kv_heads < n_heads
285
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
286
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
287
+
288
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
289
+
290
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
291
+ raise ValueError(
292
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
293
+ f" {attn_weights.size()}"
294
+ )
295
+
296
+ if attention_mask is not None:
297
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
298
+ raise ValueError(
299
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
300
+ )
301
+ attn_weights = attn_weights + attention_mask
302
+
303
+ # upcast attention to fp32
304
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
305
+ attn_output = torch.matmul(attn_weights, value_states)
306
+
307
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
308
+ raise ValueError(
309
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
310
+ f" {attn_output.size()}"
311
+ )
312
+
313
+ attn_output = attn_output.transpose(1, 2).contiguous()
314
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
315
+
316
+ if self.config.pretraining_tp > 1:
317
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
318
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
319
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
320
+ else:
321
+ attn_output = self.o_proj(attn_output)
322
+
323
+ if not output_attentions:
324
+ attn_weights = None
325
+
326
+ return attn_output, attn_weights, past_key_value
327
+
328
+
329
+ class LlamaMLP(nn.Module):
330
+ def __init__(self, config):
331
+ super().__init__()
332
+ self.config = config
333
+ self.hidden_size = config.hidden_size
334
+ self.intermediate_size = config.intermediate_size
335
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
336
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
337
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
338
+ self.act_fn = ACT2FN[config.hidden_act]
339
+
340
+ def forward(self, x):
341
+ if self.config.pretraining_tp > 1:
342
+ slice = self.intermediate_size // self.config.pretraining_tp
343
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
344
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
345
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
346
+
347
+ gate_proj = torch.cat(
348
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
349
+ )
350
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
351
+
352
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
353
+ down_proj = [
354
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
355
+ ]
356
+ down_proj = sum(down_proj)
357
+ else:
358
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
359
+
360
+ return down_proj
361
+
362
+ class LlamaRMSNorm(nn.Module):
363
+ def __init__(self, hidden_size, eps=1e-6):
364
+ """
365
+ LlamaRMSNorm is equivalent to T5LayerNorm
366
+ """
367
+ super().__init__()
368
+ self.weight = nn.Parameter(torch.ones(hidden_size))
369
+ self.variance_epsilon = eps
370
+
371
+ def forward(self, hidden_states):
372
+ input_dtype = hidden_states.dtype
373
+ hidden_states = hidden_states.to(torch.float32)
374
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
375
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
376
+ return self.weight * hidden_states.to(input_dtype)
377
+
378
+ class LlamaDecoderLayer(nn.Module):
379
+ def __init__(self, config,index):
380
+ super().__init__()
381
+ self.hidden_size = config.hidden_size
382
+ self.self_attn = LlamaAttention(config=config)
383
+ self.mlp = LlamaMLP(config)
384
+ self.index=index
385
+ if self.index!=0:
386
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
387
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
388
+
389
+ def forward(
390
+ self,
391
+ hidden_states: torch.Tensor,
392
+ attention_mask: Optional[torch.Tensor] = None,
393
+ position_ids: Optional[torch.LongTensor] = None,
394
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
395
+ output_attentions: Optional[bool] = False,
396
+ use_cache: Optional[bool] = False,
397
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
398
+ """
399
+ Args:
400
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
401
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
402
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
403
+ output_attentions (`bool`, *optional*):
404
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
405
+ returned tensors for more detail.
406
+ use_cache (`bool`, *optional*):
407
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
408
+ (see `past_key_values`).
409
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
410
+ """
411
+
412
+ residual = hidden_states
413
+
414
+ if self.index != 0:
415
+ hidden_states = self.input_layernorm(hidden_states)
416
+
417
+ # Self Attention
418
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
419
+ hidden_states=hidden_states,
420
+ attention_mask=attention_mask,
421
+ position_ids=position_ids,
422
+ past_key_value=past_key_value,
423
+ output_attentions=output_attentions,
424
+ use_cache=use_cache,
425
+ )
426
+ hidden_states = residual + hidden_states
427
+
428
+ # Fully Connected
429
+ residual = hidden_states
430
+ hidden_states = self.post_attention_layernorm(hidden_states)
431
+ hidden_states = self.mlp(hidden_states)
432
+ hidden_states = residual + hidden_states
433
+
434
+ outputs = (hidden_states,)
435
+
436
+ if output_attentions:
437
+ outputs += (self_attn_weights,)
438
+
439
+ if use_cache:
440
+ outputs += (present_key_value,)
441
+
442
+ return outputs
443
+
444
+ class I(nn.Module):
445
+ def __init__(self):
446
+ super().__init__()
447
+ self.dummy = nn.Parameter(torch.ones(1, dtype=torch.float32))
448
+ def forward(self,x):
449
+ return x + self.dummy - self.dummy #(also tried x+self.dummy)
450
+
451
+ def len_list(x,n):
452
+ return [i for i in x if len(i)<=n]
453
+
454
+ class Model(nn.Module):
455
+ def __init__(self,config,load_emb=False):
456
+ super().__init__()
457
+
458
+
459
+
460
+
461
+ self.gradient_checkpointing = True
462
+ self.padding_idx = config.pad_token_id
463
+ self.vocab_size = config.vocab_size
464
+
465
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
466
+ if load_emb:
467
+ from safetensors import safe_open
468
+ with safe_open("weights/llama2chat/13B/model-00001-of-00003.safetensors",
469
+ framework="pt",
470
+ device="cpu") as f:
471
+ tensor_slice = f.get_slice("model.embed_tokens.weight")
472
+ vocab_size, hidden_dim = tensor_slice.get_shape()
473
+ tensor = tensor_slice[:, :hidden_dim].float()
474
+ self.embed_tokens.weight.data = tensor
475
+
476
+ #self.init_tree()
477
+
478
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config,index) for index in range(config.num_hidden_layers)])
479
+ self.fc=nn.Linear(2*config.hidden_size,config.hidden_size)
480
+ self.act=ACT2FN[config.hidden_act]
481
+ for param in self.embed_tokens.parameters():
482
+ param.requires_grad = False
483
+
484
+
485
+ def init_tree(self):
486
+ self.tree = mc_sim_7b_63
487
+ self.tree_buffer=generate_tree_buffers(self.tree,self.embed_tokens.weight.device)
488
+
489
+
490
+ def reset(self):
491
+ self.tree_mask=None
492
+
493
+
494
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
495
+ # create causal mask
496
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
497
+ combined_attention_mask = None
498
+ if input_shape[-1] > 1:
499
+ combined_attention_mask = _make_causal_mask(
500
+ input_shape,
501
+ #inputs_embeds.dtype,
502
+ torch.float32, # [MODIFIED] force to cast to float32
503
+ device=inputs_embeds.device,
504
+ past_key_values_length=past_key_values_length,
505
+ )
506
+
507
+ if attention_mask is not None:
508
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
509
+ expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to(
510
+ inputs_embeds.device
511
+ )
512
+ combined_attention_mask = (
513
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
514
+ )
515
+
516
+ # [MODIFIED] add tree mask
517
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
518
+ tree_mask = self.tree_mask
519
+ tree_len = tree_mask.size(-1)
520
+ combined_attention_mask[:, :, -tree_len:, -tree_len:][
521
+ tree_mask == 0
522
+ ] = torch.finfo(torch.float32).min
523
+
524
+
525
+ return combined_attention_mask
526
+
527
+ def forward(
528
+ self,
529
+ hidden_states,
530
+ input_ids,
531
+ attention_mask: Optional[torch.Tensor] = None,
532
+ position_ids: Optional[torch.LongTensor] = None,
533
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
534
+ inputs_embeds: Optional[torch.FloatTensor] = None,
535
+ use_cache: Optional[bool] = None,
536
+ output_attentions: Optional[bool] = None,
537
+ output_hidden_states: Optional[bool] = None,
538
+ return_dict: Optional[bool] = None,
539
+ std=None
540
+ ):
541
+ batch_size, seq_length, _ = hidden_states.shape
542
+ seq_length_with_past = seq_length
543
+ past_key_values_length = 0
544
+
545
+ with torch.no_grad():
546
+ inputs_embeds = self.embed_tokens(input_ids)
547
+ #inputs_embeds = inputs_embeds.detach()
548
+
549
+ # if std is not None:
550
+ # noise = torch.randn(inputs_embeds.size(),device=inputs_embeds.device) * std
551
+ # inputs_embeds=inputs_embeds+noise
552
+
553
+ if past_key_values is not None:
554
+ past_key_values_length = past_key_values[0][0].shape[2]
555
+ seq_length_with_past = seq_length_with_past + past_key_values_length
556
+ if position_ids is None:
557
+ device = hidden_states.device if hidden_states is not None else inputs_embeds.device
558
+ position_ids = torch.arange(
559
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
560
+ )
561
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
562
+ else:
563
+ position_ids = position_ids.view(-1, seq_length).long()
564
+
565
+ if attention_mask is None:
566
+ attention_mask = torch.ones(
567
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
568
+ )
569
+ attention_mask = self._prepare_decoder_attention_mask(
570
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
571
+ )
572
+
573
+ if self.gradient_checkpointing and self.training:
574
+ if use_cache:
575
+ use_cache = False
576
+
577
+
578
+ #hidden_states=self.act(self.fc(torch.cat((inputs_embeds,hidden_states),dim=-1)))
579
+ inputs_embeds=inputs_embeds.to(hidden_states.dtype)
580
+ hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
581
+
582
+
583
+ all_hidden_states = () if output_hidden_states else None
584
+ next_decoder_cache = () if use_cache else None
585
+
586
+ for idx, decoder_layer in enumerate(self.layers):
587
+ if output_hidden_states:
588
+ all_hidden_states += (hidden_states,)
589
+
590
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
591
+
592
+ if self.gradient_checkpointing and self.training:
593
+
594
+ def create_custom_forward(module):
595
+ def custom_forward(*inputs):
596
+ # None for past_key_value
597
+ return module(*inputs, past_key_value, output_attentions)
598
+
599
+ return custom_forward
600
+
601
+ layer_outputs = torch.utils.checkpoint.checkpoint(
602
+ create_custom_forward(decoder_layer),
603
+ hidden_states,
604
+ attention_mask,
605
+ position_ids,
606
+ )
607
+ else:
608
+ layer_outputs = decoder_layer(
609
+ hidden_states,
610
+ attention_mask=attention_mask,
611
+ position_ids=position_ids,
612
+ past_key_value=past_key_value,
613
+ output_attentions=output_attentions,
614
+ use_cache=use_cache,
615
+ )
616
+
617
+ hidden_states = layer_outputs[0]
618
+
619
+ if use_cache:
620
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
621
+
622
+ if use_cache:
623
+ return hidden_states,next_decoder_cache
624
+
625
+ return hidden_states
626
+
627
+ @torch.no_grad()
628
+ def generate(self,hidden_states,input_ids,head,max_length=4,use_cache=False):
629
+ return_input_ids=copy.deepcopy(input_ids[0].tolist())
630
+ input_ids=input_ids[:,1:]
631
+
632
+ #input_ids=input_ids.to(hidden_states.device)
633
+ if use_cache:
634
+ past_key_values=None
635
+ for i in range(max_length):
636
+ if past_key_values!=None:
637
+ out_hidden,past_key_values = self(out_hidden[:, -1:], input_ids=torch.tensor([[token]]).to(input_ids.device),past_key_values=past_key_values,use_cache=True)
638
+ else:
639
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids,use_cache=True)
640
+ last_hidden = out_hidden[:, -1]
641
+ last_headout = head(last_hidden)
642
+ token = torch.argmax(last_headout)
643
+ #input_ids = torch.cat((input_ids, torch.tensor([[token]]).to(input_ids.device)), dim=1)
644
+ return_input_ids.append(token.item())
645
+ if token == 2:
646
+ break
647
+ #hidden_states = torch.cat((hidden_states, out_hidden[:, -1:]), dim=1)
648
+ else:
649
+ for i in range(max_length):
650
+ out_hidden=self(hidden_states,input_ids=input_ids)
651
+ last_hidden = out_hidden[:, -1]
652
+ last_headout = head(last_hidden)
653
+ token = torch.argmax(last_headout)
654
+ return_input_ids.append(token.item())
655
+ input_ids = torch.cat((input_ids, torch.tensor([[token]]).to(input_ids.device)), dim=1)
656
+ if token==2:
657
+ break
658
+ hidden_states = torch.cat((hidden_states, out_hidden[:, -1:]), dim=1)
659
+
660
+ return return_input_ids
661
+
662
+ @torch.no_grad()
663
+ def repeat_kv(self,kv,numr):
664
+ newkv=[]
665
+ for i in kv:
666
+ newkv.append((i[0].repeat(numr,1,1,1),i[1].repeat(numr,1,1,1)))
667
+ return tuple(newkv)
668
+
669
+ @torch.no_grad()
670
+ def reduce_kv(self,kv,numr):
671
+ newkv=[]
672
+ for i in kv:
673
+ newkv.append((i[0][:numr],i[1][:numr]))
674
+ return tuple(newkv)
675
+
676
+
677
+ def reset_kv(self):
678
+ self.stable_kv=None
679
+
680
+ @torch.no_grad()
681
+ def topK_genrate_batch(self,hidden_states,input_ids,head,max_length=4,use_cache=True):
682
+ #input_ids = torch.tensor([state[1:]])
683
+ input_ids = input_ids[:, 1:]
684
+ input_ids = input_ids.to(hidden_states.device)
685
+ sslogits=[]
686
+ self.reset()
687
+ if use_cache:
688
+
689
+
690
+
691
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids,use_cache=True)
692
+ last_hidden = out_hidden[:, -1]
693
+ last_headout = head(last_hidden)
694
+ sslogits.append(last_headout)
695
+ topk_index = torch.topk(last_headout, 3, dim=-1).indices
696
+
697
+ # hidden_states = torch.cat((hidden_states, out_hidden[:, -1:]), dim=1)
698
+ hidden_states = out_hidden[:, -1:]
699
+ hidden_states = hidden_states.repeat(3, 1, 1)
700
+ #input_ids = input_ids.repeat(3, 1)
701
+ input_ids = topk_index.t()
702
+ past_key_values = self.repeat_kv(past_key_values,3)
703
+ out_hidden,past_key_values = self(hidden_states, input_ids=input_ids,past_key_values=past_key_values,use_cache=True)
704
+ last_hidden = out_hidden[:, -1]
705
+ last_headout = head(last_hidden)
706
+ sslogits.append(last_headout)
707
+
708
+ hidden_states = out_hidden[0:1, -1:]
709
+ #input_ids = input_ids[:1]
710
+ topk_index = torch.topk(last_headout[:1], 3, dim=-1).indices
711
+ #hidden_states = torch.cat((hidden_states, out_hidden[0:1, -1:]), dim=1)
712
+ hidden_states = hidden_states.repeat(3, 1, 1)
713
+ #input_ids = input_ids.repeat(3, 1)
714
+ input_ids = topk_index.t()
715
+ out_hidden,past_key_values = self(hidden_states, input_ids=input_ids,past_key_values=past_key_values,use_cache=True)
716
+ last_hidden = out_hidden[:, -1]
717
+ last_headout = head(last_hidden)
718
+ sslogits.append(last_headout)
719
+
720
+ #hidden_states = hidden_states[:1]
721
+ #input_ids = input_ids[:1]
722
+ topk_index = torch.topk(last_headout[:1], 3, dim=-1).indices
723
+ hidden_states = out_hidden[0:1, -1:]
724
+ input_ids = topk_index[:, :1]
725
+ past_key_values=self.reduce_kv(past_key_values,1)
726
+ out_hidden,past_key_values = self(hidden_states, input_ids=input_ids,past_key_values=past_key_values,use_cache=True)
727
+ last_hidden = out_hidden[:, -1]
728
+ last_headout = head(last_hidden)
729
+ sslogits.append(last_headout)
730
+ else:
731
+ out_hidden = self(hidden_states, input_ids=input_ids)
732
+ last_hidden = out_hidden[:, -1]
733
+ last_headout = head(last_hidden)
734
+ sslogits.append(last_headout)
735
+ topk_index=torch.topk(last_headout, 3, dim=-1).indices
736
+
737
+ hidden_states = torch.cat((hidden_states, out_hidden[:, -1:]), dim=1)
738
+ hidden_states=hidden_states.repeat(3,1,1)
739
+ input_ids=input_ids.repeat(3,1)
740
+ input_ids=torch.cat((input_ids,topk_index.t()),dim=-1)
741
+ out_hidden = self(hidden_states, input_ids=input_ids)
742
+ last_hidden = out_hidden[:, -1]
743
+ last_headout = head(last_hidden)
744
+ sslogits.append(last_headout)
745
+
746
+ hidden_states=hidden_states[:1]
747
+ input_ids=input_ids[:1]
748
+ topk_index = torch.topk(last_headout[:1], 3, dim=-1).indices
749
+ hidden_states = torch.cat((hidden_states, out_hidden[0:1, -1:]), dim=1)
750
+ hidden_states = hidden_states.repeat(3, 1, 1)
751
+ input_ids = input_ids.repeat(3, 1)
752
+ input_ids = torch.cat((input_ids, topk_index.t()), dim=-1)
753
+ out_hidden = self(hidden_states, input_ids=input_ids)
754
+ last_hidden = out_hidden[:, -1]
755
+ last_headout = head(last_hidden)
756
+ sslogits.append(last_headout)
757
+
758
+ hidden_states = hidden_states[:1]
759
+ input_ids = input_ids[:1]
760
+ topk_index = torch.topk(last_headout[:1], 3, dim=-1).indices
761
+ hidden_states = torch.cat((hidden_states, out_hidden[0:1, -1:]), dim=1)
762
+ input_ids = torch.cat((input_ids, topk_index[:,:1]), dim=-1)
763
+ out_hidden = self(hidden_states, input_ids=input_ids)
764
+ last_hidden = out_hidden[:, -1]
765
+ last_headout = head(last_hidden)
766
+ sslogits.append(last_headout)
767
+
768
+ return torch.cat(sslogits)
769
+
770
+ @torch.no_grad()
771
+ def repeat_hidden(self,hidden_state,repeat_num):
772
+ new_hidden=[]
773
+ for id,i in enumerate(repeat_num):
774
+ new_hidden.append(hidden_state[:,id:id+1].repeat(1,i,1))
775
+ return torch.cat(new_hidden,dim=1)
776
+
777
+ @torch.no_grad()
778
+ def sample(self,tensor,k=1,replacement=True):
779
+ probabilities = torch.nn.functional.softmax(tensor, dim=1)
780
+ sampled_indices = torch.multinomial(probabilities, k,replacement=replacement)
781
+ sampled_probs = torch.gather(probabilities, 1, sampled_indices)
782
+
783
+ return sampled_indices,sampled_probs
784
+
785
+ @torch.no_grad()
786
+ def topK_genrate(self, hidden_states, input_ids, head, logits_processor,max_length=4, use_cache=True):
787
+ # test_=input_ids
788
+ # input_ids = torch.tensor([state[1:]])
789
+ input_ids = input_ids[:, 1:]
790
+ input_ids = input_ids.to(hidden_states.device)
791
+ ss_token,ss_prob = [],[]
792
+ len_posi=input_ids.shape[1]
793
+ self.reset()
794
+ if use_cache:
795
+
796
+
797
+ if hasattr(self, "stable_kv") and self.stable_kv is not None:
798
+ kv_len=self.stable_kv[0][0].shape[2]
799
+ out_hidden, past_key_values = self(hidden_states[:,kv_len:], input_ids=input_ids[:,kv_len:], past_key_values=self.stable_kv,use_cache=True)
800
+ else:
801
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, use_cache=True)
802
+ self.stable_kv=past_key_values
803
+ last_hidden = out_hidden[:, -1]
804
+ last_headout = head(last_hidden)
805
+
806
+
807
+
808
+ for i in range(len(self.tree_buffer['tree_indices'])):
809
+ if logits_processor is not None:
810
+ topk_index,topk_prob=self.sample(last_headout,top_k)
811
+ else:
812
+ topk_index,topk_prob = torch.topk(last_headout, top_k, dim=-1).indices,torch.topk(last_headout, top_k, dim=-1).values
813
+
814
+ ss_token.append(topk_index)
815
+ ss_prob.append(topk_prob)
816
+ #topk_index = torch.topk(last_headout, top_k, dim=-1).indices
817
+ topk_index = topk_index.view(-1)
818
+ select_index=topk_index[self.tree_buffer['tree_indices'][i]]
819
+ #len_sq=select_index.shape[0]
820
+ input_ids=select_index[None,:]
821
+ if i==0:
822
+ hidden_states = out_hidden[:, -1:]
823
+ else:
824
+ hidden_states=out_hidden
825
+ hidden_states=self.repeat_hidden(hidden_states,self.tree_buffer["repeat_nums"][i])
826
+ #hidden_states = hidden_states.repeat(1,len_sq,1)
827
+ self.tree_mask=self.tree_buffer['attn_mask'][i]
828
+ position_ids=len_posi+self.tree_buffer["position_ids"][i]
829
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, past_key_values=past_key_values,
830
+ position_ids=position_ids,use_cache=True)
831
+ len_posi += 1
832
+
833
+
834
+ last_headout = head(out_hidden[0])
835
+ #sslogits.append(last_headout)
836
+ #print(select_index)
837
+ topk_index, topk_prob = self.sample(last_headout, top_k)
838
+ ss_token.append(topk_index)
839
+ ss_prob.append(topk_prob)
840
+
841
+ else:
842
+ # TODO
843
+ pass
844
+
845
+ return (torch.cat(ss_token),torch.cat(ss_prob))
846
+
847
+
848
+
849
+
850
+ @torch.no_grad()
851
+ def acc(self,data,head,max_length=5):
852
+ hidden_states=data["hidden_states"]
853
+ input_ids=data["input_ids"]
854
+ #attention_mask=data["attention_mask"]
855
+ loss_mask=data["loss_mask"]
856
+ sample_mask=data["sample_mask"]
857
+ target=data["target"]
858
+ total=[0 for _ in range(max_length)]
859
+ correct=[0 for _ in range(max_length)]
860
+ bs,sl=hidden_states.shape[0],hidden_states.shape[1]
861
+ target_headout = head(target)
862
+ hidden_states_headout=head(hidden_states)
863
+
864
+ for i in range(bs):
865
+ for j in range(sl):
866
+ if loss_mask[i,j]==0:
867
+ continue
868
+ single_hidden_states=hidden_states[i,:j]
869
+ single_input_ids=input_ids[i,:j]
870
+
871
+
872
+ single_hidden_states = single_hidden_states[None, :, :]
873
+ single_input_ids = single_input_ids[None, :]
874
+ for k in range(max_length):
875
+ tmp_in_target_headout = hidden_states_headout[i,single_hidden_states.shape[1]-1]
876
+ tmp_out_target_headout = target_headout[i, single_hidden_states.shape[1]-1]
877
+ target_in_token = torch.argmax(tmp_in_target_headout)
878
+ target_out_token = torch.argmax(tmp_out_target_headout)
879
+ tmp_token=input_ids[i,single_hidden_states.shape[1]-1]
880
+ tmp_sample_mask=sample_mask[i,single_hidden_states.shape[1]-1]
881
+ if not (target_in_token==tmp_token):
882
+ break
883
+ out_hidden = self(single_hidden_states, input_ids=single_input_ids)
884
+ last_hidden = out_hidden[:, -1]
885
+ last_headout = head(last_hidden)
886
+ token = torch.argmax(last_headout)
887
+ total[k] += 1
888
+ if token==target_out_token:
889
+ correct[k]+=1
890
+ else:
891
+ for kk in range(k,max_length):
892
+ total[kk]+=1
893
+ break
894
+
895
+ single_hidden_states=torch.cat((single_hidden_states,out_hidden[:,-1:]),dim=1)
896
+ single_input_ids = torch.cat((single_input_ids, torch.tensor([[token]]).to(single_input_ids.device)), dim=1)
897
+
898
+
899
+ acc=[correct[i]/total[i] for i in range(len(correct))]
900
+ return acc
901
+
902
+
903
+
904
+
905
+
906
+ class Vhead(nn.Module):
907
+ def __init__(self,ins=6566,outs=32000):
908
+ super().__init__()
909
+ self.fc = nn.Linear(ins,outs,bias=False)
910
+ def forward(self,x):
911
+ return self.fc(x)
912
+
913
+
914
+
915
+ import torch
916
+
917
+ def count_parameters(model):
918
+ return sum(p.numel() for p in model.parameters())
919
+
920
+
921
+
922
+
923
+ if __name__=='__main__':
924
+ import time
925
+ config=EConfig.from_pretrained('config.json')
926
+ model=Model(config)
927
+ model.cuda()
928
+ model.init_tree()
929
+ #model.half()
930
+ model.eval()
931
+ total_params = count_parameters(model)
932
+ print(f"总参数量: {total_params}")
933
+ head = torch.nn.Linear(6656, 32000, bias=False)
934
+ head.load_state_dict(torch.load("/home/lyh/code/nlp/ess/transhead_embeding_long/headf32.ckpt"))
935
+ head.cuda()
936
+ head.eval()
937
+ logits_processor=prepare_logits_processor()
938
+ with torch.no_grad():
939
+ ins=torch.randn(1,499,6656)
940
+ input_ids=torch.tensor([[29915,385,299,365,395]*100])
941
+ attention_mask=torch.tensor([[1,0,1]])
942
+ out0 = model.generate(ins.cuda(), input_ids.cuda(), head, use_cache=True)
943
+ out = model.generate(ins.cuda(), input_ids.cuda(), head, use_cache=False)
944
+ #model(ins,input_ids=input_ids,attention_mask=attention_mask)
945
+ outk=model.topK_genrate_batch(ins.cuda(),input_ids.cuda(),head,use_cache=False)
946
+ model.reset_kv()
947
+ outkcache = model.topK_genrate(ins[:,:400].cuda(), input_ids[:,:401].cuda(), head, None,use_cache=True)
948
+ outkcache1 = model.topK_genrate(ins.cuda(), input_ids.cuda(), head, use_cache=True)
949
+ #s=time.time()
950
+ # out0 = model.generate(ins.cuda(), input_ids.cuda(), head,use_cache=True)
951
+ # outs,past_key_values=model(ins[:,:2],input_ids=input_ids[:,:2],use_cache=True)
952
+ # outs1, past_key_values1 = model(ins[:,2:], input_ids=input_ids[:,2:], use_cache=True,past_key_values=past_key_values)
953
+ # outs0, past_key_values0 = model(ins, input_ids=input_ids, use_cache=True)
954
+ #print(time.time()-s)
955
+ for _ in range(10):
956
+ s = time.time()
957
+ outk=model.topK_genrate_batch(ins.cuda(),input_ids.cuda(),head,use_cache=True)
958
+ print(time.time() - s)
959
+
960
+ print('---'*10)
961
+
962
+ for _ in range(10):
963
+ s = time.time()
964
+ outk=model.topK_genrate(ins.cuda(),input_ids.cuda(),head,use_cache=True)
965
+ print(time.time() - s)
model/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "bos_token_id": 1,
6
+ "eos_token_id": 2,
7
+ "hidden_act": "silu",
8
+ "hidden_size": 6656,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 17920,
11
+ "max_sequence_length": 2048,
12
+ "model_type": "llama",
13
+ "num_attention_heads": 52,
14
+ "num_key_value_heads": 13,
15
+ "num_hidden_layers": 1,
16
+ "pad_token_id": 0,
17
+ "rms_norm_eps": 1e-06,
18
+ "tie_word_embeddings": false,
19
+ "torch_dtype": "float16",
20
+ "transformers_version": "4.28.0.dev0",
21
+ "use_cache": true,
22
+ "vocab_size": 32000
23
+ }
model/configs.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ class EConfig(PretrainedConfig):
3
+ r"""
4
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
5
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
6
+ defaults will yield a similar configuration to that of the LLaMA-7B.
7
+
8
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
9
+ documentation from [`PretrainedConfig`] for more information.
10
+
11
+
12
+ Args:
13
+ vocab_size (`int`, *optional*, defaults to 32000):
14
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
15
+ `inputs_ids` passed when calling [`LlamaModel`]
16
+ hidden_size (`int`, *optional*, defaults to 4096):
17
+ Dimension of the hidden representations.
18
+ intermediate_size (`int`, *optional*, defaults to 11008):
19
+ Dimension of the MLP representations.
20
+ num_hidden_layers (`int`, *optional*, defaults to 32):
21
+ Number of hidden layers in the Transformer encoder.
22
+ num_attention_heads (`int`, *optional*, defaults to 32):
23
+ Number of attention heads for each attention layer in the Transformer encoder.
24
+ num_key_value_heads (`int`, *optional*):
25
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
26
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
27
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
28
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
29
+ by meanpooling all the original heads within that group. For more details checkout [this
30
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
31
+ `num_attention_heads`.
32
+ pretraining_tp (`int`, *optional*, defaults to `1`):
33
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
34
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
35
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
36
+ issue](https://github.com/pytorch/pytorch/issues/76232).
37
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
38
+ The non-linear activation function (function or string) in the decoder.
39
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
40
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
41
+ just in case (e.g., 512 or 1024 or 2048).
42
+ initializer_range (`float`, *optional*, defaults to 0.02):
43
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
44
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
45
+ The epsilon used by the rms normalization layers.
46
+ use_cache (`bool`, *optional*, defaults to `True`):
47
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
48
+ relevant if `config.is_decoder=True`.
49
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
50
+ Whether to tie weight embeddings
51
+ rope_scaling (`Dict`, *optional*):
52
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
53
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
54
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
55
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
56
+ these scaling strategies behave:
57
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
58
+ experimental feature, subject to breaking API changes in future versions.
59
+
60
+ Example:
61
+
62
+ ```python
63
+ >>> from transformers import LlamaModel, LlamaConfig
64
+
65
+ >>> # Initializing a LLaMA llama-7b style configuration
66
+ >>> configuration = LlamaConfig()
67
+
68
+ >>> # Initializing a model from the llama-7b style configuration
69
+ >>> model = LlamaModel(configuration)
70
+
71
+ >>> # Accessing the model configuration
72
+ >>> configuration = model.config
73
+ ```"""
74
+ model_type = "llama"
75
+ keys_to_ignore_at_inference = ["past_key_values"]
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_size=32000,
80
+ hidden_size=4096,
81
+ intermediate_size=11008,
82
+ num_hidden_layers=32,
83
+ num_attention_heads=32,
84
+ num_key_value_heads=None,
85
+ hidden_act="silu",
86
+ max_position_embeddings=2048,
87
+ initializer_range=0.02,
88
+ rms_norm_eps=1e-6,
89
+ use_cache=True,
90
+ pad_token_id=None,
91
+ bos_token_id=1,
92
+ eos_token_id=2,
93
+ pretraining_tp=1,
94
+ tie_word_embeddings=False,
95
+ rope_scaling=None,
96
+ **kwargs,
97
+ ):
98
+ self.vocab_size = vocab_size
99
+ self.max_position_embeddings = max_position_embeddings
100
+ self.hidden_size = hidden_size
101
+ self.intermediate_size = intermediate_size
102
+ self.num_hidden_layers = num_hidden_layers
103
+ self.num_attention_heads = num_attention_heads
104
+
105
+ # for backward compatibility
106
+ if num_key_value_heads is None:
107
+ num_key_value_heads = num_attention_heads
108
+
109
+ self.num_key_value_heads = num_key_value_heads
110
+ self.hidden_act = hidden_act
111
+ self.initializer_range = initializer_range
112
+ self.rms_norm_eps = rms_norm_eps
113
+ self.pretraining_tp = pretraining_tp
114
+ self.use_cache = use_cache
115
+ self.rope_scaling = rope_scaling
116
+ self._rope_scaling_validation()
117
+
118
+ super().__init__(
119
+ pad_token_id=pad_token_id,
120
+ bos_token_id=bos_token_id,
121
+ eos_token_id=eos_token_id,
122
+ tie_word_embeddings=tie_word_embeddings,
123
+ **kwargs,
124
+ )
125
+
126
+ def _rope_scaling_validation(self):
127
+ """
128
+ Validate the `rope_scaling` configuration.
129
+ """
130
+ if self.rope_scaling is None:
131
+ return
132
+
133
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
134
+ raise ValueError(
135
+ "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
136
+ f"got {self.rope_scaling}"
137
+ )
138
+ rope_scaling_type = self.rope_scaling.get("type", None)
139
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
140
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
141
+ raise ValueError(
142
+ f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
143
+ )
144
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
145
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
model/ea_model.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
5
+ from .utils import *
6
+ from .kv_cache import initialize_past_key_values
7
+ from .choices import mc_sim_7b_63
8
+ from transformers import AutoTokenizer
9
+ import os
10
+ from huggingface_hub import hf_hub_download
11
+ from .cnets import Model
12
+ from .configs import EConfig
13
+
14
+
15
+
16
+
17
+
18
+ class ResBlock(nn.Module):
19
+ """
20
+ A Residual Block module.
21
+
22
+ This module performs a linear transformation followed by a SiLU activation,
23
+ and then adds the result to the original input, creating a residual connection.
24
+
25
+ Args:
26
+ hidden_size (int): The size of the hidden layers in the block.
27
+ """
28
+
29
+ def __init__(self, hidden_size):
30
+ super().__init__()
31
+ self.linear = nn.Linear(hidden_size, hidden_size)
32
+ # Initialize as an identity mapping
33
+ torch.nn.init.zeros_(self.linear.weight)
34
+ # Use SiLU activation to keep consistent with the Llama model
35
+ self.act = nn.SiLU()
36
+
37
+ def forward(self, x):
38
+ """
39
+ Forward pass of the ResBlock.
40
+
41
+ Args:
42
+ x (torch.Tensor): Input tensor.
43
+
44
+ Returns:
45
+ torch.Tensor: Output after the residual connection and activation.
46
+ """
47
+ return x + self.act(self.linear(x))
48
+
49
+
50
+ class EaModel(nn.Module):
51
+
52
+
53
+ def __init__(
54
+ self,
55
+ base_model,
56
+ base_model_name_or_path,
57
+ ea_model_path,
58
+ ):
59
+
60
+ super().__init__()
61
+ self.base_model = base_model
62
+ self.config = base_model.config
63
+ self.hidden_size = base_model.lm_head.weight.shape[-1]
64
+ self.vocab_size = base_model.lm_head.weight.shape[0]
65
+ self.base_model_name_or_path = base_model_name_or_path
66
+ self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
67
+ config = EConfig.from_pretrained(ea_model_path)
68
+ self.ea_layer = Model(config)
69
+
70
+
71
+ device = base_model.model.layers[-1].self_attn.q_proj.weight.device
72
+ self.ea_layer.to(self.base_model.dtype).to(device)
73
+ self.ea_layer.init_tree()
74
+
75
+
76
+
77
+ def get_tokenizer(self):
78
+ """Get the tokenizer of the base model.
79
+
80
+ Returns:
81
+ Tokenizer: The tokenizer of the base model.
82
+ """
83
+ return self.tokenizer
84
+
85
+ @classmethod
86
+ def from_pretrained(
87
+ cls,
88
+ base_model_path=None,
89
+ ea_model_path=None,
90
+ **kwargs,
91
+ ):
92
+
93
+
94
+
95
+ base_model = KVLlamaForCausalLM.from_pretrained(
96
+ base_model_path, **kwargs
97
+ )
98
+
99
+ model = cls(
100
+ base_model,
101
+ base_model_path,
102
+ ea_model_path
103
+ )
104
+
105
+ ea_layer_state_dict = torch.load(os.path.join(ea_model_path,"pytorch_model.bin"), map_location=base_model.device)
106
+ model.ea_layer.load_state_dict(ea_layer_state_dict, strict=False)
107
+
108
+ return model
109
+
110
+ def forward(
111
+ self,
112
+ input_ids=None,
113
+ attention_mask=None,
114
+ labels=None,
115
+ past_key_values=None,
116
+ output_orig=False,
117
+ position_ids=None,
118
+ init=True,
119
+ logits_processor=None
120
+ ):
121
+
122
+
123
+ with torch.inference_mode():
124
+ # Pass input through the base model
125
+ outputs = self.base_model.model(
126
+ input_ids=input_ids,
127
+ attention_mask=attention_mask,
128
+ past_key_values=past_key_values,
129
+ position_ids=position_ids,
130
+ )
131
+ if output_orig:
132
+ orig = self.base_model.lm_head(outputs[0])
133
+ hidden_states = outputs[0].clone()
134
+ if init:
135
+ if logits_processor is not None:
136
+ logits=orig[:, -1]
137
+ logits=logits_processor(None,logits)
138
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
139
+ token=torch.multinomial(probabilities, 1)
140
+ else:
141
+ token = torch.argmax(orig[:,-1])
142
+ token=token[None,None]
143
+ input_ids=torch.cat((input_ids,token.to(input_ids.device)),dim=1)
144
+ # Clone the output hidden states
145
+
146
+ ea_logits = self.ea_layer.topK_genrate(hidden_states,input_ids,self.base_model.lm_head,logits_processor)
147
+ if output_orig:
148
+ return ea_logits, outputs, orig,hidden_states,token
149
+ return ea_logits,hidden_states,token
150
+ else:
151
+ if output_orig:
152
+ return outputs,orig,hidden_states
153
+
154
+ @torch.no_grad()
155
+ def eagenerate(
156
+ self,
157
+ input_ids,
158
+ temperature=0.0,
159
+ top_p=0.0,
160
+ top_k=0.0,
161
+ max_new_tokens=512,
162
+ max_length=2048,
163
+ tree_choices=mc_sim_7b_63,
164
+
165
+ ):
166
+ if temperature>1e-5:
167
+ logits_processor=prepare_logits_processor(temperature=temperature,top_p=top_p,top_k=top_k)
168
+ else:
169
+ logits_processor=None
170
+ assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
171
+ # Avoid modifying the input_ids in-place
172
+ input_ids = input_ids.clone()
173
+ self.ea_layer.reset_kv()
174
+
175
+ if hasattr(self, "tree_choices") and self.tree_choices == tree_choices:
176
+ tree_buffers = self.tree_buffers
177
+ else:
178
+ tree_buffers = generate_tree_buffers(
179
+ tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device
180
+ )
181
+ self.tree_buffers = tree_buffers
182
+ self.tree_choices = tree_choices
183
+
184
+ # Initialize the past key and value states
185
+ if hasattr(self, "past_key_values"):
186
+ past_key_values = self.past_key_values
187
+ past_key_values_data = self.past_key_values_data
188
+ current_length_data = self.current_length_data
189
+ # Reset the past key and value states
190
+ current_length_data.zero_()
191
+ else:
192
+ (
193
+ past_key_values,
194
+ past_key_values_data,
195
+ current_length_data,
196
+ ) = initialize_past_key_values(self.base_model)
197
+ self.past_key_values = past_key_values
198
+ self.past_key_values_data = past_key_values_data
199
+ self.current_length_data = current_length_data
200
+
201
+ input_len = input_ids.shape[1]
202
+ reset_tree_mode(self)
203
+ tree_logits, logits, hidden_state, sample_token = initialize_tree(
204
+ input_ids, self, tree_buffers["tree_attn_mask"], past_key_values, logits_processor
205
+ )
206
+ new_token = 0
207
+
208
+ for idx in range(max_length):
209
+ candidates, cart_candidates_prob, tree_candidates = generate_candidates(
210
+ tree_logits,
211
+ tree_buffers["tree_indices"],
212
+ tree_buffers["retrieve_indices"],
213
+ sample_token,
214
+ logits_processor
215
+ )
216
+ logits, hidden_state_new, outputs = tree_decoding(
217
+ self,
218
+ tree_candidates,
219
+ past_key_values,
220
+ tree_buffers["tree_position_ids"],
221
+ input_ids,
222
+ tree_buffers["retrieve_indices"],
223
+ )
224
+ best_candidate, accept_length, sample_p = evaluate_posterior(
225
+ logits, candidates, logits_processor, cart_candidates_prob
226
+ )
227
+ input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs(
228
+ input_ids,
229
+ candidates,
230
+ best_candidate,
231
+ accept_length,
232
+ tree_buffers["retrieve_indices"],
233
+ logits_processor,
234
+ logits,
235
+ tree_logits,
236
+ new_token,
237
+ past_key_values_data,
238
+ current_length_data,
239
+ self,
240
+ hidden_state,
241
+ hidden_state_new,
242
+ sample_p
243
+ )
244
+
245
+
246
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
247
+ return input_ids
248
+ if new_token > max_new_tokens:
249
+ return input_ids
250
+ if input_ids.shape[1] > max_length:
251
+ return input_ids
252
+
253
+ @torch.no_grad()
254
+ def ea_generate(
255
+ self,
256
+ input_ids,
257
+ temperature=0.0,
258
+ top_p=0.0,
259
+ top_k=0.0,
260
+ max_steps=512,
261
+ tree_choices=mc_sim_7b_63,
262
+
263
+ ):
264
+ if temperature > 1e-5:
265
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
266
+ else:
267
+ logits_processor = None
268
+ assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
269
+ # Avoid modifying the input_ids in-place
270
+ input_ids = input_ids.clone()
271
+ self.ea_layer.reset_kv()
272
+
273
+ if hasattr(self, "tree_choices") and self.tree_choices == tree_choices:
274
+ tree_buffers = self.tree_buffers
275
+ else:
276
+ tree_buffers = generate_tree_buffers(
277
+ tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device
278
+ )
279
+ self.tree_buffers = tree_buffers
280
+ self.tree_choices = tree_choices
281
+
282
+ # Initialize the past key and value states
283
+ if hasattr(self, "past_key_values"):
284
+ past_key_values = self.past_key_values
285
+ past_key_values_data = self.past_key_values_data
286
+ current_length_data = self.current_length_data
287
+ # Reset the past key and value states
288
+ current_length_data.zero_()
289
+ else:
290
+ (
291
+ past_key_values,
292
+ past_key_values_data,
293
+ current_length_data,
294
+ ) = initialize_past_key_values(self.base_model)
295
+ self.past_key_values = past_key_values
296
+ self.past_key_values_data = past_key_values_data
297
+ self.current_length_data = current_length_data
298
+
299
+ input_len = input_ids.shape[1]
300
+ reset_tree_mode(self)
301
+ tree_logits, logits, hidden_state, sample_token = initialize_tree(
302
+ input_ids, self, tree_buffers["tree_attn_mask"], past_key_values, logits_processor
303
+ )
304
+ new_token = 0
305
+
306
+ for idx in range(max_steps):
307
+ candidates, cart_candidates_prob, tree_candidates = generate_candidates(
308
+ tree_logits,
309
+ tree_buffers["tree_indices"],
310
+ tree_buffers["retrieve_indices"],
311
+ sample_token,
312
+ logits_processor
313
+ )
314
+ logits, hidden_state_new, outputs = tree_decoding(
315
+ self,
316
+ tree_candidates,
317
+ past_key_values,
318
+ tree_buffers["tree_position_ids"],
319
+ input_ids,
320
+ tree_buffers["retrieve_indices"],
321
+ )
322
+ best_candidate, accept_length, sample_p = evaluate_posterior(
323
+ logits, candidates, logits_processor, cart_candidates_prob
324
+ )
325
+ input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs(
326
+ input_ids,
327
+ candidates,
328
+ best_candidate,
329
+ accept_length,
330
+ tree_buffers["retrieve_indices"],
331
+ logits_processor,
332
+ logits,
333
+ tree_logits,
334
+ new_token,
335
+ past_key_values_data,
336
+ current_length_data,
337
+ self,
338
+ hidden_state,
339
+ hidden_state_new,
340
+ sample_p
341
+ )
342
+
343
+ yield input_ids
344
+
345
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
346
+ break
347
+ if new_token > 1024:
348
+ break
349
+ if input_ids.shape[1] > 1960:
350
+ break
351
+
352
+ @torch.no_grad()
353
+ def naive_generate(
354
+ self,
355
+ input_ids,
356
+ temperature=0.0,
357
+ top_p=0.0,
358
+ top_k=0.0,
359
+ max_steps=512,
360
+ tree_choices=mc_sim_7b_63,
361
+
362
+ ):
363
+ if temperature > 1e-5:
364
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
365
+ else:
366
+ logits_processor = None
367
+ assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
368
+ # Avoid modifying the input_ids in-place
369
+ input_ids = input_ids.clone()
370
+ self.ea_layer.reset_kv()
371
+
372
+ if hasattr(self, "tree_choices") and self.tree_choices == tree_choices:
373
+ tree_buffers = self.tree_buffers
374
+ else:
375
+ tree_buffers = generate_tree_buffers(
376
+ tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device
377
+ )
378
+ self.tree_buffers = tree_buffers
379
+ self.tree_choices = tree_choices
380
+
381
+ # Initialize the past key and value states
382
+ if hasattr(self, "past_key_values"):
383
+ past_key_values = self.past_key_values
384
+ past_key_values_data = self.past_key_values_data
385
+ current_length_data = self.current_length_data
386
+ # Reset the past key and value states
387
+ current_length_data.zero_()
388
+ else:
389
+ (
390
+ past_key_values,
391
+ past_key_values_data,
392
+ current_length_data,
393
+ ) = initialize_past_key_values(self.base_model)
394
+ self.past_key_values = past_key_values
395
+ self.past_key_values_data = past_key_values_data
396
+ self.current_length_data = current_length_data
397
+
398
+ input_len = input_ids.shape[1]
399
+ reset_tree_mode(self)
400
+ outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
401
+ new_token = 0
402
+
403
+ for idx in range(max_steps):
404
+ input_id = outputs.logits[:, -1:].argmax(dim=-1)
405
+ outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
406
+ input_ids = torch.cat([input_ids, input_id], dim=-1)
407
+
408
+ yield input_ids
409
+
410
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
411
+ break
412
+ if new_token > 1024:
413
+ break
414
+ if input_ids.shape[1] > 1960:
415
+ break
model/kv_cache.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class KVCache:
5
+ """
6
+ A key-value cache for the model.
7
+
8
+ This class provides a mechanism to maintain a growing cache of keys and values,
9
+ particularly useful for models that benefit from caching previous states,
10
+ like transformers during autoregressive decoding.
11
+
12
+ Attributes:
13
+ data (torch.Tensor): The tensor storing keys and values.
14
+ current_length (int): Current length of the data being stored.
15
+ """
16
+
17
+ def __init__(self, data, current_length):
18
+ """
19
+ Initialize the KVCache.
20
+
21
+ Args:
22
+ data (torch.Tensor): Initial tensor to store the keys and values.
23
+ current_length (int): Initial length of the data.
24
+ """
25
+ self.data = data
26
+ self.current_length = current_length
27
+
28
+ @property
29
+ def shape(self):
30
+ """Return the shape of the data tensor with updated length."""
31
+ return (
32
+ self.data.shape[0],
33
+ self.data.shape[1],
34
+ self.current_length.item(),
35
+ self.data.shape[3],
36
+ )
37
+
38
+ def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2):
39
+ """
40
+ Copy values from the current data at specified indices to a new location.
41
+
42
+ Args:
43
+ indices (torch.Tensor): Indices of the data tensor to be copied.
44
+ prev_length (int): Previous length before adding new data.
45
+ dim (int, optional): Dimension along which copying should be performed. Default is 2.
46
+ """
47
+ tgt = self.data.index_select(dim, indices)
48
+ dst = self.data.narrow(dim, prev_length, tgt.shape[dim])
49
+ dst.copy_(tgt, non_blocking=True)
50
+ self.current_length.fill_(prev_length + tgt.shape[dim])
51
+
52
+ def cat(self, tensor: torch.Tensor, dim: int = 2):
53
+ """
54
+ Concatenate the given tensor with the current data.
55
+
56
+ Args:
57
+ tensor (torch.Tensor): The tensor to be concatenated.
58
+ dim (int, optional): The dimension along which concatenation should be done. Default is 2.
59
+
60
+ Returns:
61
+ torch.Tensor: The data tensor after concatenation up to the current length.
62
+ """
63
+ dst = self.data.narrow(dim, self.current_length, tensor.shape[dim])
64
+ dst.copy_(tensor)
65
+ self.current_length.add_(tensor.shape[dim])
66
+ return torch.narrow(self.data, 2, 0, self.current_length)
67
+
68
+
69
+ def initialize_past_key_values(model):
70
+ """
71
+ Initialize past key and value states for a given transformer model.
72
+
73
+ This function prepares key-value cache structures for the model, allowing it to store and reuse
74
+ past key and value states during autoregressive decoding, which can improve efficiency.
75
+
76
+ Args:
77
+ model (nn.Module): The transformer model for which past key-value states need to be initialized.
78
+
79
+ Returns:
80
+ tuple:
81
+ - past_key_values (list): A list of KVCache objects for each layer in the model.
82
+ - past_key_values_data (torch.Tensor): The tensor that will store all keys and values.
83
+ - current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache.
84
+ """
85
+ # Extracting configuration from the model
86
+ config = model.config
87
+ # Initializing the batch size to 1, this can be modified if different batch sizes are required
88
+ batch_size = 1
89
+ # Initializing a tensor to store past keys and values for all layers
90
+
91
+ devices=[]
92
+ for i in range(config.num_hidden_layers):
93
+ try:
94
+ device = model.model.layers[i].self_attn.q_proj.weight.device
95
+ except:
96
+ device=model.layers[i].self_attn.q_proj.weight.device
97
+ devices.append(device)
98
+ past_key_values_data_list=[]
99
+ startnum=0
100
+ startdevice=devices[0]
101
+ for id,i in enumerate(devices):
102
+ if startdevice!=i:
103
+ past_key_values_data = torch.zeros(
104
+ startnum * 2,
105
+ batch_size,
106
+ config.num_key_value_heads,
107
+ config.max_position_embeddings,
108
+ config.hidden_size // config.num_attention_heads,
109
+ device=startdevice,
110
+ dtype=model.dtype,
111
+ )
112
+ past_key_values_data_list.append(past_key_values_data)
113
+ startdevice = i
114
+ startnum=0
115
+ startnum += 1
116
+ past_key_values_data = torch.zeros(
117
+ startnum * 2,
118
+ batch_size,
119
+ config.num_key_value_heads,
120
+ config.max_position_embeddings,
121
+ config.hidden_size // config.num_attention_heads,
122
+ device=startdevice,
123
+ dtype=model.dtype,
124
+ )
125
+ past_key_values_data_list.append(past_key_values_data)
126
+ # Initialize tensor to store the current length of the cached data for all layers.
127
+ # [IMPORTANT] It needs to be kept on CPU for quick access and updates.
128
+ current_length_data = torch.zeros(
129
+ config.num_hidden_layers * 2, dtype=torch.long, device="cpu"
130
+ )
131
+ # Creating a KVCache for each pair of key and value in all layers
132
+ past_key_values = [] * config.num_hidden_layers
133
+
134
+ bias=0
135
+ start_data_m=devices[0].index
136
+ for i in range(config.num_hidden_layers):
137
+ data_m=devices[i].index
138
+ if data_m!=start_data_m:
139
+ bias=0
140
+ start_data_m=data_m
141
+ past_key_values.append(
142
+ [
143
+ KVCache(past_key_values_data_list[data_m-devices[0].index][2*bias + j], current_length_data[i * 2 + j])
144
+ for j in range(2)
145
+ ]
146
+ )
147
+ bias+=1
148
+ return past_key_values, past_key_values_data_list, current_length_data
model/modeling_llama_kv.py ADDED
@@ -0,0 +1,1398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://github.com/huggingface/transformers/blob/v4.31-release/src/transformers/models/llama/modeling_llama.py
2
+ # Modifications are denoted by the symbol: [MODIFIED]
3
+
4
+
5
+ """ PyTorch LLaMA model."""
6
+ import math
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+
15
+ # [MODIFIED] Import from transformer library
16
+ from transformers.activations import ACT2FN
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ SequenceClassifierOutputWithPast,
21
+ )
22
+ from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import (
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+ from transformers import LlamaConfig
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ _CONFIG_FOR_DOC = "LlamaConfig"
34
+
35
+
36
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
37
+ def _make_causal_mask(
38
+ input_ids_shape: torch.Size,
39
+ dtype: torch.dtype,
40
+ device: torch.device,
41
+ past_key_values_length: int = 0,
42
+ ):
43
+ """
44
+ Create a causal mask for bi-directional self-attention.
45
+
46
+ Args:
47
+ input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len).
48
+ dtype (torch.dtype): The data type of the mask.
49
+ device (torch.device): The device on which the mask will be placed.
50
+ past_key_values_length (int, optional): The length of past key values. Default is 0.
51
+
52
+ Returns:
53
+ torch.Tensor: The causal mask tensor.
54
+ """
55
+ bsz, tgt_len = input_ids_shape
56
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
57
+ mask_cond = torch.arange(mask.size(-1), device=device)
58
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
59
+ mask = mask.to(dtype)
60
+
61
+ if past_key_values_length > 0:
62
+ mask = torch.cat(
63
+ [
64
+ torch.zeros(
65
+ tgt_len, past_key_values_length, dtype=dtype, device=device
66
+ ),
67
+ mask,
68
+ ],
69
+ dim=-1,
70
+ )
71
+ return mask[None, None, :, :].expand(
72
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
73
+ )
74
+
75
+
76
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
77
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
78
+ """
79
+ Expand attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
80
+
81
+ Args:
82
+ mask (torch.Tensor): The attention mask tensor of shape `[bsz, seq_len]`.
83
+ dtype (torch.dtype): The data type of the mask.
84
+ tgt_len (Optional[int], optional): The target sequence length. If None, it defaults to the source sequence length.
85
+
86
+ Returns:
87
+ torch.Tensor: The expanded mask tensor.
88
+ """
89
+ bsz, src_len = mask.size()
90
+ tgt_len = tgt_len if tgt_len is not None else src_len
91
+
92
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
93
+
94
+ inverted_mask = 1.0 - expanded_mask
95
+
96
+ return inverted_mask.masked_fill(
97
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
98
+ )
99
+
100
+
101
+ import torch.nn as nn
102
+ import torch
103
+
104
+
105
+ class LlamaRMSNorm(nn.Module):
106
+ """
107
+ LlamaRMSNorm is equivalent to T5LayerNorm.
108
+
109
+ Args:
110
+ hidden_size (int): The size of the hidden states.
111
+ eps (float, optional): A small value to prevent division by zero. Default is 1e-6.
112
+ """
113
+
114
+ def __init__(self, hidden_size, eps=1e-6):
115
+ super().__init__()
116
+ self.weight = nn.Parameter(torch.ones(hidden_size))
117
+ self.variance_epsilon = eps
118
+
119
+ def forward(self, hidden_states):
120
+ """
121
+ Apply LlamaRMSNorm to the input hidden states.
122
+
123
+ Args:
124
+ hidden_states (torch.Tensor): Input hidden states.
125
+
126
+ Returns:
127
+ torch.Tensor: The normalized and scaled hidden states.
128
+ """
129
+ input_dtype = hidden_states.dtype
130
+ hidden_states = hidden_states.to(torch.float32)
131
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
132
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
133
+ return self.weight * hidden_states.to(input_dtype)
134
+
135
+
136
+ class LlamaRotaryEmbedding(nn.Module):
137
+ """
138
+ Llama Rotary Positional Embedding Module.
139
+
140
+ Args:
141
+ dim (int): The dimension of the embedding.
142
+ max_position_embeddings (int, optional): The maximum position for embeddings. Default is 2048.
143
+ base (int, optional): The base value for rotational encoding. Default is 10000.
144
+ device (str, optional): The device on which the computation will be performed. Default is None.
145
+ """
146
+
147
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
148
+ super().__init__()
149
+
150
+ self.dim = dim
151
+ self.max_position_embeddings = max_position_embeddings
152
+ self.base = base
153
+ inv_freq = 1.0 / (
154
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
155
+ )
156
+ self.register_buffer("inv_freq", inv_freq)
157
+
158
+ # Build here to make `torch.jit.trace` work.
159
+ self._set_cos_sin_cache(
160
+ seq_len=max_position_embeddings,
161
+ device=self.inv_freq.device,
162
+ dtype=torch.get_default_dtype(),
163
+ )
164
+
165
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
166
+ """
167
+ Set the cosine and sine cache for positional embeddings.
168
+
169
+ Args:
170
+ seq_len (int): The sequence length.
171
+ device (str): The device on which the cache tensors will be stored.
172
+ dtype: The data type of the cache tensors.
173
+ """
174
+ self.max_seq_len_cached = seq_len
175
+ t = torch.arange(
176
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
177
+ )
178
+
179
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
180
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
181
+ emb = torch.cat((freqs, freqs), dim=-1)
182
+ self.register_buffer(
183
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
184
+ )
185
+ self.register_buffer(
186
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
187
+ )
188
+
189
+ def forward(self, x, seq_len=None):
190
+ """
191
+ Forward pass of the LlamaRotaryEmbedding module.
192
+
193
+ Args:
194
+ x (torch.Tensor): Input tensor of shape [bs, num_attention_heads, seq_len, head_size].
195
+ seq_len (int): The sequence length. If greater than the cached length, the cache will be updated.
196
+
197
+ Returns:
198
+ tuple: A tuple containing two tensors, the cosine and sine embeddings, both of shape [1, 1, seq_len, dim].
199
+ """
200
+ if seq_len > self.max_seq_len_cached:
201
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
202
+
203
+ return (
204
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
205
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
206
+ )
207
+
208
+
209
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
210
+ """
211
+ LlamaRotaryEmbedding extended with linear scaling.
212
+
213
+ This class adds linear scaling to LlamaRotaryEmbedding. Credits to the Reddit user /u/kaiokendev.
214
+
215
+ Args:
216
+ dim (int): The dimension of the embedding.
217
+ max_position_embeddings (int, optional): The maximum number of position embeddings. Default is 2048.
218
+ base (int, optional): The base value for the rotational embeddings. Default is 10000.
219
+ device (str or torch.device, optional): The device where the embeddings should be stored. Default is None.
220
+ scaling_factor (float, optional): The scaling factor for the embeddings. Default is 1.0.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ dim,
226
+ max_position_embeddings=2048,
227
+ base=10000,
228
+ device=None,
229
+ scaling_factor=1.0,
230
+ ):
231
+ self.scaling_factor = scaling_factor
232
+ super().__init__(dim, max_position_embeddings, base, device)
233
+
234
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
235
+ """
236
+ Set the cosine and sine cache for the rotary embeddings.
237
+
238
+ Args:
239
+ seq_len (int): The sequence length.
240
+ device (str or torch.device): The device where the cache should be stored.
241
+ dtype: The data type for the cache.
242
+ """
243
+ self.max_seq_len_cached = seq_len
244
+ t = torch.arange(
245
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
246
+ )
247
+ t = t / self.scaling_factor
248
+
249
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
250
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
251
+ emb = torch.cat((freqs, freqs), dim=-1)
252
+ self.register_buffer(
253
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
254
+ )
255
+ self.register_buffer(
256
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
257
+ )
258
+
259
+
260
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
261
+ """
262
+ LlamaRotaryEmbedding extended with Dynamic NTK scaling.
263
+
264
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ dim,
270
+ max_position_embeddings=2048,
271
+ base=10000,
272
+ device=None,
273
+ scaling_factor=1.0,
274
+ ):
275
+ """
276
+ Initialize the LlamaDynamicNTKScalingRotaryEmbedding.
277
+
278
+ Args:
279
+ dim (int): The dimensionality of the embedding.
280
+ max_position_embeddings (int, optional): Maximum number of position embeddings. Default is 2048.
281
+ base (int, optional): Base value for scaling calculations. Default is 10000.
282
+ device: The device to place tensors on. If None, uses the default device.
283
+ scaling_factor (float, optional): Scaling factor for NTK scaling. Default is 1.0.
284
+ """
285
+ self.scaling_factor = scaling_factor
286
+ super().__init__(dim, max_position_embeddings, base, device)
287
+
288
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
289
+ """
290
+ Set the cached values for cosine and sine.
291
+
292
+ Args:
293
+ seq_len (int): The sequence length.
294
+ device: The device to place tensors on.
295
+ dtype: The data type of tensors.
296
+ """
297
+ self.max_seq_len_cached = seq_len
298
+
299
+ if seq_len > self.max_position_embeddings:
300
+ base = self.base * (
301
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
302
+ - (self.scaling_factor - 1)
303
+ ) ** (self.dim / (self.dim - 2))
304
+ inv_freq = 1.0 / (
305
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
306
+ )
307
+ self.register_buffer("inv_freq", inv_freq)
308
+
309
+ t = torch.arange(
310
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
311
+ )
312
+
313
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
314
+ emb = torch.cat((freqs, freqs), dim=-1)
315
+ self.register_buffer(
316
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
317
+ )
318
+ self.register_buffer(
319
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
320
+ )
321
+
322
+
323
+ def rotate_half(x):
324
+ """
325
+ Rotates half the hidden dimensions of the input.
326
+
327
+ Args:
328
+ x (torch.Tensor): Input tensor.
329
+
330
+ Returns:
331
+ torch.Tensor: Tensor with half of its hidden dimensions rotated.
332
+ """
333
+ x1 = x[..., : x.shape[-1] // 2]
334
+ x2 = x[..., x.shape[-1] // 2:]
335
+ return torch.cat((-x2, x1), dim=-1)
336
+
337
+
338
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
339
+ """
340
+ Apply rotary position embeddings to query and key tensors.
341
+
342
+ Args:
343
+ q (torch.Tensor): Query tensor.
344
+ k (torch.Tensor): Key tensor.
345
+ cos (torch.Tensor): Cosine values.
346
+ sin (torch.Tensor): Sine values.
347
+ position_ids (torch.Tensor): Position IDs.
348
+
349
+ Returns:
350
+ torch.Tensor: Query and key tensors with rotary position embeddings applied.
351
+ """
352
+ cos = cos.squeeze(1).squeeze(0)
353
+ sin = sin.squeeze(1).squeeze(0)
354
+ cos = cos[position_ids].unsqueeze(1)
355
+ sin = sin[position_ids].unsqueeze(1)
356
+ q_embed = (q * cos) + (rotate_half(q) * sin)
357
+ k_embed = (k * cos) + (rotate_half(k) * sin)
358
+ return q_embed, k_embed
359
+
360
+
361
+ class LlamaMLP(nn.Module):
362
+ """
363
+ LlamaMLP is a multi-layer perceptron module used in the Llama model.
364
+
365
+ Args:
366
+ config: The configuration for the MLP.
367
+
368
+ Attributes:
369
+ pretraining_tp (int): The pretraining time periods.
370
+ hidden_size (int): The size of the hidden layer.
371
+ intermediate_size (int): The size of the intermediate layer.
372
+ gate_proj (nn.Linear): The linear projection for gating.
373
+ up_proj (nn.Linear): The linear projection for the up projection.
374
+ down_proj (nn.Linear): The linear projection for the down projection.
375
+ act_fn: The activation function.
376
+
377
+ """
378
+
379
+ def __init__(self, config):
380
+ super().__init__()
381
+ self.pretraining_tp = config.pretraining_tp
382
+ self.hidden_size = config.hidden_size
383
+ self.intermediate_size = config.intermediate_size
384
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
385
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
386
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
387
+ self.act_fn = ACT2FN[config.hidden_act]
388
+
389
+ def forward(self, x):
390
+ """
391
+ Forward pass of the MLP.
392
+
393
+ Args:
394
+ x: Input tensor.
395
+
396
+ Returns:
397
+ torch.Tensor: Output tensor.
398
+ """
399
+ if self.pretraining_tp > 1:
400
+ slice = self.intermediate_size // self.pretraining_tp
401
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
402
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
403
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
404
+
405
+ gate_proj = torch.cat(
406
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)],
407
+ dim=-1,
408
+ )
409
+ up_proj = torch.cat(
410
+ [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)],
411
+ dim=-1,
412
+ )
413
+
414
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
415
+ down_proj = [
416
+ F.linear(intermediate_states[i], down_proj_slices[i])
417
+ for i in range(self.pretraining_tp)
418
+ ]
419
+ down_proj = sum(down_proj)
420
+ else:
421
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
422
+
423
+ return down_proj
424
+
425
+
426
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
427
+ """
428
+ Repeat key and value tensors n times along the specified dimension.
429
+
430
+ Args:
431
+ hidden_states (torch.Tensor): Input tensor with shape (batch, num_key_value_heads, seqlen, head_dim).
432
+ n_rep (int): Number of times to repeat.
433
+
434
+ Returns:
435
+ torch.Tensor: Repeated tensor with shape (batch, num_key_value_heads * n_rep, seqlen, head_dim).
436
+ """
437
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
438
+ if n_rep == 1:
439
+ return hidden_states
440
+ hidden_states = hidden_states[:, :, None, :, :].expand(
441
+ batch, num_key_value_heads, n_rep, slen, head_dim
442
+ )
443
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
444
+
445
+
446
+ class LlamaAttention(nn.Module):
447
+ """
448
+ LlamaAttention is a multi-headed attention module based on the 'Attention Is All You Need' paper.
449
+
450
+ Args:
451
+ config (LlamaConfig): Configuration for the attention module.
452
+
453
+ Attributes:
454
+ config (LlamaConfig): Configuration for the attention module.
455
+ hidden_size (int): The size of the hidden layer.
456
+ num_heads (int): The number of attention heads.
457
+ head_dim (int): The dimension of each attention head.
458
+ num_key_value_heads (int): The number of key-value attention heads.
459
+ num_key_value_groups (int): The number of key-value groups.
460
+ pretraining_tp (int): The pretraining time periods.
461
+ max_position_embeddings (int): The maximum position embeddings.
462
+
463
+ """
464
+
465
+ def __init__(self, config: LlamaConfig):
466
+ super().__init__()
467
+ self.config = config
468
+ self.hidden_size = config.hidden_size
469
+ self.num_heads = config.num_attention_heads
470
+ self.head_dim = self.hidden_size // self.num_heads
471
+ self.num_key_value_heads = config.num_key_value_heads
472
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
473
+ self.pretraining_tp = config.pretraining_tp
474
+ self.max_position_embeddings = config.max_position_embeddings
475
+
476
+ if (self.head_dim * self.num_heads) != self.hidden_size:
477
+ raise ValueError(
478
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
479
+ f" and `num_heads`: {self.num_heads})."
480
+ )
481
+ self.q_proj = nn.Linear(
482
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
483
+ )
484
+ self.k_proj = nn.Linear(
485
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
486
+ )
487
+ self.v_proj = nn.Linear(
488
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
489
+ )
490
+ self.o_proj = nn.Linear(
491
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
492
+ )
493
+ self._init_rope()
494
+
495
+ def _init_rope(self):
496
+ if self.config.rope_scaling is None:
497
+ self.rotary_emb = LlamaRotaryEmbedding(
498
+ self.head_dim, max_position_embeddings=self.max_position_embeddings
499
+ )
500
+ else:
501
+ scaling_type = self.config.rope_scaling["type"]
502
+ scaling_factor = self.config.rope_scaling["factor"]
503
+ if scaling_type == "linear":
504
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
505
+ self.head_dim,
506
+ max_position_embeddings=self.max_position_embeddings,
507
+ scaling_factor=scaling_factor,
508
+ )
509
+ elif scaling_type == "dynamic":
510
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
511
+ self.head_dim,
512
+ max_position_embeddings=self.max_position_embeddings,
513
+ scaling_factor=scaling_factor,
514
+ )
515
+ else:
516
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
517
+
518
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
519
+ return (
520
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
521
+ .transpose(1, 2)
522
+ .contiguous()
523
+ )
524
+
525
+ def forward(
526
+ self,
527
+ hidden_states: torch.Tensor,
528
+ attention_mask: Optional[torch.Tensor] = None,
529
+ position_ids: Optional[torch.LongTensor] = None,
530
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
531
+ output_attentions: bool = False,
532
+ use_cache: bool = False,
533
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
534
+ bsz, q_len, _ = hidden_states.size()
535
+
536
+ if self.pretraining_tp > 1:
537
+ key_value_slicing = (
538
+ self.num_key_value_heads * self.head_dim
539
+ ) // self.pretraining_tp
540
+ query_slices = self.q_proj.weight.split(
541
+ (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
542
+ )
543
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
544
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
545
+
546
+ query_states = [
547
+ F.linear(hidden_states, query_slices[i])
548
+ for i in range(self.pretraining_tp)
549
+ ]
550
+ query_states = torch.cat(query_states, dim=-1)
551
+
552
+ key_states = [
553
+ F.linear(hidden_states, key_slices[i])
554
+ for i in range(self.pretraining_tp)
555
+ ]
556
+ key_states = torch.cat(key_states, dim=-1)
557
+
558
+ value_states = [
559
+ F.linear(hidden_states, value_slices[i])
560
+ for i in range(self.pretraining_tp)
561
+ ]
562
+ value_states = torch.cat(value_states, dim=-1)
563
+
564
+ else:
565
+ query_states = self.q_proj(hidden_states)
566
+ key_states = self.k_proj(hidden_states)
567
+ value_states = self.v_proj(hidden_states)
568
+
569
+ query_states = query_states.view(
570
+ bsz, q_len, self.num_heads, self.head_dim
571
+ ).transpose(1, 2)
572
+ key_states = key_states.view(
573
+ bsz, q_len, self.num_key_value_heads, self.head_dim
574
+ ).transpose(1, 2)
575
+ value_states = value_states.view(
576
+ bsz, q_len, self.num_key_value_heads, self.head_dim
577
+ ).transpose(1, 2)
578
+
579
+ kv_seq_len = key_states.shape[-2]
580
+ if past_key_value is not None:
581
+ kv_seq_len += past_key_value[0].shape[-2]
582
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
583
+ query_states, key_states = apply_rotary_pos_emb(
584
+ query_states, key_states, cos, sin, position_ids
585
+ )
586
+
587
+ # [MODIFIED] Using KVCache mechanism for preallocated GPU memory optimization
588
+ # past_key_value is utilized to leverage previously computed key and value states.
589
+ # If past_key_value is available, reuse the states for k, v, and self_attention.
590
+ if past_key_value is not None:
591
+ key_states = past_key_value[0].cat(key_states, dim=2)
592
+ value_states = past_key_value[1].cat(value_states, dim=2)
593
+ # Reset past_key_value to avoid return past_key_value.
594
+ past_key_value = None
595
+
596
+ # repeat k/v heads if n_kv_heads < n_heads
597
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
598
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
599
+
600
+ attn_weights = torch.matmul(
601
+ query_states, key_states.transpose(2, 3)
602
+ ) / math.sqrt(self.head_dim)
603
+
604
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
605
+ raise ValueError(
606
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
607
+ f" {attn_weights.size()}"
608
+ )
609
+
610
+ if attention_mask is not None:
611
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
612
+ raise ValueError(
613
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
614
+ )
615
+ attn_weights = attn_weights + attention_mask
616
+
617
+ # upcast attention to fp32
618
+ attn_weights = nn.functional.softmax(
619
+ attn_weights, dim=-1, dtype=torch.float32
620
+ ).to(query_states.dtype)
621
+ attn_output = torch.matmul(attn_weights, value_states)
622
+
623
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
624
+ raise ValueError(
625
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
626
+ f" {attn_output.size()}"
627
+ )
628
+
629
+ attn_output = attn_output.transpose(1, 2).contiguous()
630
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
631
+
632
+ if self.pretraining_tp > 1:
633
+ attn_output = attn_output.split(
634
+ self.hidden_size // self.pretraining_tp, dim=2
635
+ )
636
+ o_proj_slices = self.o_proj.weight.split(
637
+ self.hidden_size // self.pretraining_tp, dim=1
638
+ )
639
+ attn_output = sum(
640
+ [
641
+ F.linear(attn_output[i], o_proj_slices[i])
642
+ for i in range(self.pretraining_tp)
643
+ ]
644
+ )
645
+ else:
646
+ attn_output = self.o_proj(attn_output)
647
+
648
+ if not output_attentions:
649
+ attn_weights = None
650
+
651
+ return attn_output, attn_weights, past_key_value
652
+
653
+
654
+ class LlamaDecoderLayer(nn.Module):
655
+ """
656
+ LlamaDecoderLayer represents a single layer of the Llama decoder.
657
+
658
+ Args:
659
+ config (LlamaConfig): Configuration for the decoder layer.
660
+
661
+ Attributes:
662
+ hidden_size (int): The size of the hidden layer.
663
+ self_attn (LlamaAttention): Multi-headed self-attention module.
664
+ mlp (LlamaMLP): Multi-layer perceptron module.
665
+ input_layernorm (LlamaRMSNorm): Layer normalization for input.
666
+ post_attention_layernorm (LlamaRMSNorm): Layer normalization after self-attention.
667
+ """
668
+
669
+ def __init__(self, config: LlamaConfig):
670
+ super().__init__()
671
+ self.hidden_size = config.hidden_size
672
+ self.self_attn = LlamaAttention(config=config)
673
+ self.mlp = LlamaMLP(config)
674
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
675
+ self.post_attention_layernorm = LlamaRMSNorm(
676
+ config.hidden_size, eps=config.rms_norm_eps
677
+ )
678
+
679
+ def forward(
680
+ self,
681
+ hidden_states: torch.Tensor,
682
+ attention_mask: Optional[torch.Tensor] = None,
683
+ position_ids: Optional[torch.LongTensor] = None,
684
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
685
+ output_attentions: Optional[bool] = False,
686
+ use_cache: Optional[bool] = False,
687
+ ) -> Tuple[
688
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
689
+ ]:
690
+ """
691
+ Forward pass for the LlamaDecoderLayer.
692
+
693
+ Args:
694
+ hidden_states (torch.FloatTensor): Input tensor of shape `(batch, seq_len, embed_dim)`.
695
+ attention_mask (torch.FloatTensor, optional): Attention mask of size
696
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
697
+ position_ids (torch.LongTensor, optional): Positional IDs tensor.
698
+ past_key_value (Tuple[torch.FloatTensor], optional): Cached past key and value projection states.
699
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers.
700
+ use_cache (bool, optional): If set to `True`, `past_key_values` key-value states are returned and can be
701
+ used to speed up decoding.
702
+
703
+ Returns:
704
+ Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Tuple containing:
705
+ - hidden_states (torch.FloatTensor): Output tensor.
706
+ - self_attn_weights (Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]): Self-attention weights if
707
+ `output_attentions` is `True`.
708
+ - present_key_value (Optional[Tuple[torch.FloatTensor]]): Cached key and value projection states if
709
+ `use_cache` is `True`.
710
+ """
711
+
712
+ residual = hidden_states
713
+
714
+ hidden_states = self.input_layernorm(hidden_states)
715
+
716
+ # Self Attention
717
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
718
+ hidden_states=hidden_states,
719
+ attention_mask=attention_mask,
720
+ position_ids=position_ids,
721
+ past_key_value=past_key_value,
722
+ output_attentions=output_attentions,
723
+ use_cache=use_cache,
724
+ )
725
+ hidden_states = residual + hidden_states
726
+
727
+ # Fully Connected
728
+ residual = hidden_states
729
+ hidden_states = self.post_attention_layernorm(hidden_states)
730
+ hidden_states = self.mlp(hidden_states)
731
+ hidden_states = residual + hidden_states
732
+
733
+ outputs = (hidden_states,)
734
+
735
+ if output_attentions:
736
+ outputs += (self_attn_weights,)
737
+
738
+ if use_cache:
739
+ outputs += (present_key_value,)
740
+
741
+ return outputs
742
+
743
+
744
+ LLAMA_START_DOCSTRING = r"""
745
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
746
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
747
+ etc.)
748
+
749
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
750
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
751
+ and behavior.
752
+
753
+ Parameters:
754
+ config ([`LlamaConfig`]):
755
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
756
+ load the weights associated with the model, only the configuration. Check out the
757
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
758
+ """
759
+
760
+
761
+ @add_start_docstrings(
762
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
763
+ LLAMA_START_DOCSTRING,
764
+ )
765
+ class LlamaPreTrainedModel(PreTrainedModel):
766
+ config_class = LlamaConfig
767
+ base_model_prefix = "model"
768
+ supports_gradient_checkpointing = True
769
+ _no_split_modules = ["LlamaDecoderLayer"]
770
+ _skip_keys_device_placement = "past_key_values"
771
+
772
+ def _init_weights(self, module):
773
+ std = self.config.initializer_range
774
+ if isinstance(module, nn.Linear):
775
+ module.weight.data.normal_(mean=0.0, std=std)
776
+ if module.bias is not None:
777
+ module.bias.data.zero_()
778
+ elif isinstance(module, nn.Embedding):
779
+ module.weight.data.normal_(mean=0.0, std=std)
780
+ if module.padding_idx is not None:
781
+ module.weight.data[module.padding_idx].zero_()
782
+
783
+ def _set_gradient_checkpointing(self, module, value=False):
784
+ if isinstance(module, LlamaModel):
785
+ module.gradient_checkpointing = value
786
+
787
+
788
+ LLAMA_INPUTS_DOCSTRING = r"""
789
+ Args:
790
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
791
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
792
+ it.
793
+
794
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
795
+ [`PreTrainedTokenizer.__call__`] for details.
796
+
797
+ [What are input IDs?](../glossary#input-ids)
798
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
799
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
800
+
801
+ - 1 for tokens that are **not masked**,
802
+ - 0 for tokens that are **masked**.
803
+
804
+ [What are attention masks?](../glossary#attention-mask)
805
+
806
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
807
+ [`PreTrainedTokenizer.__call__`] for details.
808
+
809
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
810
+ `past_key_values`).
811
+
812
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
813
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
814
+ information on the default strategy.
815
+
816
+ - 1 indicates the head is **not masked**,
817
+ - 0 indicates the head is **masked**.
818
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
819
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
820
+ config.n_positions - 1]`.
821
+
822
+ [What are position IDs?](../glossary#position-ids)
823
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
824
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
825
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
826
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
827
+
828
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
829
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
830
+
831
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
832
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
833
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
834
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
835
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
836
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
837
+ model's internal embedding lookup matrix.
838
+ use_cache (`bool`, *optional*):
839
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
840
+ `past_key_values`).
841
+ output_attentions (`bool`, *optional*):
842
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
843
+ tensors for more detail.
844
+ output_hidden_states (`bool`, *optional*):
845
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
846
+ more detail.
847
+ return_dict (`bool`, *optional*):
848
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
849
+ """
850
+
851
+
852
+ @add_start_docstrings(
853
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
854
+ LLAMA_START_DOCSTRING,
855
+ )
856
+ class LlamaModel(LlamaPreTrainedModel):
857
+ """
858
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
859
+
860
+ Args:
861
+ config: LlamaConfig
862
+ """
863
+
864
+ def __init__(self, config: LlamaConfig):
865
+ super().__init__(config)
866
+ self.padding_idx = config.pad_token_id
867
+ self.vocab_size = config.vocab_size
868
+
869
+ self.embed_tokens = nn.Embedding(
870
+ config.vocab_size, config.hidden_size, self.padding_idx
871
+ )
872
+ self.layers = nn.ModuleList(
873
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
874
+ )
875
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
876
+
877
+ self.gradient_checkpointing = False
878
+ # Initialize weights and apply final processing
879
+ self.post_init()
880
+
881
+ def get_input_embeddings(self):
882
+ return self.embed_tokens
883
+
884
+ def set_input_embeddings(self, value):
885
+ self.embed_tokens = value
886
+
887
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
888
+ def _prepare_decoder_attention_mask(
889
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
890
+ ):
891
+ # create causal mask
892
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
893
+ combined_attention_mask = None
894
+ if input_shape[-1] > 1:
895
+ combined_attention_mask = _make_causal_mask(
896
+ input_shape,
897
+ # inputs_embeds.dtype,
898
+ torch.float32, # [MODIFIED] force to cast to float32
899
+ device=inputs_embeds.device,
900
+ past_key_values_length=past_key_values_length,
901
+ )
902
+
903
+ if attention_mask is not None:
904
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
905
+ expanded_attn_mask = _expand_mask(
906
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
907
+ ).to(inputs_embeds.device)
908
+ combined_attention_mask = (
909
+ expanded_attn_mask
910
+ if combined_attention_mask is None
911
+ else expanded_attn_mask + combined_attention_mask
912
+ )
913
+
914
+
915
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
916
+ tree_mask = self.tree_mask
917
+ tree_len = tree_mask.size(-1)
918
+ combined_attention_mask[:, :, -tree_len:, -tree_len:][
919
+ tree_mask == 0
920
+ ] = combined_attention_mask.min()
921
+
922
+ return combined_attention_mask
923
+
924
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
925
+ def forward(
926
+ self,
927
+ input_ids: torch.LongTensor = None,
928
+ attention_mask: Optional[torch.Tensor] = None,
929
+ position_ids: Optional[torch.LongTensor] = None,
930
+ past_key_values=None, # [MODIFIED] past_key_value is KVCache class
931
+ inputs_embeds: Optional[torch.FloatTensor] = None,
932
+ use_cache: Optional[bool] = None,
933
+ output_attentions: Optional[bool] = None,
934
+ output_hidden_states: Optional[bool] = None,
935
+ return_dict: Optional[bool] = None,
936
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
937
+ output_attentions = (
938
+ output_attentions
939
+ if output_attentions is not None
940
+ else self.config.output_attentions
941
+ )
942
+ output_hidden_states = (
943
+ output_hidden_states
944
+ if output_hidden_states is not None
945
+ else self.config.output_hidden_states
946
+ )
947
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
948
+
949
+ return_dict = (
950
+ return_dict if return_dict is not None else self.config.use_return_dict
951
+ )
952
+
953
+ # retrieve input_ids and inputs_embeds
954
+ if input_ids is not None and inputs_embeds is not None:
955
+ raise ValueError(
956
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
957
+ )
958
+ elif input_ids is not None:
959
+ batch_size, seq_length = input_ids.shape
960
+ elif inputs_embeds is not None:
961
+ batch_size, seq_length, _ = inputs_embeds.shape
962
+ else:
963
+ raise ValueError(
964
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
965
+ )
966
+
967
+ seq_length_with_past = seq_length
968
+ past_key_values_length = 0
969
+
970
+ if past_key_values is not None:
971
+ past_key_values_length = past_key_values[0][0].shape[2]
972
+ seq_length_with_past = seq_length_with_past + past_key_values_length
973
+
974
+ if position_ids is None:
975
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
976
+ position_ids = torch.arange(
977
+ past_key_values_length,
978
+ seq_length + past_key_values_length,
979
+ dtype=torch.long,
980
+ device=device,
981
+ )
982
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
983
+ else:
984
+ position_ids = position_ids.view(-1, seq_length).long()
985
+
986
+ if inputs_embeds is None:
987
+ inputs_embeds = self.embed_tokens(input_ids)
988
+ # embed positions
989
+ if attention_mask is None:
990
+ attention_mask = torch.ones(
991
+ (batch_size, seq_length_with_past),
992
+ dtype=torch.bool,
993
+ device=inputs_embeds.device,
994
+ )
995
+ attention_mask = self._prepare_decoder_attention_mask(
996
+ attention_mask,
997
+ (batch_size, seq_length),
998
+ inputs_embeds,
999
+ past_key_values_length,
1000
+ )
1001
+
1002
+ hidden_states = inputs_embeds
1003
+
1004
+ if self.gradient_checkpointing and self.training:
1005
+ if use_cache:
1006
+ logger.warning_once(
1007
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1008
+ )
1009
+ use_cache = False
1010
+
1011
+ # decoder layers
1012
+ all_hidden_states = () if output_hidden_states else None
1013
+ all_self_attns = () if output_attentions else None
1014
+ next_decoder_cache = () if use_cache else None
1015
+
1016
+ for idx, decoder_layer in enumerate(self.layers):
1017
+ # if idx==16:
1018
+ # print(idx)
1019
+ if output_hidden_states:
1020
+ all_hidden_states += (hidden_states,)
1021
+
1022
+ past_key_value = (
1023
+ past_key_values[idx] if past_key_values is not None else None
1024
+ )
1025
+
1026
+ if self.gradient_checkpointing and self.training:
1027
+
1028
+ def create_custom_forward(module):
1029
+ def custom_forward(*inputs):
1030
+ # None for past_key_value
1031
+ return module(*inputs, output_attentions, None)
1032
+
1033
+ return custom_forward
1034
+
1035
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1036
+ create_custom_forward(decoder_layer),
1037
+ hidden_states,
1038
+ attention_mask,
1039
+ position_ids,
1040
+ None,
1041
+ )
1042
+ else:
1043
+ layer_outputs = decoder_layer(
1044
+ hidden_states,
1045
+ attention_mask=attention_mask,
1046
+ position_ids=position_ids,
1047
+ past_key_value=past_key_value,
1048
+ output_attentions=output_attentions,
1049
+ use_cache=use_cache,
1050
+ )
1051
+
1052
+ hidden_states = layer_outputs[0]
1053
+
1054
+ if use_cache:
1055
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1056
+
1057
+ if output_attentions:
1058
+ all_self_attns += (layer_outputs[1],)
1059
+
1060
+ hidden_states = self.norm(hidden_states)
1061
+
1062
+ # add hidden states from the last decoder layer
1063
+ if output_hidden_states:
1064
+ all_hidden_states += (hidden_states,)
1065
+
1066
+ next_cache = next_decoder_cache if use_cache else None
1067
+ if not return_dict:
1068
+ return tuple(
1069
+ v
1070
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1071
+ if v is not None
1072
+ )
1073
+ return BaseModelOutputWithPast(
1074
+ last_hidden_state=hidden_states,
1075
+ past_key_values=next_cache,
1076
+ hidden_states=all_hidden_states,
1077
+ attentions=all_self_attns,
1078
+ )
1079
+
1080
+
1081
+ class LlamaForCausalLM(LlamaPreTrainedModel):
1082
+ _tied_weights_keys = ["lm_head.weight"]
1083
+
1084
+ def __init__(self, config):
1085
+ super().__init__(config)
1086
+ self.model = LlamaModel(config)
1087
+ self.pretraining_tp = config.pretraining_tp
1088
+ self.vocab_size = config.vocab_size
1089
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1090
+
1091
+ # Initialize weights and apply final processing
1092
+ self.post_init()
1093
+
1094
+ def get_input_embeddings(self):
1095
+ return self.model.embed_tokens
1096
+
1097
+ def set_input_embeddings(self, value):
1098
+ self.model.embed_tokens = value
1099
+
1100
+ def get_output_embeddings(self):
1101
+ return self.lm_head
1102
+
1103
+ def set_output_embeddings(self, new_embeddings):
1104
+ self.lm_head = new_embeddings
1105
+
1106
+ def set_decoder(self, decoder):
1107
+ self.model = decoder
1108
+
1109
+ def get_decoder(self):
1110
+ return self.model
1111
+
1112
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1113
+ @replace_return_docstrings(
1114
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1115
+ )
1116
+ def forward(
1117
+ self,
1118
+ input_ids: torch.LongTensor = None,
1119
+ attention_mask: Optional[torch.Tensor] = None,
1120
+ position_ids: Optional[torch.LongTensor] = None,
1121
+ past_key_values=None, # [MODIFIED] past_key_value is KVCache class
1122
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1123
+ labels: Optional[torch.LongTensor] = None,
1124
+ use_cache: Optional[bool] = None,
1125
+ output_attentions: Optional[bool] = None,
1126
+ output_hidden_states: Optional[bool] = None,
1127
+ return_dict: Optional[bool] = None,
1128
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1129
+ r"""
1130
+ Args:
1131
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1132
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1133
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1134
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1135
+
1136
+ Returns:
1137
+
1138
+ Example:
1139
+
1140
+ ```python
1141
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1142
+
1143
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1144
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1145
+
1146
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1147
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1148
+
1149
+ >>> # Generate
1150
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1151
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1152
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1153
+ ```"""
1154
+
1155
+ output_attentions = (
1156
+ output_attentions
1157
+ if output_attentions is not None
1158
+ else self.config.output_attentions
1159
+ )
1160
+ output_hidden_states = (
1161
+ output_hidden_states
1162
+ if output_hidden_states is not None
1163
+ else self.config.output_hidden_states
1164
+ )
1165
+ return_dict = (
1166
+ return_dict if return_dict is not None else self.config.use_return_dict
1167
+ )
1168
+
1169
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1170
+ outputs = self.model(
1171
+ input_ids=input_ids,
1172
+ attention_mask=attention_mask,
1173
+ position_ids=position_ids,
1174
+ past_key_values=past_key_values,
1175
+ inputs_embeds=inputs_embeds,
1176
+ use_cache=use_cache,
1177
+ output_attentions=output_attentions,
1178
+ output_hidden_states=output_hidden_states,
1179
+ return_dict=return_dict,
1180
+ )
1181
+
1182
+ hidden_states = outputs[0]
1183
+ if self.pretraining_tp > 1:
1184
+ lm_head_slices = self.lm_head.weight.split(
1185
+ self.vocab_size // self.pretraining_tp, dim=0
1186
+ )
1187
+ logits = [
1188
+ F.linear(hidden_states, lm_head_slices[i])
1189
+ for i in range(self.pretraining_tp)
1190
+ ]
1191
+ logits = torch.cat(logits, dim=-1)
1192
+ else:
1193
+ logits = self.lm_head(hidden_states)
1194
+ logits = logits.float()
1195
+
1196
+ loss = None
1197
+ if labels is not None:
1198
+ # Shift so that tokens < n predict n
1199
+ shift_logits = logits[..., :-1, :].contiguous()
1200
+ shift_labels = labels[..., 1:].contiguous()
1201
+ # Flatten the tokens
1202
+ loss_fct = CrossEntropyLoss()
1203
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1204
+ shift_labels = shift_labels.view(-1)
1205
+ # Enable model parallelism
1206
+ shift_labels = shift_labels.to(shift_logits.device)
1207
+ loss = loss_fct(shift_logits, shift_labels)
1208
+
1209
+ if not return_dict:
1210
+ output = (logits,) + outputs[1:]
1211
+ return (loss,) + output if loss is not None else output
1212
+
1213
+ return CausalLMOutputWithPast(
1214
+ loss=loss,
1215
+ logits=logits,
1216
+ past_key_values=outputs.past_key_values,
1217
+ hidden_states=outputs.hidden_states,
1218
+ attentions=outputs.attentions,
1219
+ )
1220
+
1221
+ def prepare_inputs_for_generation(
1222
+ self,
1223
+ input_ids,
1224
+ past_key_values=None,
1225
+ attention_mask=None,
1226
+ inputs_embeds=None,
1227
+ **kwargs,
1228
+ ):
1229
+ if past_key_values:
1230
+ input_ids = input_ids[:, -1:]
1231
+
1232
+ position_ids = kwargs.get("position_ids", None)
1233
+ if attention_mask is not None and position_ids is None:
1234
+ # create position_ids on the fly for batch generation
1235
+ position_ids = attention_mask.long().cumsum(-1) - 1
1236
+ position_ids.masked_fill_(attention_mask == 0, 1)
1237
+ if past_key_values:
1238
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1239
+
1240
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1241
+ if inputs_embeds is not None and past_key_values is None:
1242
+ model_inputs = {"inputs_embeds": inputs_embeds}
1243
+ else:
1244
+ model_inputs = {"input_ids": input_ids}
1245
+
1246
+ model_inputs.update(
1247
+ {
1248
+ "position_ids": position_ids,
1249
+ "past_key_values": past_key_values,
1250
+ "use_cache": kwargs.get("use_cache"),
1251
+ "attention_mask": attention_mask,
1252
+ }
1253
+ )
1254
+ return model_inputs
1255
+
1256
+ @staticmethod
1257
+ def _reorder_cache(past_key_values, beam_idx):
1258
+ reordered_past = ()
1259
+ for layer_past in past_key_values:
1260
+ reordered_past += (
1261
+ tuple(
1262
+ past_state.index_select(0, beam_idx.to(past_state.device))
1263
+ for past_state in layer_past
1264
+ ),
1265
+ )
1266
+ return reordered_past
1267
+
1268
+
1269
+ @add_start_docstrings(
1270
+ """
1271
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1272
+
1273
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1274
+ (e.g. GPT-2) do.
1275
+
1276
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1277
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1278
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1279
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1280
+ each row of the batch).
1281
+ """,
1282
+ LLAMA_START_DOCSTRING,
1283
+ )
1284
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1285
+ def __init__(self, config):
1286
+ super().__init__(config)
1287
+ self.num_labels = config.num_labels
1288
+ self.model = LlamaModel(config)
1289
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1290
+
1291
+ # Initialize weights and apply final processing
1292
+ self.post_init()
1293
+
1294
+ def get_input_embeddings(self):
1295
+ return self.model.embed_tokens
1296
+
1297
+ def set_input_embeddings(self, value):
1298
+ self.model.embed_tokens = value
1299
+
1300
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1301
+ def forward(
1302
+ self,
1303
+ input_ids: torch.LongTensor = None,
1304
+ attention_mask: Optional[torch.Tensor] = None,
1305
+ position_ids: Optional[torch.LongTensor] = None,
1306
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1307
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1308
+ labels: Optional[torch.LongTensor] = None,
1309
+ use_cache: Optional[bool] = None,
1310
+ output_attentions: Optional[bool] = None,
1311
+ output_hidden_states: Optional[bool] = None,
1312
+ return_dict: Optional[bool] = None,
1313
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1314
+ r"""
1315
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1316
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1317
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1318
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1319
+ """
1320
+ return_dict = (
1321
+ return_dict if return_dict is not None else self.config.use_return_dict
1322
+ )
1323
+
1324
+ transformer_outputs = self.model(
1325
+ input_ids,
1326
+ attention_mask=attention_mask,
1327
+ position_ids=position_ids,
1328
+ past_key_values=past_key_values,
1329
+ inputs_embeds=inputs_embeds,
1330
+ use_cache=use_cache,
1331
+ output_attentions=output_attentions,
1332
+ output_hidden_states=output_hidden_states,
1333
+ return_dict=return_dict,
1334
+ )
1335
+ hidden_states = transformer_outputs[0]
1336
+ logits = self.score(hidden_states)
1337
+
1338
+ if input_ids is not None:
1339
+ batch_size = input_ids.shape[0]
1340
+ else:
1341
+ batch_size = inputs_embeds.shape[0]
1342
+
1343
+ if self.config.pad_token_id is None and batch_size != 1:
1344
+ raise ValueError(
1345
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1346
+ )
1347
+ if self.config.pad_token_id is None:
1348
+ sequence_lengths = -1
1349
+ else:
1350
+ if input_ids is not None:
1351
+ sequence_lengths = (
1352
+ torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1353
+ ).to(logits.device)
1354
+ else:
1355
+ sequence_lengths = -1
1356
+
1357
+ pooled_logits = logits[
1358
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1359
+ ]
1360
+
1361
+ loss = None
1362
+ if labels is not None:
1363
+ labels = labels.to(logits.device)
1364
+ if self.config.problem_type is None:
1365
+ if self.num_labels == 1:
1366
+ self.config.problem_type = "regression"
1367
+ elif self.num_labels > 1 and (
1368
+ labels.dtype == torch.long or labels.dtype == torch.int
1369
+ ):
1370
+ self.config.problem_type = "single_label_classification"
1371
+ else:
1372
+ self.config.problem_type = "multi_label_classification"
1373
+
1374
+ if self.config.problem_type == "regression":
1375
+ loss_fct = MSELoss()
1376
+ if self.num_labels == 1:
1377
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1378
+ else:
1379
+ loss = loss_fct(pooled_logits, labels)
1380
+ elif self.config.problem_type == "single_label_classification":
1381
+ loss_fct = CrossEntropyLoss()
1382
+ loss = loss_fct(
1383
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1384
+ )
1385
+ elif self.config.problem_type == "multi_label_classification":
1386
+ loss_fct = BCEWithLogitsLoss()
1387
+ loss = loss_fct(pooled_logits, labels)
1388
+ if not return_dict:
1389
+ output = (pooled_logits,) + transformer_outputs[1:]
1390
+ return ((loss,) + output) if loss is not None else output
1391
+
1392
+ return SequenceClassifierOutputWithPast(
1393
+ loss=loss,
1394
+ logits=pooled_logits,
1395
+ past_key_values=transformer_outputs.past_key_values,
1396
+ hidden_states=transformer_outputs.hidden_states,
1397
+ attentions=transformer_outputs.attentions,
1398
+ )
model/utils.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+
5
+ # TODO
6
+ # from transformers import LlamaTokenizer
7
+ # tokenizer=LlamaTokenizer.from_pretrained("/home/lyh/weights/hf/vicuna_v13/7B/")
8
+
9
+ TOPK = 10 # topk for sparse tree
10
+
11
+
12
+
13
+ from transformers.generation.logits_process import (
14
+ LogitsProcessorList,
15
+ RepetitionPenaltyLogitsProcessor,
16
+ TemperatureLogitsWarper,
17
+ TopKLogitsWarper,
18
+ TopPLogitsWarper,
19
+ )
20
+ def prepare_logits_processor(
21
+ temperature=0.0, repetition_penalty=0.0, top_p=0.0, top_k=0
22
+ ) -> LogitsProcessorList:
23
+ processor_list = LogitsProcessorList()
24
+ if temperature >= 1e-5 and temperature != 1.0:
25
+ processor_list.append(TemperatureLogitsWarper(temperature))
26
+ if repetition_penalty > 1.0:
27
+ processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
28
+ if 1e-8 <= top_p < 1.0:
29
+ processor_list.append(TopPLogitsWarper(top_p))
30
+ if top_k > 0:
31
+ processor_list.append(TopKLogitsWarper(top_k))
32
+ return processor_list
33
+
34
+
35
+
36
+ # test_processor = prepare_logits_processor(
37
+ # 0.0, 0.0, -1, 1
38
+ # )
39
+
40
+
41
+ def pad_path(path, length, pad_value=-2):
42
+ """
43
+ Pad the given path list with a specific value up to a specified length.
44
+
45
+ Parameters:
46
+ - path (list): The original list that needs padding.
47
+ - length (int): The desired length of the padded list.
48
+ - pad_value (optional, default=-2): The value to use for padding.
49
+
50
+ Returns:
51
+ - list: A new list based on the original path but padded to the desired length.
52
+
53
+ Example:
54
+ >>> pad_path([1,2,3], 5)
55
+ [1, 2, 3, -2, -2]
56
+
57
+ Note:
58
+ If the given path is already longer than the specified length,
59
+ then no padding occurs, and the original path is returned.
60
+ """
61
+
62
+ # Calculate the number of padding values needed by subtracting the length
63
+ # of the path from the desired length.
64
+ # Append the padding values to the original path and return the new list.
65
+ return path + [pad_value] * (length - len(path))
66
+
67
+
68
+ def generate_tree_buffers(tree_choices, device="cuda"):
69
+
70
+ sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x))
71
+ tree_len = len(sorted_tree_choices) + 1
72
+
73
+ # Initialize depth_counts to keep track of how many choices have a particular depth
74
+ depth_counts = []
75
+ prev_depth = 0
76
+ for path in sorted_tree_choices:
77
+ depth = len(path)
78
+ if depth != prev_depth:
79
+ depth_counts.append(0)
80
+ depth_counts[depth - 1] += 1
81
+ prev_depth = depth
82
+
83
+
84
+ tree_attn_mask = torch.eye(tree_len, tree_len)
85
+ tree_attn_mask[:, 0] = 1
86
+ start = 0
87
+ for i in range(len(depth_counts)):
88
+ for j in range(depth_counts[i]):
89
+ cur_tree_choice = sorted_tree_choices[start + j]
90
+ # retrieve ancestor position
91
+ if len(cur_tree_choice) == 1:
92
+ continue
93
+ ancestor_idx = []
94
+ for c in range(len(cur_tree_choice) - 1):
95
+ ancestor_idx.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
96
+ tree_attn_mask[j + start + 1, ancestor_idx] = 1
97
+ start += depth_counts[i]
98
+
99
+
100
+ tree_indices = torch.zeros(tree_len, dtype=torch.long)
101
+ tree_indices[0] = 0
102
+ start = 0
103
+ bias = 0
104
+ for i in range(len(depth_counts)):
105
+ for j in range(depth_counts[i]):
106
+ cur_tree_choice = sorted_tree_choices[start + j]
107
+ cur_parent = cur_tree_choice[:-1]
108
+ if j!=0:
109
+ if cur_parent!=parent:
110
+ bias+=1
111
+ parent=cur_parent
112
+ else:
113
+ parent=cur_parent
114
+ tree_indices[start + j + 1] = cur_tree_choice[-1] + TOPK * (i+bias) + 1
115
+ start += depth_counts[i]
116
+
117
+
118
+ tree_position_ids = torch.zeros(tree_len, dtype=torch.long)
119
+ start = 0
120
+ for i in range(len(depth_counts)):
121
+ tree_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
122
+ start += depth_counts[i]
123
+
124
+
125
+ retrieve_indices_nest = []
126
+ retrieve_paths = []
127
+ for i in range(len(sorted_tree_choices)):
128
+ cur_tree_choice = sorted_tree_choices[-i - 1]
129
+ retrieve_indice = []
130
+ if cur_tree_choice in retrieve_paths:
131
+ continue
132
+ else:
133
+ for c in range(len(cur_tree_choice)):
134
+ retrieve_indice.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]))
135
+ retrieve_paths.append(cur_tree_choice[:c + 1])
136
+ retrieve_indices_nest.append(retrieve_indice)
137
+ max_length = max([len(x) for x in retrieve_indices_nest])
138
+ retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
139
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
140
+ retrieve_indices = retrieve_indices + 1
141
+ retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices],
142
+ dim=1)
143
+
144
+ # Aggregate the generated buffers into a dictionary
145
+ tree_buffers = {
146
+ "tree_attn_mask": tree_attn_mask.unsqueeze(0).unsqueeze(0),
147
+ "tree_indices": tree_indices,
148
+ "tree_position_ids": tree_position_ids,
149
+ "retrieve_indices": retrieve_indices,
150
+ }
151
+
152
+ # Move the tensors in the dictionary to the specified device
153
+ tree_buffers = {
154
+ k: v.clone().to(device)
155
+ if isinstance(v, torch.Tensor)
156
+ else torch.tensor(v, device=device)
157
+ for k, v in tree_buffers.items()
158
+ }
159
+ return tree_buffers
160
+
161
+
162
+ def initialize_tree(input_ids, model, tree_attn_mask, past_key_values,logits_processor):
163
+
164
+ tree_logits, outputs, logits,hidden_state,sample_token = model(
165
+ input_ids, past_key_values=past_key_values, output_orig=True,logits_processor=logits_processor
166
+ )
167
+ model.base_model.model.tree_mask = tree_attn_mask
168
+ return tree_logits, logits,hidden_state,sample_token
169
+
170
+
171
+ def reset_tree_mode(
172
+ model,
173
+ ):
174
+
175
+ model.base_model.model.tree_mask = None
176
+ model.base_model.model.tree_mode = None
177
+
178
+
179
+ def reset_past_key_values(passed_key_values):
180
+ """
181
+ Resets the current lengths in the passed key-values to zero.
182
+
183
+ This function is designed to be used during the evaluation of a baseline model.
184
+ It iterates through each layer's key-values and sets their current lengths to zero,
185
+ effectively resetting their state.
186
+
187
+ Args:
188
+ - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
189
+
190
+ Returns:
191
+ - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
192
+ """
193
+ for i in range(len(passed_key_values)):
194
+ for j in range(2):
195
+ passed_key_values[i][j].current_length.fill_(0)
196
+ return passed_key_values
197
+
198
+
199
+ def generate_candidates(tree_logits, tree_indices, retrieve_indices,sample_token,logits_processor):
200
+
201
+
202
+ candidates_logit = sample_token[0]
203
+
204
+
205
+ candidates_tree_logits = tree_logits[0]
206
+
207
+
208
+
209
+ candidates = torch.cat([candidates_logit, candidates_tree_logits.view(-1)], dim=-1)
210
+
211
+
212
+
213
+ tree_candidates = candidates[tree_indices]
214
+
215
+
216
+
217
+ tree_candidates_ext = torch.cat(
218
+ [tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device)], dim=0)
219
+
220
+
221
+
222
+ cart_candidates = tree_candidates_ext[retrieve_indices]
223
+
224
+ if logits_processor is not None:
225
+ candidates_tree_prob = tree_logits[1]
226
+ candidates_prob = torch.cat(
227
+ [torch.ones(1, device=candidates_tree_prob.device, dtype=torch.float32), candidates_tree_prob.view(-1)],
228
+ dim=-1)
229
+
230
+ tree_candidates_prob = candidates_prob[tree_indices]
231
+ tree_candidates_prob_ext = torch.cat(
232
+ [tree_candidates_prob, torch.ones((1), dtype=torch.float32, device=tree_candidates_prob.device)], dim=0)
233
+ cart_candidates_prob = tree_candidates_prob_ext[retrieve_indices]
234
+ else:
235
+ cart_candidates_prob=None
236
+ # Unsqueeze the tree candidates for dimension consistency.
237
+ tree_candidates = tree_candidates.unsqueeze(0)
238
+ return cart_candidates,cart_candidates_prob, tree_candidates
239
+
240
+
241
+ def tree_decoding(
242
+ model,
243
+ tree_candidates,
244
+ past_key_values,
245
+ tree_position_ids,
246
+ input_ids,
247
+ retrieve_indices,
248
+ ):
249
+
250
+ position_ids = tree_position_ids + input_ids.shape[1]
251
+
252
+
253
+ outputs,tree_logits,hidden_state = model(
254
+ tree_candidates,
255
+ output_orig=True,
256
+ past_key_values=past_key_values,
257
+ position_ids=position_ids,
258
+ init=False,
259
+ )
260
+
261
+
262
+ logits = tree_logits[0, retrieve_indices]
263
+ return logits, hidden_state,outputs
264
+
265
+
266
+ def evaluate_posterior(
267
+ logits, candidates, logits_processor,cart_candidates_prob
268
+ ):
269
+ """
270
+ Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.
271
+
272
+ Depending on the temperature value, the function either uses greedy decoding or evaluates posterior
273
+ probabilities to select the best candidate.
274
+
275
+ Args:
276
+ - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
277
+ - candidates (torch.Tensor): Candidate token sequences.
278
+ - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
279
+ - posterior_threshold (float): Threshold for posterior probability.
280
+ - posterior_alpha (float): Scaling factor for the threshold.
281
+
282
+ Returns:
283
+ - best_candidate (torch.Tensor): Index of the chosen best candidate.
284
+ - accept_length (int): Length of the accepted candidate sequence.
285
+ """
286
+ # Greedy decoding based on temperature value
287
+ if logits_processor is None:
288
+ # Find the tokens that match the maximum logits for each position in the sequence
289
+ posterior_mask = (
290
+ candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
291
+ ).int()
292
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
293
+ accept_length = candidates_accept_length.max()
294
+ # Choose the best candidate
295
+ if accept_length == 0:
296
+ # Default to the first candidate if none are accepted
297
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
298
+ else:
299
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
300
+ return best_candidate, accept_length,logits[best_candidate, accept_length]
301
+
302
+ else:
303
+ accept_length=1
304
+ accept_cand=candidates[0][:1]
305
+ best_candidate=0
306
+ #breakflag=False
307
+ for i in range(1,candidates.shape[1]):
308
+ is_eq=(candidates[:,:accept_length]==accept_cand).all(dim=1)
309
+ if i!=accept_length:
310
+ #breakflag=True
311
+ break
312
+ fi=torch.nonzero(is_eq, as_tuple=True)[0][0]
313
+ gt_logits=logits[fi,i-1][None]
314
+ gt_logits=logits_processor(None,gt_logits)[0]
315
+ gtp=torch.softmax(gt_logits,dim=0)
316
+ adjustflag=False
317
+ for j in range(candidates.shape[0]):
318
+ if is_eq[j]:
319
+ r=random.random()
320
+ x=candidates[j,i]
321
+ if x==0:
322
+ continue
323
+ px=gtp[x]
324
+ qx=cart_candidates_prob[j,i]
325
+ acp=px/qx
326
+ if r<=acp:
327
+ accept_cand=torch.cat((accept_cand,x[None]),dim=0)
328
+ accept_length+=1
329
+ best_candidate=j
330
+ break
331
+ else:
332
+ gtp[x]=max(px-qx,0)
333
+ gtp=gtp/gtp.sum()
334
+ adjustflag=True
335
+ if adjustflag:
336
+ sample_p=gtp
337
+ else:
338
+ gt_logits = logits[best_candidate, accept_length-1]
339
+ sample_p=torch.softmax(gt_logits,dim=0)
340
+ return torch.tensor(best_candidate), accept_length-1,sample_p
341
+
342
+
343
+
344
+
345
+
346
+
347
+
348
+ @torch.no_grad()
349
+ def update_inference_inputs(
350
+ input_ids,
351
+ candidates,
352
+ best_candidate,
353
+ accept_length,
354
+ retrieve_indices,
355
+ logits_processor,
356
+ logits,
357
+ tree_logits,
358
+ new_token,
359
+ past_key_values_data_list,
360
+ current_length_data,
361
+ model,
362
+ hidden_state,
363
+ hidden_state_new,
364
+ sample_p
365
+ ):
366
+
367
+ prev_input_len = input_ids.shape[1]
368
+ # Map the best candidate indices to the original indices in the sequence
369
+ select_indices = (
370
+ retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len
371
+ )
372
+ # Append the tokens from the best candidate to the input sequence
373
+ input_ids = torch.cat(
374
+ [input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1
375
+ )
376
+ # Update the past key values based on the selected tokens
377
+ # Source tensor that contains relevant past information based on the selected candidate
378
+ for past_key_values_data in past_key_values_data_list:
379
+ tgt = past_key_values_data[..., select_indices.to(past_key_values_data.device), :]
380
+ # Destination tensor where the relevant past information will be stored
381
+ dst = past_key_values_data[..., prev_input_len : prev_input_len + tgt.shape[-2], :]
382
+ # Copy relevant past information from the source to the destination
383
+ dst.copy_(tgt, non_blocking=True)
384
+
385
+ # Update the current length tensor (currently only support batch size is 1)
386
+ current_length_data.fill_(prev_input_len + tgt.shape[-2])
387
+
388
+
389
+ retrieve_hidden_state_new=hidden_state_new[:,retrieve_indices]
390
+ accept_hidden_state_new=retrieve_hidden_state_new[:,best_candidate, : accept_length + 1]
391
+ #token=model.base_model.lm_head(accept_hidden_state_new[:,-1]).argmax()
392
+ #token=token[None,None]
393
+ prob = sample_p
394
+ if logits_processor is not None:
395
+ token = torch.multinomial(prob, 1)
396
+ token=token[None]
397
+ else:
398
+ token=torch.argmax(prob)
399
+ token=token[None,None]
400
+ hidden_state=torch.cat((hidden_state,accept_hidden_state_new),dim=1)
401
+ tree_logits=model.ea_layer.topK_genrate(hidden_state,input_ids=torch.cat((input_ids,token.to(input_ids.device)),dim=1),head=model.base_model.lm_head,logits_processor=logits_processor)
402
+
403
+ new_token += accept_length + 1
404
+
405
+ return input_ids, tree_logits, new_token,hidden_state,token
406
+
407
+ if __name__=="__main__":
408
+ logits=torch.randn(1,5)
409
+ tp = prepare_logits_processor(0.9, 0, 0.9, 0)
410
+ l=tp(None,logits)
411
+ if tp is None:
412
+ print(tp)
model/utils_c.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ TOPK = 10 # topk for sparse tree
4
+
5
+
6
+ def pad_path(path, length, pad_value=-2):
7
+ """
8
+ Pad the given path list with a specific value up to a specified length.
9
+
10
+ Parameters:
11
+ - path (list): The original list that needs padding.
12
+ - length (int): The desired length of the padded list.
13
+ - pad_value (optional, default=-2): The value to use for padding.
14
+
15
+ Returns:
16
+ - list: A new list based on the original path but padded to the desired length.
17
+
18
+ Example:
19
+ >>> pad_path([1,2,3], 5)
20
+ [1, 2, 3, -2, -2]
21
+
22
+ Note:
23
+ If the given path is already longer than the specified length,
24
+ then no padding occurs, and the original path is returned.
25
+ """
26
+
27
+ # Calculate the number of padding values needed by subtracting the length
28
+ # of the path from the desired length.
29
+ # Append the padding values to the original path and return the new list.
30
+ return path + [pad_value] * (length - len(path))
31
+
32
+ class node:
33
+ def __init__(self,parent=None,value=None,dict_key=None):
34
+ self.parent=parent
35
+ self.value=value
36
+ if parent:
37
+ self.depth=parent.depth+1
38
+ parent.children.append(self)
39
+ else:
40
+ self.depth=0
41
+ self.children=[]
42
+ self.dict_key=dict_key
43
+ def is_leaf(self):
44
+ return len(self.children)==0
45
+
46
+ def all_index(self):
47
+ if not self.parent.parent:
48
+ return [self.index]
49
+ else:
50
+ return self.parent.all_index()+[self.index]
51
+
52
+
53
+
54
+ class Tree:
55
+ def __init__(self,tree_list):
56
+ sorted_tree_list = sorted(tree_list, key=lambda x: (len(x), x))
57
+ self.root=node()
58
+ self.node_dic={}
59
+ for tree_node in sorted_tree_list:
60
+ cur_value=tree_node[-1]
61
+ if len(tree_node)==1:
62
+ cur_node=node(parent=self.root,value=cur_value,dict_key=tuple(tree_node))
63
+ else:
64
+ cur_parent=self.node_dic[tuple(tree_node[:-1])]
65
+ cur_node = node(parent=cur_parent, value=cur_value,dict_key=tuple(tree_node))
66
+ self.node_dic[tuple(tree_node)] = cur_node
67
+ self.indexnode()
68
+
69
+ def max_depth(self):
70
+ return max([item.depth for item in self.node_dic.values()])
71
+
72
+ def num_node_wchild(self):
73
+ num_c=0
74
+ for item in self.node_dic.values():
75
+ if not item.is_leaf():
76
+ num_c+=1
77
+ return num_c
78
+
79
+ def get_node_wchild(self):
80
+ ns=[]
81
+ for item in self.node_dic.values():
82
+ if not item.is_leaf():
83
+ ns.append(item)
84
+ return ns
85
+
86
+ def indexnode(self):
87
+ cur_index=0
88
+ for key in self.node_dic:
89
+ cur_node=self.node_dic[key]
90
+ if not cur_node.is_leaf():
91
+ cur_node.index=cur_index
92
+ cur_index+=1
93
+
94
+
95
+
96
+
97
+ def generate_tree_buffers(tree_choices, device="cuda"):
98
+ tree=Tree(tree_choices)
99
+ sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x))
100
+ tree_len = tree.num_node_wchild()
101
+
102
+
103
+ max_depth=tree.max_depth()
104
+ nodes_wc=tree.get_node_wchild()
105
+
106
+ depth_counts=[0 for _ in range(max_depth-1)]
107
+ for x in nodes_wc:
108
+ depth_counts[x.depth-1]+=1
109
+ depth_counts_sum = [sum(depth_counts[:i + 1]) for i in range(len(depth_counts))]
110
+
111
+
112
+ tree_attn_mask = torch.eye(tree_len, tree_len)
113
+
114
+ for id,x in enumerate(nodes_wc):
115
+ tree_attn_mask[id,x.all_index()]=1
116
+
117
+
118
+
119
+
120
+ tree_attn_mask_list0=[tree_attn_mask[:ml,:ml] for ml in depth_counts_sum]
121
+ tree_attn_mask_list=[]
122
+ for id,x in enumerate(tree_attn_mask_list0):
123
+ x=x[-depth_counts[id]:]
124
+ tree_attn_mask_list.append(x)
125
+
126
+
127
+
128
+ tree_indices_list = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts]
129
+ repeat_nums=[[] for _ in depth_counts]
130
+ start = 0
131
+ bias = 0
132
+ for i in range(len(depth_counts)):
133
+ bias = 0
134
+ repeat_j=0
135
+ for j in range(depth_counts[i]):
136
+ cur_node = nodes_wc[start + j]
137
+ cur_parent = cur_node.parent
138
+
139
+ if j != 0:
140
+ if cur_parent != parent:
141
+ bias += 1
142
+ parent = cur_parent
143
+ repeat_nums[i].append(j-repeat_j)
144
+ repeat_j=j
145
+ else:
146
+ parent = cur_parent
147
+ tree_indices_list[i][j] = cur_node.value + TOPK * (bias)
148
+ repeat_nums[i].append(j - repeat_j+1)
149
+ start += depth_counts[i]
150
+
151
+ position_ids = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts]
152
+
153
+ # start = 0
154
+ # for i in range(len(depth_counts)):
155
+ # position_ids[start: start + depth_counts[i]] = i
156
+ # start += depth_counts[i]
157
+
158
+ tree_buffers = {
159
+ "attn_mask": [i.unsqueeze(0).unsqueeze(0) for i in tree_attn_mask_list],
160
+ "tree_indices": tree_indices_list,
161
+ "position_ids":position_ids,
162
+ "repeat_nums":repeat_nums
163
+ }
164
+
165
+ # Move the tensors in the dictionary to the specified device
166
+ tree_buffers = {
167
+ k: [i.clone().to(device) for i in v]
168
+ if isinstance(v[0], torch.Tensor)
169
+ else (
170
+ torch.tensor(v, device=device)
171
+ if isinstance(v, torch.Tensor)
172
+ else v
173
+ )
174
+ for k, v in tree_buffers.items()
175
+ }
176
+ return tree_buffers
177
+
178
+
179
+ def reset_past_key_values(passed_key_values):
180
+ """
181
+ Resets the current lengths in the passed key-values to zero.
182
+
183
+ This function is designed to be used during the evaluation of a baseline model.
184
+ It iterates through each layer's key-values and sets their current lengths to zero,
185
+ effectively resetting their state.
186
+
187
+ Args:
188
+ - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
189
+
190
+ Returns:
191
+ - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
192
+ """
193
+ for i in range(len(passed_key_values)):
194
+ for j in range(2):
195
+ passed_key_values[i][j].current_length.fill_(0)
196
+ return passed_key_values
197
+
198
+
199
+
200
+ if __name__=="__main__":
201
+ from choices import mc_sim_7b_63
202
+ a=generate_tree_buffers(mc_sim_7b_63)
203
+ print(a)