mac commited on
Commit
97d9cf5
·
1 Parent(s): beb8609

upload_code

Browse files
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: MiniCPM4.1 8B Demo
3
- emoji: 🦀
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.46.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: chat with MiniCPM4.1-8B with speculative decoding
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MiniCPM4.1 8B Eagle3 Straming
3
+ emoji: 🚀
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
10
+ tags:
11
+ - anycoder
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MiniCPM-4.1-8B-Eagle3
2
+
3
+ from pathlib import Path
4
+ import time
5
+ import logging
6
+ import gradio as gr
7
+ import torch
8
+ import spaces
9
+ import threading
10
+ from transformers import AutoTokenizer, TextIteratorStreamer
11
+ # 导入模型相关模块
12
+ from eagle.model.ea_model import EaModel
13
+ from utils_chatbot import organize_messages, stream2display_text, mtp_new_tokens
14
+
15
+ # 日志配置
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ # 全局模型实例
21
+ global_model = None
22
+ # 全局模型缓存(在GPU进程中)
23
+ _gpu_model_cache = None
24
+ # 全局模型配置
25
+ model_config = dict(
26
+ base_model_path = "openbmb/MiniCPM4.1-8B",
27
+ ea_model_path = "openbmb/MiniCPM4.1-8B-Eagle3/MiniCPM4_1-8B-Eagle3-bf16",
28
+ total_token=40,
29
+ depth=3,
30
+ top_k=10,
31
+ threshold=1.0,
32
+ use_eagle3=True,
33
+ device_map = "cpu",
34
+ trust_remote_code=True,
35
+ )
36
+
37
+ # 提前加载 tokenizer
38
+ tokenizer = AutoTokenizer.from_pretrained(
39
+ "openbmb/MiniCPM4.1-8B",
40
+ use_fast=False,
41
+ device_map="cpu",
42
+ )
43
+
44
+ def _initialize_gpu_model():
45
+ """在GPU进程中获取模型并移到GPU"""
46
+ global _gpu_model_cache
47
+ if _gpu_model_cache is None:
48
+ logger.info(f"在GPU进程中初始化模型")
49
+ _gpu_model_cache = EaModel.from_pretrained(**model_config)
50
+ logger.info(f"模型在CPU上初始化完成")
51
+ return _gpu_model_cache
52
+
53
+ @spaces.GPU(duration=42) # default is 60
54
+ def gpu_handler(inputs):
55
+ prompt_text = tokenizer.apply_chat_template(
56
+ inputs,
57
+ tokenize=False,
58
+ add_generation_prompt=True,
59
+ )
60
+ model_inputs = tokenizer([prompt_text], return_tensors="pt")
61
+ inputs = {
62
+ "model_inputs": model_inputs,
63
+ "max_new_tokens": 65536,
64
+ "temperature": 0.6,
65
+ "top_p": 0.95,
66
+ "top_k": 50,
67
+ "max_length": 65536,
68
+ }
69
+
70
+ logger.info(f"向 GPU 搬运 global_model")
71
+
72
+ """GPU推理处理器"""
73
+ model = _initialize_gpu_model()
74
+
75
+ cuda_inputs = dict(
76
+ input_ids=inputs["model_inputs"].input_ids.to("cuda"),
77
+ # attention_mask=inputs["model_inputs"].attention_mask.to("cuda"),
78
+ max_new_tokens=inputs["max_new_tokens"],
79
+ temperature=inputs["temperature"],
80
+ top_p=inputs["top_p"],
81
+ top_k=inputs["top_k"],
82
+ max_length=inputs["max_length"],
83
+ )
84
+
85
+ model.base_model.to("cuda")
86
+ model.ea_layer.to("cuda")
87
+ model.ea_layer.tree_mask_init.to("cuda")
88
+
89
+ logger.info(f"pass inputs to global_model")
90
+
91
+ output_ids = model.eagenerate(**cuda_inputs)
92
+
93
+ logger.info(f"got outputs from global_model.eagenerate")
94
+ new_text = tokenizer.decode(
95
+ output_ids[0][model_inputs.input_ids.shape[1]:],
96
+ skip_special_tokens=True,
97
+ )
98
+
99
+ return new_text
100
+
101
+ @spaces.GPU(duration=60) # default is 60
102
+ def gpu_handler_s(
103
+ inputs,
104
+ history,
105
+ temperature,
106
+ top_p,
107
+ use_eagle,
108
+ ):
109
+ prompt_text = tokenizer.apply_chat_template(
110
+ inputs,
111
+ tokenize=False,
112
+ add_generation_prompt=True,
113
+ )
114
+ model_inputs = tokenizer([prompt_text], return_tensors="pt")
115
+ inputs = {
116
+ "model_inputs": model_inputs,
117
+ "max_new_tokens": 4096,
118
+ "temperature": temperature,
119
+ "top_p": top_p,
120
+ "top_k": 50,
121
+ "max_length": 65536,
122
+ }
123
+
124
+ logger.info(f"向 GPU 搬运 global_model")
125
+
126
+ """GPU推理处理器"""
127
+ model = _initialize_gpu_model()
128
+
129
+ cuda_inputs = dict(
130
+ input_ids=inputs["model_inputs"].input_ids.to("cuda"),
131
+ # attention_mask=inputs["model_inputs"].attention_mask.to("cuda"),
132
+ max_new_tokens=inputs["max_new_tokens"],
133
+ temperature=inputs["temperature"],
134
+ top_p=inputs["top_p"],
135
+ top_k=inputs["top_k"],
136
+ max_length=inputs["max_length"],
137
+ )
138
+
139
+ model.base_model.to("cuda")
140
+ model.ea_layer.to("cuda")
141
+ model.ea_layer.tree_mask_init.to("cuda")
142
+
143
+ logger.info(f"pass inputs to global_model")
144
+
145
+ yield "", history
146
+
147
+ stop_token_ids = [
148
+ tokenizer.eos_token_id,
149
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
150
+ ]
151
+ gen_tk_count, existing_tk_count = 0, len(inputs["model_inputs"].input_ids[0])
152
+
153
+ stream_text, start_time = "", time.time()
154
+
155
+ generate_func = model.ea_generate if use_eagle else model.naive_generate
156
+
157
+ for output_ids in generate_func(**cuda_inputs):
158
+ # for output_ids in model.ea_generate(**cuda_inputs):
159
+ new_tokens, gen_tk_count = mtp_new_tokens(output_ids, gen_tk_count, existing_tk_count, stop_token_ids)
160
+ new_token_text = tokenizer.decode(new_tokens, skip_special_tokens=False)
161
+ logger.info(f"[TOKEN]'''{new_token_text}'''")
162
+ stream_text += new_token_text
163
+ token_per_sec = gen_tk_count / (time.time() - start_time)
164
+ display_text = stream2display_text(stream_text, token_per_sec)
165
+ history[-1] = (history[-1][0], display_text)
166
+ yield "", history
167
+
168
+ # logger.info(f"all gen text: \n{stream_text}")
169
+ history[-1] = (history[-1][0], stream_text.replace("<|im_end|>", ""))
170
+ # 替换 history 为非 display 形态的 text
171
+
172
+
173
+ class Model:
174
+ """模型封装类,不持有实际模型对象"""
175
+
176
+ def __init__(self):
177
+ logger.info(f"创建封装类")
178
+
179
+ def handler(self, inputs):
180
+ """非流式推理处理器"""
181
+ return gpu_handler(inputs)
182
+
183
+ def stream_handler(self, inputs, history, **kwargs):
184
+ """流式推理处理器"""
185
+ yield from gpu_handler_s(inputs, history, **kwargs)
186
+
187
+
188
+ def initialize_model():
189
+ """初始化全局模型"""
190
+ global global_model, _gpu_model_cache
191
+
192
+ # 默认配置
193
+ logger.info(f"="*50)
194
+ logger.info(f"启动 MiniCPM-4.1-8B-Eagle3 Chatbot 服务")
195
+ logger.info(f"="*50)
196
+
197
+ # 创建模型封装类
198
+ global_model = Model()
199
+
200
+ # 在主进程中预加载模型到CPU(For faster 首次推理)
201
+ try:
202
+ logger.info("在主进程中预加载模型到 CPU...")
203
+ _gpu_model_cache = EaModel.from_pretrained(**model_config)
204
+ logger.info("模型在主进程CPU上预加载完成")
205
+ except Exception as e:
206
+ logger.warning(f"主进程预加载模型失败, 将在GPU进程中加载: {e}")
207
+ _gpu_model_cache = None
208
+
209
+ return global_model
210
+
211
+
212
+ def gen_response(message, history, temperature, top_p):
213
+ chat_msg_ls = organize_messages(message, history)
214
+
215
+ new_text = global_model.handler(chat_msg_ls)
216
+
217
+ history.append((message, new_text))
218
+ return "", history
219
+
220
+ def gen_response_stream(
221
+ message,
222
+ history,
223
+ temperature,
224
+ top_p,
225
+ use_eagle,
226
+ ):
227
+ chat_msg_ls = organize_messages(message, history)
228
+
229
+ history.append((message, ""))
230
+
231
+ sampling_kwargs = dict(
232
+ temperature = temperature,
233
+ top_p = top_p,
234
+ use_eagle = use_eagle,
235
+ )
236
+
237
+ yield from global_model.stream_handler(chat_msg_ls, history, **sampling_kwargs)
238
+
239
+ def create_app():
240
+ assets_path = Path.cwd().absolute()/"assets"
241
+ logger.info(f"设置静态资源路径: {assets_path}")
242
+ gr.set_static_paths(paths=[assets_path])
243
+ logger.info("静态资源路径设置完成")
244
+
245
+ theme = gr.themes.Soft(
246
+ primary_hue="blue",
247
+ secondary_hue="gray",
248
+ neutral_hue="slate",
249
+ font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
250
+ )
251
+
252
+ # # Add border styling to components
253
+ # theme = theme.set(
254
+ # primary_border_size='1px', # 组件外框
255
+ # primary_border_color='*neutral_400', # 用主题里的 slate-400 灰色
256
+ # )
257
+
258
+ with gr.Blocks(
259
+ theme=theme,
260
+ css="""
261
+ .logo-container {
262
+ text-align: center;
263
+ margin: 0.5rem 0 1rem 0;
264
+ }
265
+ .logo-container img {
266
+ height: 96px;
267
+ width: auto;
268
+ max-width: 200px;
269
+ display: inline-block;
270
+ }
271
+ .input-box {
272
+ border: 1px solid #2f63b8;
273
+ border-radius: 8px;
274
+ }
275
+ """,
276
+ ) as demo:
277
+ with gr.Row():
278
+ with gr.Column(scale=1):
279
+ gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/OpenBMB-MiniCPM.png" alt="MiniCPM Logo"></div>')
280
+
281
+ blank_1 = gr.HTML("<div style='height:1px;'></div>")
282
+
283
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="Temperature", scale=1)
284
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.01, label="Top-p", scale=1)
285
+ use_eagle = gr.Checkbox(label="Speculative Decoding", value=True)
286
+
287
+ blank_2 = gr.HTML("<div style='height:120px;'></div>")
288
+
289
+ clear = gr.Button("Clear History")
290
+
291
+ gr.Markdown(
292
+ """
293
+ Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>
294
+ """
295
+ )
296
+ with gr.Column(scale=4):
297
+ chatbot = gr.Chatbot(label="Chat History", placeholder="Input to start a new chat", height=500)
298
+ prompt = gr.Textbox(
299
+ label="Input Text",
300
+ placeholder="Type your message here...",
301
+ lines=1,
302
+ # submit_btn=True,
303
+ elem_classes=["input-box"], # 自定义 class 供 css 使用
304
+ )
305
+
306
+ prompt.submit(gen_response_stream, inputs=[prompt, chatbot, temperature, top_p, use_eagle], outputs=[prompt, chatbot])
307
+ clear.click(lambda: None, None, chatbot, queue=False)
308
+
309
+ return demo
310
+
311
+
312
+ if __name__ == "__main__":
313
+ # 初始化模型
314
+ initialize_model()
315
+
316
+ # 创建并启动应用
317
+ demo = create_app()
318
+ demo.launch()
319
+
eagle/model/__init__.py ADDED
File without changes
eagle/model/choices.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ mc_sim_7b_63 = [[0],[1],[2],[3],[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0]
2
+ ,[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,2,0],[0,2,1],[1,0,0],
3
+ [0,0,0,0],[0,0,0,1],[0,0,0,2],[0,0,0,0,0],[0,0,0,0,1]]
eagle/model/cnets.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import copy
22
+ import os
23
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "5"
24
+ import math
25
+ from typing import List, Optional, Tuple, Union
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+
30
+ from transformers.activations import ACT2FN
31
+ from huggingface_hub import hf_hub_download
32
+
33
+
34
+ try:
35
+ from .configs import EConfig
36
+ from .utils_c import *
37
+ from .choices import *
38
+ except:
39
+ from configs import EConfig
40
+ from utils_c import *
41
+ from choices import *
42
+ from utils import prepare_logits_processor
43
+
44
+
45
+
46
+
47
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
48
+ def _make_causal_mask(
49
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
50
+ ):
51
+ """
52
+ Make causal mask used for bi-directional self-attention.
53
+ """
54
+ bsz, tgt_len = input_ids_shape
55
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
56
+ mask_cond = torch.arange(mask.size(-1), device=device)
57
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
58
+ mask = mask.to(dtype)
59
+
60
+ if past_key_values_length > 0:
61
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
62
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
63
+
64
+
65
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
66
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
67
+ """
68
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
69
+ """
70
+ bsz, src_len = mask.size()
71
+ tgt_len = tgt_len if tgt_len is not None else src_len
72
+
73
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
74
+
75
+ inverted_mask = 1.0 - expanded_mask
76
+
77
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
78
+
79
+
80
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
81
+ """
82
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
83
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
84
+ """
85
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
86
+ if n_rep == 1:
87
+ return hidden_states
88
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
89
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
90
+
91
+
92
+ def rotate_half(x):
93
+ """Rotates half the hidden dims of the input."""
94
+ x1 = x[..., : x.shape[-1] // 2]
95
+ x2 = x[..., x.shape[-1] // 2:]
96
+ return torch.cat((-x2, x1), dim=-1)
97
+
98
+
99
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
100
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
101
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
102
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
103
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
104
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
105
+ q_embed = (q * cos) + (rotate_half(q) * sin)
106
+ k_embed = (k * cos) + (rotate_half(k) * sin)
107
+ return q_embed, k_embed
108
+
109
+
110
+ class LlamaRotaryEmbedding(torch.nn.Module):
111
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
112
+ super().__init__()
113
+
114
+ self.dim = dim
115
+ self.max_position_embeddings = max_position_embeddings
116
+ self.base = base
117
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
118
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
119
+
120
+ # Build here to make `torch.jit.trace` work.
121
+ self._set_cos_sin_cache(
122
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
123
+ )
124
+
125
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
126
+ self.max_seq_len_cached = seq_len
127
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
128
+
129
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
130
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
131
+ emb = torch.cat((freqs, freqs), dim=-1)
132
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
133
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
134
+
135
+ def forward(self, x, seq_len=None):
136
+ # x: [bs, num_attention_heads, seq_len, head_size]
137
+ if seq_len > self.max_seq_len_cached:
138
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
139
+
140
+ return (
141
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
142
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
143
+ )
144
+
145
+
146
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
147
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
148
+
149
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
150
+ self.scaling_factor = scaling_factor
151
+ super().__init__(dim, max_position_embeddings, base, device)
152
+
153
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
154
+ self.max_seq_len_cached = seq_len
155
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
156
+ t = t / self.scaling_factor
157
+
158
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
159
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
162
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
163
+
164
+
165
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
166
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
167
+
168
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
169
+ self.scaling_factor = scaling_factor
170
+ super().__init__(dim, max_position_embeddings, base, device)
171
+
172
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
173
+ self.max_seq_len_cached = seq_len
174
+
175
+ if seq_len > self.max_position_embeddings:
176
+ base = self.base * (
177
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
178
+ ) ** (self.dim / (self.dim - 2))
179
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
180
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
181
+
182
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
183
+
184
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
185
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
186
+ emb = torch.cat((freqs, freqs), dim=-1)
187
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
188
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
189
+
190
+
191
+ class MiniCPMLongRoPE(LlamaRotaryEmbedding):
192
+ """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
193
+
194
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, short_factor=None, long_factor=None, original_max_position_embeddings=None):
195
+ self.short_factor = short_factor
196
+ self.long_factor = long_factor
197
+ self.original_max_position_embeddings = original_max_position_embeddings
198
+ scale = (max_position_embeddings / self.original_max_position_embeddings)
199
+ self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
200
+ super().__init__(dim, max_position_embeddings, base, device)
201
+
202
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
203
+ self.max_seq_len_cached = seq_len
204
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
205
+ if seq_len > self.original_max_position_embeddings:
206
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
207
+ else:
208
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
209
+
210
+ freqs = torch.mul(
211
+ torch.outer(t, 1.0 / ext_factors).to(device=device),
212
+ self.inv_freq.to(device=device).to(dtype)
213
+ )
214
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
215
+ emb = torch.cat((freqs, freqs), dim=-1)
216
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False)
217
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False)
218
+
219
+
220
+ class LlamaAttention(nn.Module):
221
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
222
+
223
+ def __init__(self, config):
224
+ super().__init__()
225
+ self.config = config
226
+ self.hidden_size = config.hidden_size
227
+ self.num_heads = config.num_attention_heads
228
+ self.head_dim = self.hidden_size // self.num_heads
229
+ self.num_key_value_heads = config.num_key_value_heads
230
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
231
+ self.max_position_embeddings = config.max_position_embeddings
232
+ self.rope_theta = config.rope_theta
233
+
234
+ if (self.head_dim * self.num_heads) != self.hidden_size:
235
+ raise ValueError(
236
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
237
+ f" and `num_heads`: {self.num_heads})."
238
+ )
239
+ self.q_proj = nn.Linear(self.hidden_size * 2, self.num_heads * self.head_dim, bias=False)
240
+ self.k_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
241
+ self.v_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
242
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
243
+ self._init_rope()
244
+
245
+ def _init_rope(self):
246
+ if self.config.rope_scaling is None:
247
+ if hasattr(self.config, "rope_theta"):
248
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim,
249
+ max_position_embeddings=self.max_position_embeddings,
250
+ base=self.config.rope_theta)
251
+ else:
252
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim,
253
+ max_position_embeddings=self.max_position_embeddings)
254
+ else:
255
+ scaling_type = self.config.rope_scaling["type"]
256
+ scaling_factor = self.config.rope_scaling["factor"]
257
+ if scaling_type == "linear":
258
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
259
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
260
+ )
261
+ elif scaling_type == "dynamic":
262
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
263
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
264
+ )
265
+ elif scaling_type == "longrope":
266
+ self.rotary_emb = MiniCPMLongRoPE(
267
+ self.head_dim,
268
+ max_position_embeddings=self.max_position_embeddings,
269
+ short_factor=self.config.rope_scaling['short_factor'],
270
+ long_factor=self.config.rope_scaling['long_factor'],
271
+ base=self.rope_theta,
272
+ original_max_position_embeddings=self.config.rope_scaling['original_max_position_embeddings']
273
+ )
274
+ else:
275
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
276
+
277
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
278
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ position_ids: Optional[torch.LongTensor] = None,
285
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
286
+ output_attentions: bool = False,
287
+ use_cache: bool = False,
288
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
289
+ bsz, q_len, _ = hidden_states.size()
290
+
291
+ if self.config.pretraining_tp > 1:
292
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
293
+ query_slices = self.q_proj.weight.split(
294
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
295
+ )
296
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
297
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
298
+
299
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
300
+ query_states = torch.cat(query_states, dim=-1)
301
+
302
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
303
+ key_states = torch.cat(key_states, dim=-1)
304
+
305
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
306
+ value_states = torch.cat(value_states, dim=-1)
307
+
308
+ else:
309
+ query_states = self.q_proj(hidden_states)
310
+ key_states = self.k_proj(hidden_states)
311
+ value_states = self.v_proj(hidden_states)
312
+
313
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
314
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
315
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
316
+
317
+ kv_seq_len = key_states.shape[-2]
318
+ if past_key_value is not None:
319
+ kv_seq_len += past_key_value[0].shape[-2]
320
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
321
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
322
+
323
+ if past_key_value is not None:
324
+ # reuse k, v, self_attention
325
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
326
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
327
+
328
+ past_key_value = (key_states, value_states) if use_cache else None
329
+
330
+ # repeat k/v heads if n_kv_heads < n_heads
331
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
332
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
333
+
334
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
335
+
336
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
337
+ raise ValueError(
338
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
339
+ f" {attn_weights.size()}"
340
+ )
341
+
342
+ if attention_mask is not None:
343
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
344
+ raise ValueError(
345
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
346
+ )
347
+ attn_weights = attn_weights + attention_mask
348
+
349
+ # upcast attention to fp32
350
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
351
+ attn_output = torch.matmul(attn_weights, value_states)
352
+
353
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
354
+ raise ValueError(
355
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
356
+ f" {attn_output.size()}"
357
+ )
358
+
359
+ attn_output = attn_output.transpose(1, 2).contiguous()
360
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
361
+
362
+ if self.config.pretraining_tp > 1:
363
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
364
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
365
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
366
+ else:
367
+ attn_output = self.o_proj(attn_output)
368
+
369
+ if not output_attentions:
370
+ attn_weights = None
371
+
372
+ return attn_output, attn_weights, past_key_value
373
+
374
+
375
+ class LlamaMLP(nn.Module):
376
+ def __init__(self, config):
377
+ super().__init__()
378
+ self.config = config
379
+ self.hidden_size = config.hidden_size
380
+ self.intermediate_size = config.intermediate_size
381
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
382
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
383
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
384
+ self.act_fn = ACT2FN[config.hidden_act]
385
+
386
+ def forward(self, x):
387
+ if self.config.pretraining_tp > 1:
388
+ slice = self.intermediate_size // self.config.pretraining_tp
389
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
390
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
391
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
392
+
393
+ gate_proj = torch.cat(
394
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
395
+ )
396
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
397
+
398
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
399
+ down_proj = [
400
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
401
+ ]
402
+ down_proj = sum(down_proj)
403
+ else:
404
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
405
+
406
+ return down_proj
407
+
408
+
409
+ class LlamaRMSNorm(nn.Module):
410
+ def __init__(self, hidden_size, eps=1e-6):
411
+ """
412
+ LlamaRMSNorm is equivalent to T5LayerNorm
413
+ """
414
+ super().__init__()
415
+ self.weight = nn.Parameter(torch.ones(hidden_size))
416
+ self.variance_epsilon = eps
417
+
418
+ def forward(self, hidden_states):
419
+ input_dtype = hidden_states.dtype
420
+ hidden_states = hidden_states.to(torch.float32)
421
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
422
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
423
+ return self.weight * hidden_states.to(input_dtype)
424
+
425
+
426
+ class LlamaDecoderLayeremb(nn.Module):
427
+ def __init__(self, config, last=True):
428
+ super().__init__()
429
+ self.hidden_size = config.hidden_size
430
+ self.self_attn = LlamaAttention(config=config)
431
+ self.mlp = LlamaMLP(config)
432
+ self.last = last
433
+ # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size)
434
+ self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
435
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
436
+ # if self.index!=0:
437
+
438
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
439
+
440
+ def forward(
441
+ self,
442
+ input_emb: torch.Tensor,
443
+ hidden_states: torch.Tensor,
444
+ attention_mask: Optional[torch.Tensor] = None,
445
+ position_ids: Optional[torch.LongTensor] = None,
446
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
447
+ output_attentions: Optional[bool] = False,
448
+ use_cache: Optional[bool] = False,
449
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
450
+ """
451
+ Args:
452
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
453
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
454
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
455
+ output_attentions (`bool`, *optional*):
456
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
457
+ returned tensors for more detail.
458
+ use_cache (`bool`, *optional*):
459
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
460
+ (see `past_key_values`).
461
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
462
+ """
463
+
464
+ residual = hidden_states
465
+
466
+ hidden_states = self.hidden_norm(hidden_states)
467
+ input_emb = self.input_layernorm(input_emb)
468
+
469
+ hidden_states = torch.cat((input_emb, hidden_states), dim=-1)
470
+
471
+
472
+ # cache_hidden.append(hidden_states)
473
+
474
+ # Self Attention
475
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
476
+ hidden_states=hidden_states,
477
+ attention_mask=attention_mask,
478
+ position_ids=position_ids,
479
+ past_key_value=past_key_value,
480
+ output_attentions=output_attentions,
481
+ use_cache=use_cache,
482
+ )
483
+ hidden_states = residual + hidden_states
484
+
485
+ # Fully Connected
486
+ residual = hidden_states
487
+ hidden_states = self.post_attention_layernorm(hidden_states)
488
+ hidden_states = self.mlp(hidden_states)
489
+ hidden_states = residual + hidden_states
490
+
491
+ outputs = (hidden_states,)
492
+
493
+ if output_attentions:
494
+ outputs += (self_attn_weights,)
495
+
496
+ if use_cache:
497
+ outputs += (present_key_value,)
498
+
499
+ return outputs
500
+
501
+
502
+ @torch.no_grad()
503
+ def padding(tensor, left=True):
504
+ zeropadding = torch.zeros_like(tensor[:, -1:])
505
+ if left:
506
+ tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1)
507
+ else:
508
+ tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1)
509
+ return tensor
510
+
511
+
512
+
513
+ def len_list(x, n):
514
+ return [i for i in x if len(i) <= n]
515
+
516
+
517
+ class Model(nn.Module):
518
+ def __init__(self, config, load_emb=False, path=None, bias=True, total_tokens=63, depth=5, top_k=8, threshold=1.0):
519
+ super().__init__()
520
+ self.config=config
521
+ self.gradient_checkpointing = True
522
+ self.padding_idx = config.pad_token_id
523
+ self.vocab_size = config.vocab_size
524
+
525
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
526
+ self.lm_head=nn.Linear(config.hidden_size,config.draft_vocab_size,bias=False)
527
+ if load_emb and not hasattr(config, "target_hidden_size"):
528
+ from safetensors import safe_open
529
+ import json
530
+ try:
531
+ index_json_path = os.path.join(path, "model.safetensors.index.json")
532
+ if not os.path.exists(index_json_path):
533
+ index_json_path = hf_hub_download(path, "model.safetensors.index.json")
534
+ with open(index_json_path, "r") as f:
535
+ index_json = json.loads(f.read())
536
+ emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
537
+ local_emb_path = os.path.join(path, emb_path)
538
+ if not os.path.exists(local_emb_path):
539
+ local_emb_path = hf_hub_download(path, emb_path)
540
+ with safe_open(local_emb_path,
541
+ framework="pt",
542
+ device="cpu") as f:
543
+ tensor_slice = f.get_slice("model.embed_tokens.weight")
544
+ vocab_size, hidden_dim = tensor_slice.get_shape()
545
+ tensor = tensor_slice[:, :hidden_dim].float()
546
+ except:
547
+ index_json_path = os.path.join(path, "pytorch_model.bin.index.json")
548
+ if not os.path.exists(index_json_path):
549
+ index_json_path = hf_hub_download(path, "pytorch_model.bin.index.json")
550
+ with open(index_json_path, "r") as f:
551
+ index_json = json.loads(f.read())
552
+ emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
553
+ local_emb_path = os.path.join(path, emb_path)
554
+ if not os.path.exists(local_emb_path):
555
+ local_emb_path = hf_hub_download(path, emb_path)
556
+ weights = torch.load(local_emb_path)
557
+ tensor = weights["model.embed_tokens.weight"].float()
558
+ self.embed_tokens.weight.data = tensor
559
+
560
+ self.top_k = top_k
561
+ self.total_tokens = total_tokens - 1
562
+ self.depth = depth
563
+ self.threshold = math.log(threshold)
564
+ # print("total_tokens",total_tokens)
565
+ # print("depth",depth)
566
+ # print("top_k",top_k)
567
+ # print("threshold",threshold)
568
+ self.hidden_size = config.hidden_size
569
+ self.midlayer = LlamaDecoderLayeremb(config)
570
+ if hasattr(config, "target_hidden_size"):
571
+ self.fc = nn.Linear(config.target_hidden_size * 3, self.hidden_size, bias=False)
572
+ else:
573
+ self.fc = nn.Linear(config.hidden_size * 3, self.hidden_size, bias=False)
574
+ self.norm=LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
575
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
576
+
577
+ d2t=torch.zeros((config.draft_vocab_size),dtype=torch.long)
578
+ t2d=torch.zeros((config.vocab_size),dtype=torch.bool)
579
+ self.register_buffer("d2t", d2t)
580
+ self.register_buffer("t2d", t2d)
581
+
582
+ for param in self.embed_tokens.parameters():
583
+ param.requires_grad = False
584
+
585
+ def init_tree(self):
586
+ self.tree_mask_init = torch.eye(self.top_k, device=self.embed_tokens.weight.device)[None, None]
587
+ self.position_ids = torch.zeros(self.top_k, device=self.embed_tokens.weight.device, dtype=torch.long)
588
+ self.tree_mask_init = self.tree_mask_init.to(self.embed_tokens.weight.device)
589
+
590
+ def reset(self):
591
+ self.tree_mask = None
592
+
593
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
594
+ # create causal mask
595
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
596
+ combined_attention_mask = None
597
+ if input_shape[-1] > 1:
598
+ combined_attention_mask = _make_causal_mask(
599
+ input_shape,
600
+ # inputs_embeds.dtype,
601
+ torch.float32, # [MODIFIED] force to cast to float32
602
+ device=inputs_embeds.device,
603
+ past_key_values_length=past_key_values_length,
604
+ )
605
+
606
+ if attention_mask is not None:
607
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
608
+ expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to(
609
+ inputs_embeds.device
610
+ )
611
+ combined_attention_mask = (
612
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
613
+ )
614
+
615
+ # [MODIFIED] add tree mask
616
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
617
+ tree_mask = self.tree_mask
618
+ _, _, tree_shape0, tree_shape1 = tree_mask.shape
619
+ combined_attention_mask[:, :, -tree_shape0:, -tree_shape1:][
620
+ tree_mask == 0
621
+ ] = torch.finfo(torch.float32).min
622
+
623
+ return combined_attention_mask
624
+
625
+ def forward(
626
+ self,
627
+ hidden_states,
628
+ input_ids,
629
+ attention_mask: Optional[torch.Tensor] = None,
630
+ position_ids: Optional[torch.LongTensor] = None,
631
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
632
+ inputs_embeds: Optional[torch.FloatTensor] = None,
633
+ use_cache: Optional[bool] = None,
634
+ output_attentions: Optional[bool] = None,
635
+ output_hidden_states: Optional[bool] = None,
636
+ return_dict: Optional[bool] = None,
637
+ std=None
638
+ ):
639
+ batch_size, seq_length, _ = hidden_states.shape
640
+ seq_length_with_past = seq_length
641
+ past_key_values_length = 0
642
+
643
+ with torch.no_grad():
644
+ inputs_embeds = self.embed_tokens(input_ids)
645
+ # inputs_embeds = inputs_embeds.detach()
646
+
647
+ # if std is not None:
648
+ # noise = torch.randn(inputs_embeds.size(),device=inputs_embeds.device) * std
649
+ # inputs_embeds=inputs_embeds+noise
650
+
651
+ if past_key_values is not None:
652
+ past_key_values_length = past_key_values[0][0].shape[2]
653
+ seq_length_with_past = seq_length_with_past + past_key_values_length
654
+ if position_ids is None:
655
+ device = hidden_states.device if hidden_states is not None else inputs_embeds.device
656
+ position_ids = torch.arange(
657
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
658
+ )
659
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
660
+ else:
661
+ position_ids = position_ids.view(-1, seq_length).long()
662
+
663
+ #position_ids=position_ids//4
664
+ if attention_mask is None:
665
+ attention_mask = torch.ones(
666
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
667
+ )
668
+ attention_mask = self._prepare_decoder_attention_mask(
669
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
670
+ )
671
+
672
+ # if self.gradient_checkpointing and self.training:
673
+ # if use_cache:
674
+ # use_cache = False
675
+
676
+ # hidden_states=self.act(self.fc(torch.cat((inputs_embeds,hidden_states),dim=-1)))
677
+ inputs_embeds = inputs_embeds.to(hidden_states.dtype)
678
+ if hidden_states.shape[-1]!=inputs_embeds.shape[-1]:
679
+ hidden_states = self.fc(hidden_states)
680
+ # hidden_states = self.fc(hidden_states)
681
+
682
+ all_hidden_states = () if output_hidden_states else None
683
+ next_decoder_cache = () if use_cache else None
684
+
685
+ past_key_value = past_key_values[0] if past_key_values is not None else None
686
+ layer_outputs = self.midlayer(
687
+ input_emb=inputs_embeds,
688
+ hidden_states=hidden_states,
689
+ attention_mask=attention_mask,
690
+ position_ids=position_ids,
691
+ past_key_value=past_key_value,
692
+ output_attentions=output_attentions,
693
+ use_cache=True,
694
+ )
695
+ if use_cache:
696
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
697
+ hidden_states = layer_outputs[0]
698
+
699
+
700
+ if use_cache:
701
+ return hidden_states, next_decoder_cache
702
+
703
+ return hidden_states
704
+
705
+ def reset_kv(self):
706
+ self.stable_kv = None
707
+
708
+ @torch.no_grad()
709
+ def topK_genrate(self, hidden_states, input_ids, head, logits_processor):
710
+
711
+ input_ids = input_ids.to(hidden_states.device)
712
+ total_tokens = self.total_tokens
713
+ depth = self.depth
714
+ top_k = self.top_k
715
+
716
+ sample_token = input_ids[:, -1]
717
+
718
+ scores_list = []
719
+ parents_list = []
720
+ ss_token = []
721
+
722
+ input_ids = input_ids[:, 1:]
723
+ input_ids = input_ids.to(hidden_states.device)
724
+
725
+ len_posi = input_ids.shape[1]
726
+ self.reset()
727
+
728
+ # with Timer("draft many"):
729
+ if hasattr(self, "stable_kv") and self.stable_kv is not None:
730
+ kv_len = self.stable_kv[0][0].shape[2]
731
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids[:, kv_len:],
732
+ past_key_values=self.stable_kv, use_cache=True)
733
+ else:
734
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, use_cache=True)
735
+ self.stable_kv = past_key_values
736
+ last_hidden = out_hidden[:, -1]
737
+
738
+ # last_headout = head(last_hidden)
739
+ last_headout = self.lm_head(self.norm(last_hidden))
740
+
741
+ last_p = self.logsoftmax(last_headout)
742
+ top = torch.topk(last_p, top_k, dim=-1)
743
+ topk_index, topk_p = top.indices, top.values
744
+ scores = topk_p[0]
745
+ scores_list.append(scores[None])
746
+ parents_list.append(torch.zeros(1, dtype=torch.long, device=scores.device))
747
+ if self.config.vocab_size==self.config.draft_vocab_size:
748
+ ss_token.append(topk_index)
749
+ input_ids = topk_index
750
+ else:
751
+ ss_token.append(topk_index+self.d2t[topk_index])
752
+ input_ids = topk_index+self.d2t[topk_index]
753
+ input_hidden = last_hidden[None].repeat(1, top_k, 1)
754
+ tree_mask = self.tree_mask_init
755
+ topk_cs_index = torch.arange(top_k, device=self.embed_tokens.weight.device)
756
+
757
+ # 4
758
+ for i in range(depth):
759
+ self.tree_mask = tree_mask
760
+ position_ids = len_posi + self.position_ids
761
+ # with Timer("draft one"):
762
+ out_hidden, past_key_values = self(input_hidden, input_ids=input_ids, past_key_values=past_key_values,
763
+ position_ids=position_ids, use_cache=True)
764
+ len_posi += 1
765
+
766
+ # with Timer("sort1"):
767
+ bias1 = top_k if i > 0 else 0
768
+ bias2 = max(0, i - 1)
769
+ bias = 1 + top_k ** 2 * bias2 + bias1
770
+ parents = (topk_cs_index + bias)
771
+ parents_list.append(parents)
772
+
773
+ last_headout = self.lm_head(self.norm(out_hidden[0]))
774
+ last_p = self.logsoftmax(last_headout)
775
+
776
+ top = torch.topk(last_p, top_k, dim=-1)
777
+ topk_index, topk_p = top.indices, top.values
778
+
779
+ cu_scores = topk_p + scores[:, None]
780
+
781
+ topk_cs = torch.topk(cu_scores.view(-1), top_k, dim=-1)
782
+ topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
783
+ scores = topk_cs_p
784
+
785
+ out_ids = topk_cs_index // top_k
786
+ input_hidden = out_hidden[:, out_ids]
787
+
788
+ input_ids = topk_index.view(-1)[topk_cs_index][None]
789
+
790
+ if self.config.vocab_size == self.config.draft_vocab_size:
791
+ ss_token.append(topk_index)
792
+ else:
793
+ input_ids = input_ids + self.d2t[input_ids]
794
+ ss_token.append(topk_index+self.d2t[topk_index])
795
+ scores_list.append(cu_scores)
796
+
797
+ # <mod> JQZ 250912
798
+ # tree_mask = torch.cat((tree_mask[:, :, out_ids], self.tree_mask_init), dim=3)
799
+ # <before-after> for dynamic moving between cpu and gpu
800
+ out_ids_for_mask = out_ids.to(tree_mask.device)
801
+ tree_mask = torch.cat((tree_mask[:, :, out_ids_for_mask], self.tree_mask_init), dim=3)
802
+ # </mod>
803
+
804
+
805
+ scores_list = torch.cat(scores_list, dim=0).view(-1)
806
+ ss_token_list = torch.cat(ss_token, dim=0).view(-1)
807
+ top_scores = torch.topk(scores_list, total_tokens, dim=-1)
808
+ top_scores_index = top_scores.indices
809
+ top_scores_index = torch.sort(top_scores_index).values
810
+
811
+ draft_tokens = ss_token_list[top_scores_index]
812
+ draft_tokens = torch.cat((sample_token, draft_tokens), dim=0)
813
+
814
+ draft_parents = torch.cat(parents_list, dim=0)[top_scores_index // top_k].long()
815
+ mask_index = torch.searchsorted(top_scores_index, draft_parents - 1, right=False)
816
+ # mask_index[(top_scores_index[mask_index]!=draft_parents - 1)]=-1
817
+ mask_index[draft_parents == 0] = -1
818
+ mask_index = mask_index + 1
819
+ mask_index_list = mask_index.tolist()
820
+ # with Timer("mask"):
821
+ tree_mask = torch.eye(total_tokens + 1).bool()
822
+ tree_mask[:, 0] = True
823
+ for i in range(total_tokens):
824
+ tree_mask[i + 1].add_(tree_mask[mask_index_list[i]])
825
+
826
+
827
+ tree_position_ids = torch.sum(tree_mask, dim=1) - 1
828
+
829
+ tree_mask = tree_mask.float()[None, None]
830
+ draft_tokens = draft_tokens[None]
831
+
832
+ del parents_list, scores_list, ss_token, ss_token_list, draft_parents
833
+
834
+ # with Timer("retrieve"):
835
+
836
+ max_depth = torch.max(tree_position_ids) + 1
837
+ noleaf_index = torch.unique(mask_index).tolist()
838
+ noleaf_num = len(noleaf_index) - 1
839
+ leaf_num = total_tokens - noleaf_num
840
+
841
+ retrieve_indices = torch.zeros(leaf_num, max_depth.item(), dtype=torch.long) - 1
842
+ retrieve_indices = retrieve_indices.tolist()
843
+
844
+ rid = 0
845
+ position_ids_list = tree_position_ids.tolist()
846
+
847
+ for i in range(total_tokens + 1):
848
+ if i not in noleaf_index:
849
+ cid = i
850
+ depth = position_ids_list[i]
851
+ for j in reversed(range(depth + 1)):
852
+ retrieve_indices[rid][j] = cid
853
+ cid = mask_index_list[cid - 1]
854
+ rid += 1
855
+
856
+ if logits_processor is not None:
857
+ maxitem = total_tokens + 5
858
+
859
+ def custom_sort(lst):
860
+ # sort_keys=[len(list)]
861
+ sort_keys = []
862
+ for i in range(len(lst)):
863
+ sort_keys.append(lst[i] if lst[i] >= 0 else maxitem)
864
+ return sort_keys
865
+
866
+ retrieve_indices = sorted(retrieve_indices, key=custom_sort)
867
+
868
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
869
+ del mask_index, mask_index_list, noleaf_index, noleaf_num, leaf_num, max_depth, rid
870
+ tree_position_ids = tree_position_ids.to(hidden_states.device)
871
+
872
+ return draft_tokens, retrieve_indices, tree_mask, tree_position_ids
873
+
874
+
875
+
876
+
877
+ import torch
878
+
879
+
880
+ def count_parameters(model):
881
+ return sum(p.numel() for p in model.parameters())
882
+
883
+
884
+ if __name__ == "__main__":
885
+ config = EConfig.from_pretrained('config.json')
886
+ model = Model(config, load_emb=False)
887
+ print(model)
eagle/model/cnets1.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import copy
22
+ import os
23
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "5"
24
+ import math
25
+ from typing import List, Optional, Tuple, Union
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+
30
+ from transformers.activations import ACT2FN
31
+ from huggingface_hub import hf_hub_download
32
+
33
+
34
+ try:
35
+ from .configs import EConfig
36
+ from .utils_c import *
37
+ from .choices import *
38
+ except:
39
+ from configs import EConfig
40
+ from utils_c import *
41
+ from choices import *
42
+ from utils import prepare_logits_processor
43
+
44
+
45
+
46
+
47
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
48
+ def _make_causal_mask(
49
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
50
+ ):
51
+ """
52
+ Make causal mask used for bi-directional self-attention.
53
+ """
54
+ bsz, tgt_len = input_ids_shape
55
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
56
+ mask_cond = torch.arange(mask.size(-1), device=device)
57
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
58
+ mask = mask.to(dtype)
59
+
60
+ if past_key_values_length > 0:
61
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
62
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
63
+
64
+
65
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
66
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
67
+ """
68
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
69
+ """
70
+ bsz, src_len = mask.size()
71
+ tgt_len = tgt_len if tgt_len is not None else src_len
72
+
73
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
74
+
75
+ inverted_mask = 1.0 - expanded_mask
76
+
77
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
78
+
79
+
80
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
81
+ """
82
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
83
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
84
+ """
85
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
86
+ if n_rep == 1:
87
+ return hidden_states
88
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
89
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
90
+
91
+
92
+ def rotate_half(x):
93
+ """Rotates half the hidden dims of the input."""
94
+ x1 = x[..., : x.shape[-1] // 2]
95
+ x2 = x[..., x.shape[-1] // 2:]
96
+ return torch.cat((-x2, x1), dim=-1)
97
+
98
+
99
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
100
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
101
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
102
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
103
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
104
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
105
+ q_embed = (q * cos) + (rotate_half(q) * sin)
106
+ k_embed = (k * cos) + (rotate_half(k) * sin)
107
+ return q_embed, k_embed
108
+
109
+
110
+ class LlamaRotaryEmbedding(torch.nn.Module):
111
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
112
+ super().__init__()
113
+
114
+ self.dim = dim
115
+ self.max_position_embeddings = max_position_embeddings
116
+ self.base = base
117
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
118
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
119
+
120
+ # Build here to make `torch.jit.trace` work.
121
+ self._set_cos_sin_cache(
122
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
123
+ )
124
+
125
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
126
+ self.max_seq_len_cached = seq_len
127
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
128
+
129
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
130
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
131
+ emb = torch.cat((freqs, freqs), dim=-1)
132
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
133
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
134
+
135
+ def forward(self, x, seq_len=None):
136
+ # x: [bs, num_attention_heads, seq_len, head_size]
137
+ if seq_len > self.max_seq_len_cached:
138
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
139
+
140
+ return (
141
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
142
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
143
+ )
144
+
145
+
146
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
147
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
148
+
149
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
150
+ self.scaling_factor = scaling_factor
151
+ super().__init__(dim, max_position_embeddings, base, device)
152
+
153
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
154
+ self.max_seq_len_cached = seq_len
155
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
156
+ t = t / self.scaling_factor
157
+
158
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
159
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
162
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
163
+
164
+
165
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
166
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
167
+
168
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
169
+ self.scaling_factor = scaling_factor
170
+ super().__init__(dim, max_position_embeddings, base, device)
171
+
172
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
173
+ self.max_seq_len_cached = seq_len
174
+
175
+ if seq_len > self.max_position_embeddings:
176
+ base = self.base * (
177
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
178
+ ) ** (self.dim / (self.dim - 2))
179
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
180
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
181
+
182
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
183
+
184
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
185
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
186
+ emb = torch.cat((freqs, freqs), dim=-1)
187
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
188
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
189
+
190
+
191
+ class LlamaAttention(nn.Module):
192
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
193
+
194
+ def __init__(self, config):
195
+ super().__init__()
196
+ self.config = config
197
+ self.hidden_size = config.hidden_size
198
+ self.num_heads = config.num_attention_heads
199
+ self.head_dim = self.hidden_size // self.num_heads
200
+ self.num_key_value_heads = config.num_key_value_heads
201
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
202
+ self.max_position_embeddings = config.max_position_embeddings
203
+
204
+ if (self.head_dim * self.num_heads) != self.hidden_size:
205
+ raise ValueError(
206
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
207
+ f" and `num_heads`: {self.num_heads})."
208
+ )
209
+ if hasattr(config, "qkv_bias"):
210
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
211
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
212
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
213
+ else:
214
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
215
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
216
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
217
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
218
+ self._init_rope()
219
+
220
+ def _init_rope(self):
221
+ if self.config.rope_scaling is None:
222
+ if hasattr(self.config, "rope_theta"):
223
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim,
224
+ max_position_embeddings=self.max_position_embeddings,
225
+ base=self.config.rope_theta)
226
+ else:
227
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim,
228
+ max_position_embeddings=self.max_position_embeddings)
229
+ else:
230
+ scaling_type = self.config.rope_scaling["type"]
231
+ scaling_factor = self.config.rope_scaling["factor"]
232
+ if scaling_type == "linear":
233
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
234
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
235
+ )
236
+ elif scaling_type == "dynamic":
237
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
238
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
239
+ )
240
+ else:
241
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
242
+
243
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
244
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ attention_mask: Optional[torch.Tensor] = None,
250
+ position_ids: Optional[torch.LongTensor] = None,
251
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
252
+ output_attentions: bool = False,
253
+ use_cache: bool = False,
254
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
255
+ bsz, q_len, _ = hidden_states.size()
256
+
257
+ if self.config.pretraining_tp > 1:
258
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
259
+ query_slices = self.q_proj.weight.split(
260
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
261
+ )
262
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
263
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
264
+
265
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
266
+ query_states = torch.cat(query_states, dim=-1)
267
+
268
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
269
+ key_states = torch.cat(key_states, dim=-1)
270
+
271
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
272
+ value_states = torch.cat(value_states, dim=-1)
273
+
274
+ else:
275
+ query_states = self.q_proj(hidden_states)
276
+ key_states = self.k_proj(hidden_states)
277
+ value_states = self.v_proj(hidden_states)
278
+
279
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
280
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
281
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
282
+
283
+ kv_seq_len = key_states.shape[-2]
284
+ if past_key_value is not None:
285
+ kv_seq_len += past_key_value[0].shape[-2]
286
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
287
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
288
+
289
+ if past_key_value is not None:
290
+ # reuse k, v, self_attention
291
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
292
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
293
+
294
+ past_key_value = (key_states, value_states) if use_cache else None
295
+
296
+ # repeat k/v heads if n_kv_heads < n_heads
297
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
298
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
299
+
300
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
301
+
302
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
303
+ raise ValueError(
304
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
305
+ f" {attn_weights.size()}"
306
+ )
307
+
308
+ if attention_mask is not None:
309
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
310
+ raise ValueError(
311
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
312
+ )
313
+ attn_weights = attn_weights + attention_mask
314
+
315
+ # upcast attention to fp32
316
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
317
+ attn_output = torch.matmul(attn_weights, value_states)
318
+
319
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
320
+ raise ValueError(
321
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
322
+ f" {attn_output.size()}"
323
+ )
324
+
325
+ attn_output = attn_output.transpose(1, 2).contiguous()
326
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
327
+
328
+ if self.config.pretraining_tp > 1:
329
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
330
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
331
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
332
+ else:
333
+ attn_output = self.o_proj(attn_output)
334
+
335
+ if not output_attentions:
336
+ attn_weights = None
337
+
338
+ return attn_output, attn_weights, past_key_value
339
+
340
+
341
+ class LlamaMLP(nn.Module):
342
+ def __init__(self, config):
343
+ super().__init__()
344
+ self.config = config
345
+ self.hidden_size = config.hidden_size
346
+ self.intermediate_size = config.intermediate_size
347
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
348
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
349
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
350
+ self.act_fn = ACT2FN[config.hidden_act]
351
+
352
+ def forward(self, x):
353
+ if self.config.pretraining_tp > 1:
354
+ slice = self.intermediate_size // self.config.pretraining_tp
355
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
356
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
357
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
358
+
359
+ gate_proj = torch.cat(
360
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
361
+ )
362
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
363
+
364
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
365
+ down_proj = [
366
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
367
+ ]
368
+ down_proj = sum(down_proj)
369
+ else:
370
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
371
+
372
+ return down_proj
373
+
374
+
375
+ class LlamaRMSNorm(nn.Module):
376
+ def __init__(self, hidden_size, eps=1e-6):
377
+ """
378
+ LlamaRMSNorm is equivalent to T5LayerNorm
379
+ """
380
+ super().__init__()
381
+ self.weight = nn.Parameter(torch.ones(hidden_size))
382
+ self.variance_epsilon = eps
383
+
384
+ def forward(self, hidden_states):
385
+ input_dtype = hidden_states.dtype
386
+ hidden_states = hidden_states.to(torch.float32)
387
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
388
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
389
+ return self.weight * hidden_states.to(input_dtype)
390
+
391
+
392
+ class LlamaDecoderLayer(nn.Module):
393
+ def __init__(self, config, index):
394
+ super().__init__()
395
+ self.hidden_size = config.hidden_size
396
+ self.self_attn = LlamaAttention(config=config)
397
+ self.mlp = LlamaMLP(config)
398
+ self.index = index
399
+ if self.index != 0:
400
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
401
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states: torch.Tensor,
406
+ attention_mask: Optional[torch.Tensor] = None,
407
+ position_ids: Optional[torch.LongTensor] = None,
408
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
409
+ output_attentions: Optional[bool] = False,
410
+ use_cache: Optional[bool] = False,
411
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
412
+ """
413
+ Args:
414
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
415
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
416
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
417
+ output_attentions (`bool`, *optional*):
418
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
419
+ returned tensors for more detail.
420
+ use_cache (`bool`, *optional*):
421
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
422
+ (see `past_key_values`).
423
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
424
+ """
425
+
426
+ residual = hidden_states
427
+
428
+ if self.index != 0:
429
+ hidden_states = self.input_layernorm(hidden_states)
430
+
431
+ # Self Attention
432
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
433
+ hidden_states=hidden_states,
434
+ attention_mask=attention_mask,
435
+ position_ids=position_ids,
436
+ past_key_value=past_key_value,
437
+ output_attentions=output_attentions,
438
+ use_cache=use_cache,
439
+ )
440
+ hidden_states = residual + hidden_states
441
+
442
+ # Fully Connected
443
+ residual = hidden_states
444
+ hidden_states = self.post_attention_layernorm(hidden_states)
445
+ hidden_states = self.mlp(hidden_states)
446
+ hidden_states = residual + hidden_states
447
+
448
+ outputs = (hidden_states,)
449
+
450
+ if output_attentions:
451
+ outputs += (self_attn_weights,)
452
+
453
+ if use_cache:
454
+ outputs += (present_key_value,)
455
+
456
+ return outputs
457
+
458
+
459
+ class I(nn.Module):
460
+ def __init__(self):
461
+ super().__init__()
462
+ self.dummy = nn.Parameter(torch.ones(1, dtype=torch.float32))
463
+
464
+ def forward(self, x):
465
+ return x + self.dummy - self.dummy # (also tried x+self.dummy)
466
+
467
+
468
+ def len_list(x, n):
469
+ return [i for i in x if len(i) <= n]
470
+
471
+
472
+ class Model(nn.Module):
473
+ def __init__(self, config, load_emb=False, path=None, bias=True, total_tokens=63, depth=5, top_k=8, threshold=1.0):
474
+ super().__init__()
475
+
476
+ self.gradient_checkpointing = True
477
+ self.padding_idx = config.pad_token_id
478
+ self.vocab_size = config.vocab_size
479
+
480
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
481
+ if load_emb:
482
+ from safetensors import safe_open
483
+ import json
484
+ try:
485
+ index_json_path = os.path.join(path, "model.safetensors.index.json")
486
+ if not os.path.exists(index_json_path):
487
+ index_json_path = hf_hub_download(path, "model.safetensors.index.json")
488
+ with open(index_json_path, "r") as f:
489
+ index_json = json.loads(f.read())
490
+ emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
491
+ local_emb_path = os.path.join(path, emb_path)
492
+ if not os.path.exists(local_emb_path):
493
+ local_emb_path = hf_hub_download(path, emb_path)
494
+ with safe_open(local_emb_path,
495
+ framework="pt",
496
+ device="cpu") as f:
497
+ tensor_slice = f.get_slice("model.embed_tokens.weight")
498
+ vocab_size, hidden_dim = tensor_slice.get_shape()
499
+ tensor = tensor_slice[:, :hidden_dim].float()
500
+ except:
501
+ index_json_path = os.path.join(path, "pytorch_model.bin.index.json")
502
+ if not os.path.exists(index_json_path):
503
+ index_json_path = hf_hub_download(path, "pytorch_model.bin.index.json")
504
+ with open(index_json_path, "r") as f:
505
+ index_json = json.loads(f.read())
506
+ emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
507
+ local_emb_path = os.path.join(path, emb_path)
508
+ if not os.path.exists(local_emb_path):
509
+ local_emb_path = hf_hub_download(path, emb_path)
510
+ weights = torch.load(local_emb_path)
511
+ tensor = weights["model.embed_tokens.weight"].float()
512
+ self.embed_tokens.weight.data = tensor
513
+
514
+ self.top_k = top_k
515
+ self.total_tokens = total_tokens - 1
516
+ self.depth = depth
517
+ self.threshold = math.log(threshold)
518
+ # print("total_tokens",total_tokens)
519
+ # print("depth",depth)
520
+ # print("top_k",top_k)
521
+ # print("threshold",threshold)
522
+
523
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config, index) for index in range(config.num_hidden_layers)])
524
+ self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=bias)
525
+ self.act = ACT2FN[config.hidden_act]
526
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
527
+ for param in self.embed_tokens.parameters():
528
+ param.requires_grad = False
529
+
530
+ def init_tree(self):
531
+ self.tree_mask_init = torch.eye(self.top_k, device=self.embed_tokens.weight.device)[None, None]
532
+ self.position_ids = torch.zeros(self.top_k, device=self.embed_tokens.weight.device, dtype=torch.long)
533
+ self.tree_mask_init = self.tree_mask_init.to(self.embed_tokens.weight.device)
534
+
535
+ def reset(self):
536
+ self.tree_mask = None
537
+
538
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
539
+ # create causal mask
540
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
541
+ combined_attention_mask = None
542
+ if input_shape[-1] > 1:
543
+ combined_attention_mask = _make_causal_mask(
544
+ input_shape,
545
+ # inputs_embeds.dtype,
546
+ torch.float32, # [MODIFIED] force to cast to float32
547
+ device=inputs_embeds.device,
548
+ past_key_values_length=past_key_values_length,
549
+ )
550
+
551
+ if attention_mask is not None:
552
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
553
+ expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to(
554
+ inputs_embeds.device
555
+ )
556
+ combined_attention_mask = (
557
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
558
+ )
559
+
560
+ # [MODIFIED] add tree mask
561
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
562
+ tree_mask = self.tree_mask
563
+ _, _, tree_shape0, tree_shape1 = tree_mask.shape
564
+ combined_attention_mask[:, :, -tree_shape0:, -tree_shape1:][
565
+ tree_mask == 0
566
+ ] = torch.finfo(torch.float32).min
567
+
568
+ return combined_attention_mask
569
+
570
+ def forward(
571
+ self,
572
+ hidden_states,
573
+ input_ids,
574
+ attention_mask: Optional[torch.Tensor] = None,
575
+ position_ids: Optional[torch.LongTensor] = None,
576
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
577
+ inputs_embeds: Optional[torch.FloatTensor] = None,
578
+ use_cache: Optional[bool] = None,
579
+ output_attentions: Optional[bool] = None,
580
+ output_hidden_states: Optional[bool] = None,
581
+ return_dict: Optional[bool] = None,
582
+ std=None
583
+ ):
584
+ batch_size, seq_length, _ = hidden_states.shape
585
+ seq_length_with_past = seq_length
586
+ past_key_values_length = 0
587
+
588
+ with torch.no_grad():
589
+ inputs_embeds = self.embed_tokens(input_ids)
590
+ # inputs_embeds = inputs_embeds.detach()
591
+
592
+ # if std is not None:
593
+ # noise = torch.randn(inputs_embeds.size(),device=inputs_embeds.device) * std
594
+ # inputs_embeds=inputs_embeds+noise
595
+
596
+ if past_key_values is not None:
597
+ past_key_values_length = past_key_values[0][0].shape[2]
598
+ seq_length_with_past = seq_length_with_past + past_key_values_length
599
+ if position_ids is None:
600
+ device = hidden_states.device if hidden_states is not None else inputs_embeds.device
601
+ position_ids = torch.arange(
602
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
603
+ )
604
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
605
+ else:
606
+ position_ids = position_ids.view(-1, seq_length).long()
607
+
608
+ #position_ids=position_ids//4
609
+ if attention_mask is None:
610
+ attention_mask = torch.ones(
611
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
612
+ )
613
+ attention_mask = self._prepare_decoder_attention_mask(
614
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
615
+ )
616
+
617
+ # if self.gradient_checkpointing and self.training:
618
+ # if use_cache:
619
+ # use_cache = False
620
+
621
+ # hidden_states=self.act(self.fc(torch.cat((inputs_embeds,hidden_states),dim=-1)))
622
+ inputs_embeds = inputs_embeds.to(hidden_states.dtype)
623
+ hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
624
+
625
+ all_hidden_states = () if output_hidden_states else None
626
+ next_decoder_cache = () if use_cache else None
627
+
628
+ for idx, decoder_layer in enumerate(self.layers):
629
+ if output_hidden_states:
630
+ all_hidden_states += (hidden_states,)
631
+
632
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
633
+
634
+ if self.gradient_checkpointing and self.training:
635
+
636
+ def create_custom_forward(module):
637
+ def custom_forward(*inputs):
638
+ # None for past_key_value
639
+ return module(*inputs, past_key_value, output_attentions)
640
+
641
+ return custom_forward
642
+
643
+ layer_outputs = torch.utils.checkpoint.checkpoint(
644
+ create_custom_forward(decoder_layer),
645
+ hidden_states,
646
+ attention_mask,
647
+ position_ids,
648
+ )
649
+ else:
650
+ layer_outputs = decoder_layer(
651
+ hidden_states,
652
+ attention_mask=attention_mask,
653
+ position_ids=position_ids,
654
+ past_key_value=past_key_value,
655
+ output_attentions=output_attentions,
656
+ use_cache=use_cache,
657
+ )
658
+
659
+ hidden_states = layer_outputs[0]
660
+
661
+ if use_cache:
662
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
663
+
664
+ if use_cache:
665
+ return hidden_states, next_decoder_cache
666
+
667
+ return hidden_states
668
+
669
+ def reset_kv(self):
670
+ self.stable_kv = None
671
+
672
+ @torch.no_grad()
673
+ def topK_genrate(self, hidden_states, input_ids, head, logits_processor):
674
+
675
+ input_ids = input_ids.to(hidden_states.device)
676
+ total_tokens = self.total_tokens
677
+ depth = self.depth
678
+ top_k = self.top_k
679
+
680
+ sample_token = input_ids[:, -1]
681
+
682
+ scores_list = []
683
+ parents_list = []
684
+ ss_token = []
685
+
686
+ input_ids = input_ids[:, 1:]
687
+ input_ids = input_ids.to(hidden_states.device)
688
+
689
+ len_posi = input_ids.shape[1]
690
+ self.reset()
691
+
692
+ # with Timer("draft many"):
693
+ if hasattr(self, "stable_kv") and self.stable_kv is not None:
694
+ kv_len = self.stable_kv[0][0].shape[2]
695
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids[:, kv_len:],
696
+ past_key_values=self.stable_kv, use_cache=True)
697
+ else:
698
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, use_cache=True)
699
+ self.stable_kv = past_key_values
700
+ last_hidden = out_hidden[:, -1]
701
+
702
+ last_headout = head(last_hidden)
703
+
704
+ last_p = self.logsoftmax(last_headout)
705
+ top = torch.topk(last_p, top_k, dim=-1)
706
+ topk_index, topk_p = top.indices, top.values
707
+ scores = topk_p[0]
708
+ scores_list.append(scores[None])
709
+ parents_list.append(torch.zeros(1, dtype=torch.long, device=scores.device))
710
+ ss_token.append(topk_index)
711
+ input_ids = topk_index
712
+ input_hidden = last_hidden[None].repeat(1, top_k, 1)
713
+ tree_mask = self.tree_mask_init
714
+ topk_cs_index = torch.arange(top_k, device=self.embed_tokens.weight.device)
715
+
716
+ # 4
717
+ for i in range(depth):
718
+ self.tree_mask = tree_mask
719
+ position_ids = len_posi + self.position_ids
720
+ # with Timer("draft one"):
721
+ out_hidden, past_key_values = self(input_hidden, input_ids=input_ids, past_key_values=past_key_values,
722
+ position_ids=position_ids, use_cache=True)
723
+ len_posi += 1
724
+
725
+ # with Timer("sort1"):
726
+ bias1 = top_k if i > 0 else 0
727
+ bias2 = max(0, i - 1)
728
+ bias = 1 + top_k ** 2 * bias2 + bias1
729
+ parents = (topk_cs_index + bias)
730
+ parents_list.append(parents)
731
+
732
+ last_headout = head(out_hidden[0])
733
+ last_p = self.logsoftmax(last_headout)
734
+
735
+ top = torch.topk(last_p, top_k, dim=-1)
736
+ topk_index, topk_p = top.indices, top.values
737
+
738
+ cu_scores = topk_p + scores[:, None]
739
+
740
+ topk_cs = torch.topk(cu_scores.view(-1), top_k, dim=-1)
741
+ topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
742
+ scores = topk_cs_p
743
+
744
+ out_ids = topk_cs_index // top_k
745
+ input_hidden = out_hidden[:, out_ids]
746
+
747
+ input_ids = topk_index.view(-1)[topk_cs_index][None]
748
+
749
+ ss_token.append(topk_index)
750
+ scores_list.append(cu_scores)
751
+ tree_mask = torch.cat((tree_mask[:, :, out_ids], self.tree_mask_init), dim=3)
752
+
753
+
754
+
755
+ scores_list = torch.cat(scores_list, dim=0).view(-1)
756
+ ss_token_list = torch.cat(ss_token, dim=0).view(-1)
757
+ top_scores = torch.topk(scores_list, total_tokens, dim=-1)
758
+ top_scores_index = top_scores.indices
759
+ top_scores_index = torch.sort(top_scores_index).values
760
+
761
+ draft_tokens = ss_token_list[top_scores_index]
762
+ draft_tokens = torch.cat((sample_token, draft_tokens), dim=0)
763
+
764
+ draft_parents = torch.cat(parents_list, dim=0)[top_scores_index // top_k].long()
765
+ mask_index = torch.searchsorted(top_scores_index, draft_parents - 1, right=False)
766
+ # mask_index[(top_scores_index[mask_index]!=draft_parents - 1)]=-1
767
+ mask_index[draft_parents == 0] = -1
768
+ mask_index = mask_index + 1
769
+ mask_index_list = mask_index.tolist()
770
+ # with Timer("mask"):
771
+ tree_mask = torch.eye(total_tokens + 1).bool()
772
+ tree_mask[:, 0] = True
773
+ for i in range(total_tokens):
774
+ tree_mask[i + 1].add_(tree_mask[mask_index_list[i]])
775
+
776
+
777
+ tree_position_ids = torch.sum(tree_mask, dim=1) - 1
778
+
779
+ tree_mask = tree_mask.float()[None, None]
780
+ draft_tokens = draft_tokens[None]
781
+
782
+ del parents_list, scores_list, ss_token, ss_token_list, draft_parents
783
+
784
+ # with Timer("retrieve"):
785
+
786
+ max_depth = torch.max(tree_position_ids) + 1
787
+ noleaf_index = torch.unique(mask_index).tolist()
788
+ noleaf_num = len(noleaf_index) - 1
789
+ leaf_num = total_tokens - noleaf_num
790
+
791
+ retrieve_indices = torch.zeros(leaf_num, max_depth.item(), dtype=torch.long) - 1
792
+ retrieve_indices = retrieve_indices.tolist()
793
+
794
+ rid = 0
795
+ position_ids_list = tree_position_ids.tolist()
796
+
797
+ for i in range(total_tokens + 1):
798
+ if i not in noleaf_index:
799
+ cid = i
800
+ depth = position_ids_list[i]
801
+ for j in reversed(range(depth + 1)):
802
+ retrieve_indices[rid][j] = cid
803
+ cid = mask_index_list[cid - 1]
804
+ rid += 1
805
+
806
+ if logits_processor is not None:
807
+ maxitem = total_tokens + 5
808
+
809
+ def custom_sort(lst):
810
+ # sort_keys=[len(list)]
811
+ sort_keys = []
812
+ for i in range(len(lst)):
813
+ sort_keys.append(lst[i] if lst[i] >= 0 else maxitem)
814
+ return sort_keys
815
+
816
+ retrieve_indices = sorted(retrieve_indices, key=custom_sort)
817
+
818
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
819
+ del mask_index, mask_index_list, noleaf_index, noleaf_num, leaf_num, max_depth, rid
820
+ tree_position_ids = tree_position_ids.to(hidden_states.device)
821
+
822
+ return draft_tokens, retrieve_indices, tree_mask, tree_position_ids
823
+
824
+
825
+
826
+
827
+
828
+ def count_parameters(model):
829
+ return sum(p.numel() for p in model.parameters())
830
+
831
+
832
+ if __name__ == "__main__":
833
+ config = EConfig.from_pretrained('config.json')
834
+ model = Model(config, load_emb=False)
835
+ print(model)
eagle/model/configs.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ class EConfig(PretrainedConfig):
3
+ r"""
4
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
5
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
6
+ defaults will yield a similar configuration to that of the LLaMA-7B.
7
+
8
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
9
+ documentation from [`PretrainedConfig`] for more information.
10
+
11
+
12
+ Args:
13
+ vocab_size (`int`, *optional*, defaults to 32000):
14
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
15
+ `inputs_ids` passed when calling [`LlamaModel`]
16
+ hidden_size (`int`, *optional*, defaults to 4096):
17
+ Dimension of the hidden representations.
18
+ intermediate_size (`int`, *optional*, defaults to 11008):
19
+ Dimension of the MLP representations.
20
+ num_hidden_layers (`int`, *optional*, defaults to 32):
21
+ Number of hidden layers in the Transformer encoder.
22
+ num_attention_heads (`int`, *optional*, defaults to 32):
23
+ Number of attention heads for each attention layer in the Transformer encoder.
24
+ num_key_value_heads (`int`, *optional*):
25
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
26
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
27
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
28
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
29
+ by meanpooling all the original heads within that group. For more details checkout [this
30
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
31
+ `num_attention_heads`.
32
+ pretraining_tp (`int`, *optional*, defaults to `1`):
33
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
34
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
35
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
36
+ issue](https://github.com/pytorch/pytorch/issues/76232).
37
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
38
+ The non-linear activation function (function or string) in the decoder.
39
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
40
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
41
+ just in case (e.g., 512 or 1024 or 2048).
42
+ initializer_range (`float`, *optional*, defaults to 0.02):
43
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
44
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
45
+ The epsilon used by the rms normalization layers.
46
+ use_cache (`bool`, *optional*, defaults to `True`):
47
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
48
+ relevant if `config.is_decoder=True`.
49
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
50
+ Whether to tie weight embeddings
51
+ rope_scaling (`Dict`, *optional*):
52
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
53
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
54
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
55
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
56
+ these scaling strategies behave:
57
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
58
+ experimental feature, subject to breaking API changes in future versions.
59
+
60
+ Example:
61
+
62
+ ```python
63
+ >>> from transformers import LlamaModel, LlamaConfig
64
+
65
+ >>> # Initializing a LLaMA llama-7b style configuration
66
+ >>> configuration = LlamaConfig()
67
+
68
+ >>> # Initializing a model from the llama-7b style configuration
69
+ >>> model = LlamaModel(configuration)
70
+
71
+ >>> # Accessing the model configuration
72
+ >>> configuration = model.config
73
+ ```"""
74
+ model_type = "llama"
75
+ keys_to_ignore_at_inference = ["past_key_values"]
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_size=32000,
80
+ hidden_size=4096,
81
+ intermediate_size=11008,
82
+ num_hidden_layers=32,
83
+ num_attention_heads=32,
84
+ num_key_value_heads=None,
85
+ hidden_act="silu",
86
+ max_position_embeddings=2048,
87
+ initializer_range=0.02,
88
+ rms_norm_eps=1e-6,
89
+ use_cache=True,
90
+ pad_token_id=None,
91
+ bos_token_id=1,
92
+ eos_token_id=2,
93
+ pretraining_tp=1,
94
+ tie_word_embeddings=False,
95
+ rope_scaling=None,
96
+ rope_theta=10000,
97
+ **kwargs,
98
+ ):
99
+ self.vocab_size = vocab_size
100
+ self.max_position_embeddings = max_position_embeddings
101
+ self.hidden_size = hidden_size
102
+ self.intermediate_size = intermediate_size
103
+ self.num_hidden_layers = num_hidden_layers
104
+ self.num_attention_heads = num_attention_heads
105
+
106
+ # for backward compatibility
107
+ if num_key_value_heads is None:
108
+ num_key_value_heads = num_attention_heads
109
+
110
+ self.num_key_value_heads = num_key_value_heads
111
+ self.hidden_act = hidden_act
112
+ self.initializer_range = initializer_range
113
+ self.rms_norm_eps = rms_norm_eps
114
+ self.pretraining_tp = pretraining_tp
115
+ self.use_cache = use_cache
116
+ self.rope_scaling = rope_scaling
117
+ self.rope_theta = rope_theta
118
+ # self._rope_scaling_validation()
119
+
120
+ super().__init__(
121
+ pad_token_id=pad_token_id,
122
+ bos_token_id=bos_token_id,
123
+ eos_token_id=eos_token_id,
124
+ tie_word_embeddings=tie_word_embeddings,
125
+ **kwargs,
126
+ )
127
+
128
+ # def _rope_scaling_validation(self):
129
+ # """
130
+ # Validate the `rope_scaling` configuration.
131
+ # """
132
+ # if self.rope_scaling is None:
133
+ # return
134
+
135
+ # if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
136
+ # raise ValueError(
137
+ # "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
138
+ # f"got {self.rope_scaling}"
139
+ # )
140
+ # rope_scaling_type = self.rope_scaling.get("type", None)
141
+ # rope_scaling_factor = self.rope_scaling.get("factor", None)
142
+ # if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
143
+ # raise ValueError(
144
+ # f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
145
+ # )
146
+ # if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
147
+ # raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
eagle/model/configuration_minicpm.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The OpenBMB Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ MiniCPM model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
23
+
24
+
25
+ class MiniCPMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
28
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29
+ defaults will yield a similar configuration to that of the MiniCPM-7B.
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+ Args:
33
+ vocab_size (`int`, *optional*, defaults to 32000):
34
+ Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
35
+ `inputs_ids` passed when calling [`MiniCPMModel`]
36
+ hidden_size (`int`, *optional*, defaults to 4096):
37
+ Dimension of the hidden representations.
38
+ intermediate_size (`int`, *optional*, defaults to 11008):
39
+ Dimension of the MLP representations.
40
+ num_hidden_layers (`int`, *optional*, defaults to 32):
41
+ Number of hidden layers in the Transformer decoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 32):
43
+ Number of attention heads for each attention layer in the Transformer decoder.
44
+ num_key_value_heads (`int`, *optional*):
45
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
46
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
47
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
48
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
49
+ by meanpooling all the original heads within that group. For more details checkout [this
50
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
51
+ `num_attention_heads`.
52
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
53
+ The non-linear activation function (function or string) in the decoder.
54
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
55
+ The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
56
+ MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
57
+ initializer_range (`float`, *optional*, defaults to 0.02):
58
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
59
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
60
+ The epsilon used by the rms normalization layers.
61
+ use_cache (`bool`, *optional*, defaults to `True`):
62
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
63
+ relevant if `config.is_decoder=True`.
64
+ pad_token_id (`int`, *optional*):
65
+ Padding token id.
66
+ bos_token_id (`int`, *optional*, defaults to 1):
67
+ Beginning of stream token id.
68
+ eos_token_id (`int`, *optional*, defaults to 2):
69
+ End of stream token id.
70
+ pretraining_tp (`int`, *optional*, defaults to 1):
71
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
72
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
73
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
74
+ issue](https://github.com/pytorch/pytorch/issues/76232).
75
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
76
+ Whether to tie weight embeddings
77
+ rope_theta (`float`, *optional*, defaults to 10000.0):
78
+ The base period of the RoPE embeddings.
79
+ rope_scaling (`Dict`, *optional*):
80
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
81
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
82
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
83
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
84
+ these scaling strategies behave:
85
+ https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
86
+ experimental feature, subject to breaking API changes in future versions.
87
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
88
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
89
+ attention_dropout (`float`, *optional*, defaults to 0.0):
90
+ The dropout ratio for the attention probabilities.
91
+ ```python
92
+ >>> from transformers import MiniCPMModel, MiniCPMConfig
93
+ >>> # Initializing a MiniCPM minicpm-7b style configuration
94
+ >>> configuration = MiniCPMConfig()
95
+ >>> # Initializing a model from the minicpm-7b style configuration
96
+ >>> model = MiniCPMModel(configuration)
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+
101
+ model_type = 'minicpm'
102
+ keys_to_ignore_at_inference = ['past_key_values']
103
+
104
+ def __init__(
105
+ self,
106
+ vocab_size=32000,
107
+ hidden_size=4096,
108
+ intermediate_size=11008,
109
+ num_hidden_layers=32,
110
+ num_attention_heads=32,
111
+ num_key_value_heads=None,
112
+ hidden_act='silu',
113
+ max_position_embeddings=2048,
114
+ initializer_range=0.02,
115
+ rms_norm_eps=1e-6,
116
+ use_cache=True,
117
+ pad_token_id=None,
118
+ bos_token_id=1,
119
+ eos_token_id=2,
120
+ pretraining_tp=1,
121
+ tie_word_embeddings=True,
122
+ rope_theta=10000.0,
123
+ rope_scaling=None,
124
+ attention_bias=False,
125
+ attention_dropout=0.0,
126
+ scale_emb=1,
127
+ dim_model_base=1,
128
+ scale_depth=1,
129
+ mup_denominator=32,
130
+ sparse_config=None,
131
+ **kwargs):
132
+
133
+ self.vocab_size = vocab_size
134
+ self.max_position_embeddings = max_position_embeddings
135
+ self.hidden_size = hidden_size
136
+ self.intermediate_size = intermediate_size
137
+ self.num_hidden_layers = num_hidden_layers
138
+ self.num_attention_heads = num_attention_heads
139
+
140
+ # for backward compatibility
141
+ if num_key_value_heads is None:
142
+ num_key_value_heads = num_attention_heads
143
+
144
+ self.num_key_value_heads = num_key_value_heads
145
+ self.hidden_act = hidden_act
146
+ self.initializer_range = initializer_range
147
+ self.rms_norm_eps = rms_norm_eps
148
+ self.pretraining_tp = pretraining_tp
149
+ self.use_cache = use_cache
150
+ self.rope_theta = rope_theta
151
+ self.rope_scaling = rope_scaling
152
+ # self._rope_scaling_validation()
153
+ self.attention_bias = attention_bias
154
+ self.attention_dropout = attention_dropout
155
+ self.scale_emb = scale_emb
156
+ self.dim_model_base = dim_model_base
157
+ self.scale_depth = scale_depth
158
+ # only used for Eagle Head
159
+ self.mup_denominator = mup_denominator
160
+
161
+ # sparse config
162
+ self.sparse_config = sparse_config
163
+
164
+ super().__init__(
165
+ pad_token_id=pad_token_id,
166
+ bos_token_id=bos_token_id,
167
+ eos_token_id=eos_token_id,
168
+ tie_word_embeddings=tie_word_embeddings,
169
+ **kwargs,
170
+ )
171
+ try:
172
+ import flash_attn
173
+ self._attn_implementation = 'flash_attention_2'
174
+ except:
175
+ pass
176
+
177
+ def _rope_scaling_validation(self):
178
+ """
179
+ Validate the `rope_scaling` configuration.
180
+ """
181
+ if self.rope_scaling is None:
182
+ return
183
+
184
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
185
+ raise ValueError(
186
+ '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
187
+ f'got {self.rope_scaling}'
188
+ )
189
+ rope_scaling_type = self.rope_scaling.get('type', None)
190
+ rope_scaling_factor = self.rope_scaling.get('factor', None)
191
+ if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
192
+ raise ValueError(
193
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
194
+ )
195
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
196
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
eagle/model/ea_model.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import time
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from huggingface_hub import hf_hub_download
8
+ from transformers import AutoTokenizer
9
+ import os
10
+ from transformers import PreTrainedModel, PretrainedConfig, AutoConfig
11
+
12
+ from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
13
+ from .modeling_mixtral_kv import MixtralForCausalLM as KVMixtralForCausalLM
14
+ #from .modeling_qwen2_kv import LlamaForCausalLM as KVQwen2ForCausalLM
15
+ from .modeling_qwen2_kv import Qwen2ForCausalLM as KVQwen2ForCausalLM
16
+ from .utils import *
17
+ from .kv_cache import initialize_past_key_values
18
+
19
+ from .cnets import Model
20
+ from .cnets1 import Model as Model1
21
+ from .configs import EConfig
22
+
23
+ """ Modified to support Eagle-3, marked by <mod> xxx </mod> """
24
+ # from .modeling_minicpm_kv import HackConvertMiniCPMForCausalLM as KVMiniCPMForCausalLM # <mod> convert opensource impl to llama </mod>
25
+ from .modeling_minicpm_kv import MiniCPMForCausalLM as KVMiniCPMForCausalLM # <mod> use modified opensource impl </mod>
26
+
27
+ class EaModel(nn.Module):
28
+
29
+ def __init__(
30
+ self,
31
+ use_eagle3,
32
+ base_model,
33
+ base_model_name_or_path,
34
+ ea_model_path,
35
+ total_token,
36
+ depth,
37
+ top_k,
38
+ threshold,
39
+ ea_layer_state_dict,
40
+ ):
41
+
42
+ super().__init__()
43
+ self.base_model = base_model
44
+ self.config = base_model.config
45
+ self.hidden_size = base_model.lm_head.weight.shape[-1]
46
+ self.vocab_size = base_model.lm_head.weight.shape[0]
47
+ self.base_model_name_or_path = base_model_name_or_path
48
+ self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path, use_fast=False)
49
+ self.use_eagle3 = use_eagle3
50
+ config = EConfig.from_pretrained(ea_model_path)
51
+ with open(ea_model_path, "r") as f:
52
+ con = json.loads(f.read())
53
+ try:
54
+ bias = con["bias"]
55
+ except:
56
+ bias = True
57
+ if use_eagle3:
58
+ self.ea_layer = Model(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k,
59
+ threshold=threshold, path=base_model_name_or_path,load_emb=True)
60
+ else:
61
+ self.ea_layer = Model1(config, bias=bias, total_tokens=total_token, depth=depth, top_k=top_k,
62
+ threshold=threshold, path=base_model_name_or_path,load_emb=True)
63
+
64
+ low_memory = False
65
+
66
+ device = base_model.model.layers[-1].self_attn.q_proj.weight.device
67
+ if device != base_model.lm_head.weight.device:
68
+ self.ea_layer.diff_device = True
69
+ if not low_memory:
70
+ self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device)
71
+ else:
72
+ self.ea_layer.layer_device = device
73
+
74
+ else:
75
+ self.ea_layer.diff_device = False
76
+ if self.use_eagle3 and config.vocab_size==config.draft_vocab_size:
77
+ del self.ea_layer.d2t,self.ea_layer.t2d
78
+ load_=self.ea_layer.load_state_dict(ea_layer_state_dict, strict=False)
79
+ self.ea_layer.to(self.base_model.dtype).to(device)
80
+ self.ea_layer.init_tree()
81
+
82
+ def get_tokenizer(self):
83
+ """Get the tokenizer of the base model.
84
+
85
+ Returns:
86
+ Tokenizer: The tokenizer of the base model.
87
+ """
88
+ return self.tokenizer
89
+
90
+ @classmethod
91
+ def from_pretrained(
92
+ cls,
93
+ use_eagle3=True,
94
+ base_model_path=None,
95
+ ea_model_path=None,
96
+ total_token=60,
97
+ depth=7,
98
+ top_k=10,
99
+ threshold=1.0,
100
+ **kwargs,
101
+ ):
102
+ # assert Type=="LLaMA" or "Mixtral"
103
+ Type = AutoConfig.from_pretrained(base_model_path, trust_remote_code=True).architectures[0]
104
+
105
+ if Type == 'LlamaForCausalLM':
106
+ base_model = KVLlamaForCausalLM.from_pretrained(
107
+ base_model_path, **kwargs
108
+ )
109
+ elif Type == 'Qwen2ForCausalLM':
110
+ base_model = KVQwen2ForCausalLM.from_pretrained(
111
+ base_model_path, **kwargs
112
+ )
113
+ elif Type == 'MiniCPMForCausalLM': # <mod> support MiniCPMForCausalLM
114
+ base_model = KVMiniCPMForCausalLM.from_pretrained(
115
+ base_model_path, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True,
116
+ )
117
+ # </mod>
118
+ else:
119
+ base_model = KVMixtralForCausalLM.from_pretrained(
120
+ base_model_path, **kwargs
121
+ )
122
+
123
+ # <mod>
124
+ # configpath = os.path.join(ea_model_path, "config.json")
125
+ # if not os.path.exists(configpath):
126
+ # configpath = hf_hub_download(ea_model_path, "config.json")
127
+
128
+ # try:
129
+ # load_model_path = os.path.join(ea_model_path, "pytorch_model.bin")
130
+ # if not os.path.exists(load_model_path):
131
+ # load_model_path = hf_hub_download(ea_model_path, "pytorch_model.bin")
132
+ # ea_layer_state_dict = torch.load(load_model_path,
133
+ # map_location=base_model.device)
134
+ # except:
135
+ # from safetensors.torch import load_file
136
+ # load_model_path = os.path.join(ea_model_path, "model.safetensors")
137
+ # if not os.path.exists(load_model_path):
138
+ # load_model_path = hf_hub_download(ea_model_path, "model.safetensors")
139
+ # ea_layer_state_dict = load_file(load_model_path)
140
+ # <before-after-mod> -------------------------------------------------
141
+ # ### <rewrite> new loading logic to support subfolder on hf api
142
+ try:
143
+ configpath = os.path.join(ea_model_path, "config.json")
144
+ load_model_path = os.path.join(ea_model_path, "pytorch_model.bin")
145
+ if not os.path.exists(configpath):
146
+ configpath = hf_hub_download(ea_model_path, "config.json")
147
+ if not os.path.exists(load_model_path):
148
+ load_model_path = hf_hub_download(ea_model_path, "pytorch_model.bin")
149
+ except:
150
+ folder_names = ea_model_path.split("/")
151
+ repo = "/".join(folder_names[:-1])
152
+ subfolder = folder_names[-1]
153
+ configpath = hf_hub_download(repo_id = repo, subfolder = subfolder, filename = "config.json")
154
+ load_model_path = hf_hub_download(repo_id = repo, subfolder = subfolder, filename = "pytorch_model.bin")
155
+
156
+ ea_layer_state_dict = torch.load(load_model_path, map_location=base_model.device)
157
+ # </mod>
158
+
159
+ model = cls(
160
+ use_eagle3,
161
+ base_model,
162
+ base_model_path,
163
+ configpath,
164
+ total_token,
165
+ depth,
166
+ top_k,
167
+ threshold,
168
+ ea_layer_state_dict
169
+ )
170
+
171
+ if total_token == -1:
172
+ device = model.base_model.model.layers[0].self_attn.q_proj.weight.device
173
+ cans = [40, 48, 50, 56, 60]
174
+ x = [1, 1.05, 1.07, 1.1, 1.13]
175
+ times = []
176
+
177
+ for i in range(len(cans)):
178
+ length = cans[i]
179
+ input_ids = torch.randint(0, model.config.vocab_size - 200, (1, length)).to(device)
180
+ torch.cuda.synchronize()
181
+ start_time = time.time()
182
+ for _ in range(20):
183
+ torch.cuda.synchronize()
184
+ with torch.no_grad():
185
+ outputs = model.base_model(input_ids)
186
+ torch.cuda.synchronize()
187
+ torch.cuda.synchronize()
188
+ end_time = time.time()
189
+ times.append((end_time - start_time) / x[i])
190
+ total_token = cans[times.index(min(times))]
191
+ model.ea_layer.total_tokens = total_token - 1
192
+
193
+ return model
194
+
195
+ def forward(
196
+ self,
197
+ input_ids=None,
198
+ attention_mask=None,
199
+ past_key_values=None,
200
+ output_orig=False,
201
+ position_ids=None,
202
+ ):
203
+
204
+ with torch.inference_mode():
205
+ # Pass input through the base model
206
+ outputs = self.base_model.model(
207
+ input_ids=input_ids,
208
+ attention_mask=attention_mask,
209
+ past_key_values=past_key_values,
210
+ position_ids=position_ids,
211
+ )
212
+ if output_orig:
213
+ orig = self.base_model.lm_head(outputs[0])
214
+ hidden_states = outputs[0]
215
+
216
+ if output_orig:
217
+ return outputs, orig, hidden_states
218
+ else:
219
+ return outputs, hidden_states
220
+
221
+ @torch.no_grad()
222
+ def eagenerate(
223
+ self,
224
+ input_ids,
225
+ temperature=0.0,
226
+ top_p=0.0,
227
+ top_k=0.0,
228
+ max_new_tokens=512,
229
+ max_length=2048,
230
+ log=False,
231
+ is_llama3=False,
232
+
233
+ ):
234
+ if is_llama3:
235
+ stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
236
+
237
+
238
+ if temperature > 1e-5:
239
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
240
+ else:
241
+ logits_processor = None
242
+ # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
243
+ # Avoid modifying the input_ids in-place
244
+
245
+ padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
246
+ input_ids = input_ids.clone()
247
+ self.ea_layer.reset_kv()
248
+
249
+ # Initialize the past key and value states
250
+ if hasattr(self, "past_key_values"):
251
+ past_key_values = self.past_key_values
252
+ past_key_values_data = self.past_key_values_data
253
+ current_length_data = self.current_length_data
254
+ # Reset the past key and value states
255
+ current_length_data.zero_()
256
+ else:
257
+ (
258
+ past_key_values,
259
+ past_key_values_data,
260
+ current_length_data,
261
+ ) = initialize_past_key_values(self.base_model,max_length=max_length)
262
+ self.past_key_values = past_key_values
263
+ self.past_key_values_data = past_key_values_data
264
+ self.current_length_data = current_length_data
265
+
266
+ input_len = input_ids.shape[1]
267
+ reset_tree_mode(self)
268
+ # prefill
269
+ draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
270
+ input_ids, self, past_key_values, logits_processor
271
+ )
272
+ new_token = 0
273
+ max_length = max_length - self.ea_layer.total_tokens - 10
274
+ for idx in range(max_length):
275
+ # with Timer("all"):
276
+ self.base_model.model.tree_mask = tree_mask
277
+
278
+ draft_tokens = draft_tokens.to(input_ids.device)
279
+ # Target model forward, get logits
280
+ logits, hidden_state_new, outputs = tree_decoding(
281
+ self,
282
+ draft_tokens,
283
+ past_key_values,
284
+ tree_position_ids,
285
+ input_ids,
286
+ retrieve_indices,
287
+ )
288
+ # retrieve_indices=tree_buffers["retrieve_indices"]
289
+ # logits = logits[0, retrieve_indices]
290
+ draft_tokens = torch.cat((draft_tokens, padding), dim=1)
291
+ candidates = draft_tokens[0, retrieve_indices]
292
+ # verification
293
+ best_candidate, accept_length, sample_p = evaluate_posterior(
294
+ logits, candidates, logits_processor
295
+ )
296
+ # print(accept_length)
297
+ # Adjusting the input sequence, draft model forward
298
+ input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs(
299
+ input_ids,
300
+ candidates,
301
+ best_candidate,
302
+ accept_length,
303
+ retrieve_indices,
304
+ logits_processor,
305
+ new_token,
306
+ past_key_values_data,
307
+ current_length_data,
308
+ self,
309
+ hidden_state_new,
310
+ sample_p
311
+ )
312
+
313
+ if is_llama3:
314
+ if stop_token_id in input_ids[0, input_len:].tolist():
315
+ break
316
+
317
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
318
+ break
319
+ if new_token > max_new_tokens:
320
+ break
321
+ if input_ids.shape[1] > max_length:
322
+ break
323
+ if not log:
324
+ return input_ids
325
+ else:
326
+ return input_ids, new_token, idx
327
+
328
+ @torch.no_grad()
329
+ def naivegenerate(
330
+ self,
331
+ input_ids,
332
+ temperature=0.0,
333
+ top_p=0.0,
334
+ top_k=0.0,
335
+ max_new_tokens=512,
336
+ max_length=2048,
337
+ log=False,
338
+ is_llama3=False,
339
+
340
+ ):
341
+ if is_llama3:
342
+ stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
343
+
344
+
345
+ if temperature > 1e-5:
346
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
347
+ else:
348
+ logits_processor = None
349
+ # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
350
+ # Avoid modifying the input_ids in-place
351
+
352
+ padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
353
+ input_ids = input_ids.clone()
354
+ self.ea_layer.reset_kv()
355
+
356
+ # Initialize the past key and value states
357
+ if hasattr(self, "past_key_values"):
358
+ past_key_values = self.past_key_values
359
+ past_key_values_data = self.past_key_values_data
360
+ current_length_data = self.current_length_data
361
+ # Reset the past key and value states
362
+ current_length_data.zero_()
363
+ else:
364
+ (
365
+ past_key_values,
366
+ past_key_values_data,
367
+ current_length_data,
368
+ ) = initialize_past_key_values(self.base_model,max_length=max_length)
369
+ self.past_key_values = past_key_values
370
+ self.past_key_values_data = past_key_values_data
371
+ self.current_length_data = current_length_data
372
+
373
+ input_len = input_ids.shape[1]
374
+ reset_tree_mode(self)
375
+ outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
376
+ new_token = 0
377
+ max_length = max_length - self.ea_layer.total_tokens - 10
378
+ for idx in range(max_length):
379
+ if logits_processor is not None:
380
+ logits = outputs.logits[:, -1]
381
+ logits = logits_processor(None, logits)
382
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
383
+ input_id = torch.multinomial(probabilities, 1)
384
+ else:
385
+ input_id = outputs.logits[:, -1:].argmax(dim=-1)
386
+ outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
387
+ input_ids = torch.cat([input_ids, input_id], dim=-1)
388
+ new_token += 1
389
+
390
+ if is_llama3:
391
+ if stop_token_id in input_ids[0, input_len:].tolist():
392
+ break
393
+
394
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
395
+ break
396
+ if new_token > max_new_tokens:
397
+ break
398
+ if input_ids.shape[1] > max_length:
399
+ break
400
+ if not log:
401
+ return input_ids
402
+ else:
403
+ return input_ids, new_token, idx
404
+
405
+ @torch.no_grad()
406
+ def ea_generate(
407
+ self,
408
+ input_ids,
409
+ temperature=0.0,
410
+ top_p=0.0,
411
+ top_k=0.0,
412
+ max_new_tokens=512,
413
+ max_length=2048,
414
+ log=False,
415
+ is_llama3=False,
416
+
417
+ ):
418
+ if is_llama3:
419
+ stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
420
+
421
+
422
+ if temperature > 1e-5:
423
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
424
+ else:
425
+ logits_processor = None
426
+ # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
427
+ # Avoid modifying the input_ids in-place
428
+
429
+ padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
430
+ input_ids = input_ids.clone()
431
+ self.ea_layer.reset_kv()
432
+
433
+ # Initialize the past key and value states
434
+ if hasattr(self, "past_key_values"):
435
+ past_key_values = self.past_key_values
436
+ past_key_values_data = self.past_key_values_data
437
+ current_length_data = self.current_length_data
438
+ # Reset the past key and value states
439
+ current_length_data.zero_()
440
+ else:
441
+ (
442
+ past_key_values,
443
+ past_key_values_data,
444
+ current_length_data,
445
+ ) = initialize_past_key_values(self.base_model,max_length=max_length)
446
+ self.past_key_values = past_key_values
447
+ self.past_key_values_data = past_key_values_data
448
+ self.current_length_data = current_length_data
449
+
450
+ input_len = input_ids.shape[1]
451
+ reset_tree_mode(self)
452
+ draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
453
+ input_ids, self, past_key_values, logits_processor
454
+ )
455
+ new_token = 0
456
+ max_length = max_length - self.ea_layer.total_tokens - 10
457
+ for idx in range(max_length):
458
+ # with Timer("all"):
459
+ self.base_model.model.tree_mask = tree_mask
460
+
461
+ draft_tokens = draft_tokens.to(input_ids.device)
462
+ # with Timer("tree_decoding"):
463
+ logits, hidden_state_new, outputs = tree_decoding(
464
+ self,
465
+ draft_tokens,
466
+ past_key_values,
467
+ tree_position_ids,
468
+ input_ids,
469
+ retrieve_indices,
470
+ )
471
+ # retrieve_indices=tree_buffers["retrieve_indices"]
472
+ # logits = logits[0, retrieve_indices]
473
+ draft_tokens = torch.cat((draft_tokens, padding), dim=1)
474
+ candidates = draft_tokens[0, retrieve_indices]
475
+ best_candidate, accept_length, sample_p = evaluate_posterior(
476
+ logits, candidates, logits_processor
477
+ )
478
+ # print(accept_length)
479
+ # with Timer("update_inference_inputs"):
480
+ input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs(
481
+ input_ids,
482
+ candidates,
483
+ best_candidate,
484
+ accept_length,
485
+ retrieve_indices,
486
+ logits_processor,
487
+ new_token,
488
+ past_key_values_data,
489
+ current_length_data,
490
+ self,
491
+ hidden_state_new,
492
+ sample_p
493
+ )
494
+
495
+ yield input_ids
496
+
497
+ if is_llama3:
498
+ if stop_token_id in input_ids[0, input_len:].tolist():
499
+ break
500
+
501
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
502
+ break
503
+ if new_token > max_new_tokens:
504
+ break
505
+ if input_ids.shape[1] > max_length:
506
+ break
507
+
508
+ @torch.no_grad()
509
+ def naive_generate(
510
+ self,
511
+ input_ids,
512
+ temperature=0.0,
513
+ top_p=0.0,
514
+ top_k=0.0,
515
+ max_new_tokens=512,
516
+ max_length=2048,
517
+ log=False,
518
+ is_llama3=False,
519
+
520
+ ):
521
+ if is_llama3:
522
+ stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
523
+
524
+
525
+ if temperature > 1e-5:
526
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
527
+ else:
528
+ logits_processor = None
529
+ # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
530
+ # Avoid modifying the input_ids in-place
531
+
532
+ padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
533
+ input_ids = input_ids.clone()
534
+ self.ea_layer.reset_kv()
535
+
536
+ # Initialize the past key and value states
537
+ if hasattr(self, "past_key_values"):
538
+ past_key_values = self.past_key_values
539
+ past_key_values_data = self.past_key_values_data
540
+ current_length_data = self.current_length_data
541
+ # Reset the past key and value states
542
+ current_length_data.zero_()
543
+ else:
544
+ (
545
+ past_key_values,
546
+ past_key_values_data,
547
+ current_length_data,
548
+ ) = initialize_past_key_values(self.base_model,max_length=max_length)
549
+ self.past_key_values = past_key_values
550
+ self.past_key_values_data = past_key_values_data
551
+ self.current_length_data = current_length_data
552
+
553
+ input_len = input_ids.shape[1]
554
+ reset_tree_mode(self)
555
+ outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
556
+ new_token = 0
557
+ max_length = max_length - self.ea_layer.total_tokens - 10
558
+ for idx in range(max_length):
559
+ if logits_processor is not None:
560
+ logits = outputs.logits[:, -1]
561
+ logits = logits_processor(None, logits)
562
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
563
+ input_id = torch.multinomial(probabilities, 1)
564
+ else:
565
+ input_id = outputs.logits[:, -1:].argmax(dim=-1)
566
+
567
+ outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
568
+ input_ids = torch.cat([input_ids, input_id], dim=-1)
569
+ new_token += 1
570
+
571
+ yield input_ids
572
+
573
+ if is_llama3:
574
+ if stop_token_id in input_ids[0, input_len:].tolist():
575
+ break
576
+
577
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
578
+ break
579
+ if new_token > max_new_tokens:
580
+ break
581
+ if input_ids.shape[1] > max_length:
582
+ break
eagle/model/kv_cache.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class KVCache:
5
+ """
6
+ A key-value cache for the model.
7
+
8
+ This class provides a mechanism to maintain a growing cache of keys and values,
9
+ particularly useful for models that benefit from caching previous states,
10
+ like transformers during autoregressive decoding.
11
+
12
+ Attributes:
13
+ data (torch.Tensor): The tensor storing keys and values.
14
+ current_length (int): Current length of the data being stored.
15
+ """
16
+
17
+ def __init__(self, data, current_length):
18
+ """
19
+ Initialize the KVCache.
20
+
21
+ Args:
22
+ data (torch.Tensor): Initial tensor to store the keys and values.
23
+ current_length (int): Initial length of the data.
24
+ """
25
+ self.data = data
26
+ self.current_length = current_length
27
+
28
+ @property
29
+ def shape(self):
30
+ """Return the shape of the data tensor with updated length."""
31
+ return (
32
+ self.data.shape[0],
33
+ self.data.shape[1],
34
+ self.current_length.item(),
35
+ self.data.shape[3],
36
+ )
37
+
38
+ def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2):
39
+ """
40
+ Copy values from the current data at specified indices to a new location.
41
+
42
+ Args:
43
+ indices (torch.Tensor): Indices of the data tensor to be copied.
44
+ prev_length (int): Previous length before adding new data.
45
+ dim (int, optional): Dimension along which copying should be performed. Default is 2.
46
+ """
47
+ tgt = self.data.index_select(dim, indices)
48
+ dst = self.data.narrow(dim, prev_length, tgt.shape[dim])
49
+ dst.copy_(tgt, non_blocking=True)
50
+ self.current_length.fill_(prev_length + tgt.shape[dim])
51
+
52
+ def cat(self, tensor: torch.Tensor, dim: int = 2):
53
+ """
54
+ Concatenate the given tensor with the current data.
55
+
56
+ Args:
57
+ tensor (torch.Tensor): The tensor to be concatenated.
58
+ dim (int, optional): The dimension along which concatenation should be done. Default is 2.
59
+
60
+ Returns:
61
+ torch.Tensor: The data tensor after concatenation up to the current length.
62
+ """
63
+ dst = self.data.narrow(dim, self.current_length, tensor.shape[dim])
64
+ dst.copy_(tensor)
65
+ self.current_length.add_(tensor.shape[dim])
66
+ return torch.narrow(self.data, 2, 0, self.current_length)
67
+
68
+
69
+ def initialize_past_key_values(model,max_length=2200):
70
+ """
71
+ Initialize past key and value states for a given transformer model.
72
+
73
+ This function prepares key-value cache structures for the model, allowing it to store and reuse
74
+ past key and value states during autoregressive decoding, which can improve efficiency.
75
+
76
+ Args:
77
+ model (nn.Module): The transformer model for which past key-value states need to be initialized.
78
+
79
+ Returns:
80
+ tuple:
81
+ - past_key_values (list): A list of KVCache objects for each layer in the model.
82
+ - past_key_values_data (torch.Tensor): The tensor that will store all keys and values.
83
+ - current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache.
84
+ """
85
+ # Extracting configuration from the model
86
+ config = model.config
87
+ # Initializing the batch size to 1, this can be modified if different batch sizes are required
88
+ batch_size = 1
89
+ # Initializing a tensor to store past keys and values for all layers
90
+
91
+ devices=[]
92
+ for i in range(config.num_hidden_layers):
93
+ try:
94
+ device = model.model.layers[i].self_attn.q_proj.weight.device
95
+ except:
96
+ device=model.layers[i].self_attn.q_proj.weight.device
97
+ devices.append(device)
98
+ past_key_values_data_list=[]
99
+ startnum=0
100
+ startdevice=devices[0]
101
+ for id,i in enumerate(devices):
102
+ if startdevice!=i:
103
+ past_key_values_data = torch.zeros(
104
+ startnum * 2,
105
+ batch_size,
106
+ config.num_key_value_heads,
107
+ max_length,
108
+ config.hidden_size // config.num_attention_heads,
109
+ device=startdevice,
110
+ dtype=model.dtype,
111
+ )
112
+ past_key_values_data_list.append(past_key_values_data)
113
+ startdevice = i
114
+ startnum=0
115
+ startnum += 1
116
+ past_key_values_data = torch.zeros(
117
+ startnum * 2,
118
+ batch_size,
119
+ config.num_key_value_heads,
120
+ max_length,
121
+ config.hidden_size // config.num_attention_heads,
122
+ device=startdevice,
123
+ dtype=model.dtype,
124
+ )
125
+ past_key_values_data_list.append(past_key_values_data)
126
+ # Initialize tensor to store the current length of the cached data for all layers.
127
+ # [IMPORTANT] It needs to be kept on CPU for quick access and updates.
128
+ current_length_data = torch.zeros(
129
+ config.num_hidden_layers * 2, dtype=torch.long, device="cpu"
130
+ )
131
+ # Creating a KVCache for each pair of key and value in all layers
132
+ past_key_values = [] * config.num_hidden_layers
133
+
134
+ bias=0
135
+ start_data_m=devices[0].index
136
+ for i in range(config.num_hidden_layers):
137
+ data_m=devices[i].index
138
+ if data_m!=start_data_m:
139
+ bias=0
140
+ start_data_m=data_m
141
+ try:
142
+ past_key_values.append(
143
+ [
144
+ KVCache(past_key_values_data_list[data_m-devices[0].index][2*bias + j], current_length_data[i * 2 + j])
145
+ for j in range(2)
146
+ ]
147
+ )
148
+ except:
149
+ past_key_values.append(
150
+ [
151
+ KVCache(past_key_values_data_list[0][2 * bias + j],
152
+ current_length_data[i * 2 + j])
153
+ for j in range(2)
154
+ ]
155
+ )
156
+ bias+=1
157
+ return past_key_values, past_key_values_data_list, current_length_data
eagle/model/modeling_llama_kv.py ADDED
@@ -0,0 +1,1597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://github.com/huggingface/transformers/blob/v4.31-release/src/transformers/models/llama/modeling_llama.py
2
+ # Modifications are denoted by the symbol: [MODIFIED]
3
+
4
+
5
+ """ PyTorch LLaMA model."""
6
+ import math
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+
15
+ # [MODIFIED] Import from transformer library
16
+ from transformers.activations import ACT2FN
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ SequenceClassifierOutputWithPast,
21
+ )
22
+ from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import (
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
30
+ from transformers import LlamaConfig
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ _CONFIG_FOR_DOC = "LlamaConfig"
35
+
36
+
37
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
38
+ def _make_causal_mask(
39
+ input_ids_shape: torch.Size,
40
+ dtype: torch.dtype,
41
+ device: torch.device,
42
+ past_key_values_length: int = 0,
43
+ ):
44
+ """
45
+ Create a causal mask for bi-directional self-attention.
46
+
47
+ Args:
48
+ input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len).
49
+ dtype (torch.dtype): The data type of the mask.
50
+ device (torch.device): The device on which the mask will be placed.
51
+ past_key_values_length (int, optional): The length of past key values. Default is 0.
52
+
53
+ Returns:
54
+ torch.Tensor: The causal mask tensor.
55
+ """
56
+ bsz, tgt_len = input_ids_shape
57
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
58
+ mask_cond = torch.arange(mask.size(-1), device=device)
59
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
60
+ mask = mask.to(dtype)
61
+
62
+ if past_key_values_length > 0:
63
+ mask = torch.cat(
64
+ [
65
+ torch.zeros(
66
+ tgt_len, past_key_values_length, dtype=dtype, device=device
67
+ ),
68
+ mask,
69
+ ],
70
+ dim=-1,
71
+ )
72
+ return mask[None, None, :, :].expand(
73
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
74
+ )
75
+
76
+
77
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
78
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
79
+ """
80
+ Expand attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
81
+
82
+ Args:
83
+ mask (torch.Tensor): The attention mask tensor of shape `[bsz, seq_len]`.
84
+ dtype (torch.dtype): The data type of the mask.
85
+ tgt_len (Optional[int], optional): The target sequence length. If None, it defaults to the source sequence length.
86
+
87
+ Returns:
88
+ torch.Tensor: The expanded mask tensor.
89
+ """
90
+ bsz, src_len = mask.size()
91
+ tgt_len = tgt_len if tgt_len is not None else src_len
92
+
93
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
94
+
95
+ inverted_mask = 1.0 - expanded_mask
96
+
97
+ return inverted_mask.masked_fill(
98
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
99
+ )
100
+
101
+
102
+
103
+
104
+ class LlamaRMSNorm(nn.Module):
105
+ """
106
+ LlamaRMSNorm is equivalent to T5LayerNorm.
107
+
108
+ Args:
109
+ hidden_size (int): The size of the hidden states.
110
+ eps (float, optional): A small value to prevent division by zero. Default is 1e-6.
111
+ """
112
+
113
+ def __init__(self, hidden_size, eps=1e-6):
114
+ super().__init__()
115
+ self.weight = nn.Parameter(torch.ones(hidden_size))
116
+ self.variance_epsilon = eps
117
+
118
+ def forward(self, hidden_states):
119
+ """
120
+ Apply LlamaRMSNorm to the input hidden states.
121
+
122
+ Args:
123
+ hidden_states (torch.Tensor): Input hidden states.
124
+
125
+ Returns:
126
+ torch.Tensor: The normalized and scaled hidden states.
127
+ """
128
+ input_dtype = hidden_states.dtype
129
+ hidden_states = hidden_states.to(torch.float32)
130
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
131
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
132
+ return self.weight * hidden_states.to(input_dtype)
133
+
134
+
135
+ class LlamaRotaryEmbedding(nn.Module):
136
+ """
137
+ Llama Rotary Positional Embedding Module.
138
+
139
+ Args:
140
+ dim (int): The dimension of the embedding.
141
+ max_position_embeddings (int, optional): The maximum position for embeddings. Default is 2048.
142
+ base (int, optional): The base value for rotational encoding. Default is 10000.
143
+ device (str, optional): The device on which the computation will be performed. Default is None.
144
+ """
145
+
146
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
147
+ super().__init__()
148
+
149
+ self.dim = dim
150
+ self.max_position_embeddings = max_position_embeddings
151
+ self.base = base
152
+ inv_freq = 1.0 / (
153
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
154
+ )
155
+ self.register_buffer("inv_freq", inv_freq)
156
+
157
+ # Build here to make `torch.jit.trace` work.
158
+ self._set_cos_sin_cache(
159
+ seq_len=max_position_embeddings,
160
+ device=self.inv_freq.device,
161
+ dtype=torch.get_default_dtype(),
162
+ )
163
+
164
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
165
+ """
166
+ Set the cosine and sine cache for positional embeddings.
167
+
168
+ Args:
169
+ seq_len (int): The sequence length.
170
+ device (str): The device on which the cache tensors will be stored.
171
+ dtype: The data type of the cache tensors.
172
+ """
173
+ self.max_seq_len_cached = seq_len
174
+ t = torch.arange(
175
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
176
+ )
177
+
178
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
179
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
180
+ emb = torch.cat((freqs, freqs), dim=-1)
181
+ self.register_buffer(
182
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
183
+ )
184
+ self.register_buffer(
185
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
186
+ )
187
+
188
+ def forward(self, x, seq_len=None):
189
+ """
190
+ Forward pass of the LlamaRotaryEmbedding module.
191
+
192
+ Args:
193
+ x (torch.Tensor): Input tensor of shape [bs, num_attention_heads, seq_len, head_size].
194
+ seq_len (int): The sequence length. If greater than the cached length, the cache will be updated.
195
+
196
+ Returns:
197
+ tuple: A tuple containing two tensors, the cosine and sine embeddings, both of shape [1, 1, seq_len, dim].
198
+ """
199
+ if seq_len > self.max_seq_len_cached:
200
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
201
+
202
+ return (
203
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
204
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
205
+ )
206
+
207
+
208
+ class LlamaRotaryEmbedding_L31(nn.Module):
209
+ def __init__(
210
+ self,
211
+ dim=None,
212
+ max_position_embeddings=2048,
213
+ base=10000,
214
+ device=None,
215
+ scaling_factor=1.0,
216
+ rope_type="default",
217
+ config: Optional[LlamaConfig] = None,
218
+ ):
219
+ super().__init__()
220
+ # TODO (joao): remove the `if` below, only used for BC
221
+ self.rope_kwargs = {}
222
+ if config is None:
223
+ logger.warning_once(
224
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
225
+ "`config` argument. All other arguments will be removed in v4.46"
226
+ )
227
+ self.rope_kwargs = {
228
+ "rope_type": rope_type,
229
+ "factor": scaling_factor,
230
+ "dim": dim,
231
+ "base": base,
232
+ "max_position_embeddings": max_position_embeddings,
233
+ }
234
+ self.rope_type = rope_type
235
+ self.max_seq_len_cached = max_position_embeddings
236
+ self.original_max_seq_len = max_position_embeddings
237
+ else:
238
+ # BC: "rope_type" was originally "type"
239
+ if config.rope_scaling is not None:
240
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
241
+ else:
242
+ self.rope_type = "default"
243
+ self.max_seq_len_cached = config.max_position_embeddings
244
+ self.original_max_seq_len = config.max_position_embeddings
245
+
246
+ self.config = config
247
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
248
+
249
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
250
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
251
+ self.original_inv_freq = self.inv_freq
252
+
253
+ def _dynamic_frequency_update(self, position_ids, device):
254
+ """
255
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
256
+ 1 - growing beyond the cached sequence length (allow scaling)
257
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
258
+ """
259
+ seq_len = torch.max(position_ids) + 1
260
+ if seq_len > self.max_seq_len_cached: # growth
261
+ inv_freq, self.attention_scaling = self.rope_init_fn(
262
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
263
+ )
264
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
265
+ self.max_seq_len_cached = seq_len
266
+
267
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
268
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
269
+ self.max_seq_len_cached = self.original_max_seq_len
270
+
271
+ @torch.no_grad()
272
+ def forward(self, x, position_ids):
273
+ if "dynamic" in self.rope_type:
274
+ self._dynamic_frequency_update(position_ids, device=x.device)
275
+
276
+ # Core RoPE block
277
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
278
+ position_ids_expanded = position_ids[:, None, :].float()
279
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
280
+ device_type = x.device.type
281
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
282
+ with torch.autocast(device_type=device_type, enabled=False):
283
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
284
+ emb = torch.cat((freqs, freqs), dim=-1)
285
+ cos = emb.cos()
286
+ sin = emb.sin()
287
+
288
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
289
+ cos = cos * self.attention_scaling
290
+ sin = sin * self.attention_scaling
291
+
292
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
293
+
294
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
295
+ """
296
+ LlamaRotaryEmbedding extended with linear scaling.
297
+
298
+ This class adds linear scaling to LlamaRotaryEmbedding. Credits to the Reddit user /u/kaiokendev.
299
+
300
+ Args:
301
+ dim (int): The dimension of the embedding.
302
+ max_position_embeddings (int, optional): The maximum number of position embeddings. Default is 2048.
303
+ base (int, optional): The base value for the rotational embeddings. Default is 10000.
304
+ device (str or torch.device, optional): The device where the embeddings should be stored. Default is None.
305
+ scaling_factor (float, optional): The scaling factor for the embeddings. Default is 1.0.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ dim,
311
+ max_position_embeddings=2048,
312
+ base=10000,
313
+ device=None,
314
+ scaling_factor=1.0,
315
+ ):
316
+ self.scaling_factor = scaling_factor
317
+ super().__init__(dim, max_position_embeddings, base, device)
318
+
319
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
320
+ """
321
+ Set the cosine and sine cache for the rotary embeddings.
322
+
323
+ Args:
324
+ seq_len (int): The sequence length.
325
+ device (str or torch.device): The device where the cache should be stored.
326
+ dtype: The data type for the cache.
327
+ """
328
+ self.max_seq_len_cached = seq_len
329
+ t = torch.arange(
330
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
331
+ )
332
+ t = t / self.scaling_factor
333
+
334
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
335
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
336
+ emb = torch.cat((freqs, freqs), dim=-1)
337
+ self.register_buffer(
338
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
339
+ )
340
+ self.register_buffer(
341
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
342
+ )
343
+
344
+
345
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
346
+ """
347
+ LlamaRotaryEmbedding extended with Dynamic NTK scaling.
348
+
349
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
350
+ """
351
+
352
+ def __init__(
353
+ self,
354
+ dim,
355
+ max_position_embeddings=2048,
356
+ base=10000,
357
+ device=None,
358
+ scaling_factor=1.0,
359
+ ):
360
+ """
361
+ Initialize the LlamaDynamicNTKScalingRotaryEmbedding.
362
+
363
+ Args:
364
+ dim (int): The dimensionality of the embedding.
365
+ max_position_embeddings (int, optional): Maximum number of position embeddings. Default is 2048.
366
+ base (int, optional): Base value for scaling calculations. Default is 10000.
367
+ device: The device to place tensors on. If None, uses the default device.
368
+ scaling_factor (float, optional): Scaling factor for NTK scaling. Default is 1.0.
369
+ """
370
+ self.scaling_factor = scaling_factor
371
+ super().__init__(dim, max_position_embeddings, base, device)
372
+
373
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
374
+ """
375
+ Set the cached values for cosine and sine.
376
+
377
+ Args:
378
+ seq_len (int): The sequence length.
379
+ device: The device to place tensors on.
380
+ dtype: The data type of tensors.
381
+ """
382
+ self.max_seq_len_cached = seq_len
383
+
384
+ if seq_len > self.max_position_embeddings:
385
+ base = self.base * (
386
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
387
+ - (self.scaling_factor - 1)
388
+ ) ** (self.dim / (self.dim - 2))
389
+ inv_freq = 1.0 / (
390
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
391
+ )
392
+ self.register_buffer("inv_freq", inv_freq)
393
+
394
+ t = torch.arange(
395
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
396
+ )
397
+
398
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
399
+ emb = torch.cat((freqs, freqs), dim=-1)
400
+ self.register_buffer(
401
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
402
+ )
403
+ self.register_buffer(
404
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
405
+ )
406
+
407
+
408
+ def rotate_half(x):
409
+ """
410
+ Rotates half the hidden dimensions of the input.
411
+
412
+ Args:
413
+ x (torch.Tensor): Input tensor.
414
+
415
+ Returns:
416
+ torch.Tensor: Tensor with half of its hidden dimensions rotated.
417
+ """
418
+ x1 = x[..., : x.shape[-1] // 2]
419
+ x2 = x[..., x.shape[-1] // 2:]
420
+ return torch.cat((-x2, x1), dim=-1)
421
+
422
+
423
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
424
+ """
425
+ Apply rotary position embeddings to query and key tensors.
426
+
427
+ Args:
428
+ q (torch.Tensor): Query tensor.
429
+ k (torch.Tensor): Key tensor.
430
+ cos (torch.Tensor): Cosine values.
431
+ sin (torch.Tensor): Sine values.
432
+ position_ids (torch.Tensor): Position IDs.
433
+
434
+ Returns:
435
+ torch.Tensor: Query and key tensors with rotary position embeddings applied.
436
+ """
437
+ cos = cos.squeeze(1).squeeze(0)
438
+ sin = sin.squeeze(1).squeeze(0)
439
+ cos = cos[position_ids].unsqueeze(1)
440
+ sin = sin[position_ids].unsqueeze(1)
441
+ q_embed = (q * cos) + (rotate_half(q) * sin)
442
+ k_embed = (k * cos) + (rotate_half(k) * sin)
443
+ return q_embed, k_embed
444
+
445
+
446
+ def apply_rotary_pos_emb_L31(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
447
+ """Applies Rotary Position Embedding to the query and key tensors.
448
+
449
+ Args:
450
+ q (`torch.Tensor`): The query tensor.
451
+ k (`torch.Tensor`): The key tensor.
452
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
453
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
454
+ position_ids (`torch.Tensor`, *optional*):
455
+ Deprecated and unused.
456
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
457
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
458
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
459
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
460
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
461
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
462
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
463
+ Returns:
464
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
465
+ """
466
+ cos = cos.unsqueeze(unsqueeze_dim)
467
+ sin = sin.unsqueeze(unsqueeze_dim)
468
+ q_embed = (q * cos) + (rotate_half(q) * sin)
469
+ k_embed = (k * cos) + (rotate_half(k) * sin)
470
+ return q_embed, k_embed
471
+
472
+
473
+ class LlamaMLP(nn.Module):
474
+ """
475
+ LlamaMLP is a multi-layer perceptron module used in the Llama model.
476
+
477
+ Args:
478
+ config: The configuration for the MLP.
479
+
480
+ Attributes:
481
+ pretraining_tp (int): The pretraining time periods.
482
+ hidden_size (int): The size of the hidden layer.
483
+ intermediate_size (int): The size of the intermediate layer.
484
+ gate_proj (nn.Linear): The linear projection for gating.
485
+ up_proj (nn.Linear): The linear projection for the up projection.
486
+ down_proj (nn.Linear): The linear projection for the down projection.
487
+ act_fn: The activation function.
488
+
489
+ """
490
+
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.pretraining_tp = config.pretraining_tp
494
+ self.hidden_size = config.hidden_size
495
+ self.intermediate_size = config.intermediate_size
496
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
497
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
498
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
499
+ self.act_fn = ACT2FN[config.hidden_act]
500
+
501
+ def forward(self, x):
502
+ """
503
+ Forward pass of the MLP.
504
+
505
+ Args:
506
+ x: Input tensor.
507
+
508
+ Returns:
509
+ torch.Tensor: Output tensor.
510
+ """
511
+ if self.pretraining_tp > 1:
512
+ slice = self.intermediate_size // self.pretraining_tp
513
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
514
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
515
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
516
+
517
+ gate_proj = torch.cat(
518
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)],
519
+ dim=-1,
520
+ )
521
+ up_proj = torch.cat(
522
+ [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)],
523
+ dim=-1,
524
+ )
525
+
526
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
527
+ down_proj = [
528
+ F.linear(intermediate_states[i], down_proj_slices[i])
529
+ for i in range(self.pretraining_tp)
530
+ ]
531
+ down_proj = sum(down_proj)
532
+ else:
533
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
534
+
535
+ return down_proj
536
+
537
+
538
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
539
+ """
540
+ Repeat key and value tensors n times along the specified dimension.
541
+
542
+ Args:
543
+ hidden_states (torch.Tensor): Input tensor with shape (batch, num_key_value_heads, seqlen, head_dim).
544
+ n_rep (int): Number of times to repeat.
545
+
546
+ Returns:
547
+ torch.Tensor: Repeated tensor with shape (batch, num_key_value_heads * n_rep, seqlen, head_dim).
548
+ """
549
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
550
+ if n_rep == 1:
551
+ return hidden_states
552
+ hidden_states = hidden_states[:, :, None, :, :].expand(
553
+ batch, num_key_value_heads, n_rep, slen, head_dim
554
+ )
555
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
556
+
557
+
558
+ class HackMiniCPMLongRoPE(LlamaRotaryEmbedding):
559
+ """https://huggingface.co/openbmb/MiniCPM4.1-8B/blob/main/modeling_minicpm.py"""
560
+ """Extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
561
+
562
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, short_factor=None, long_factor=None, original_max_position_embeddings=None):
563
+ self.short_factor = short_factor
564
+ self.long_factor = long_factor
565
+ self.original_max_position_embeddings = original_max_position_embeddings
566
+ scale = (max_position_embeddings / self.original_max_position_embeddings)
567
+ self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
568
+ super().__init__(dim, max_position_embeddings, base, device)
569
+
570
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
571
+ self.max_seq_len_cached = seq_len
572
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
573
+ if seq_len > self.original_max_position_embeddings:
574
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
575
+ else:
576
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
577
+
578
+ freqs = torch.mul(
579
+ torch.outer(t, 1.0 / ext_factors).to(device=device),
580
+ self.inv_freq.to(device=device).to(dtype)
581
+ )
582
+ # # Different from paper, but it uses a different permutation in order to obtain the same calculation
583
+ # emb = torch.cat((freqs, freqs), dim=-1)
584
+ # self.register_buffer('cos_cached', emb.cos().to(dtype) * self.scaling_factor, persistent=False)
585
+ # self.register_buffer('sin_cached', emb.sin().to(dtype) * self.scaling_factor, persistent=False)
586
+
587
+
588
+ # t = t / ext_factors
589
+ # # t = t / self.scaling_factor
590
+
591
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
592
+ # # Different from paper, but it uses a different permutation in order to obtain the same calculation
593
+
594
+ # 250911
595
+ # [DIFF] shape
596
+ emb = torch.cat((freqs, freqs), dim=-1)
597
+ self.register_buffer(
598
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
599
+ )
600
+ self.register_buffer(
601
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
602
+ )
603
+
604
+ # # 250914 prev modification forgot to add scaling factor
605
+ # # [DIFF] shape
606
+ # emb = torch.cat((freqs, freqs), dim=-1)
607
+ # self.register_buffer(
608
+ # "cos_cached", emb.cos()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False
609
+ # )
610
+ # self.register_buffer(
611
+ # "sin_cached", emb.sin()[None, None, :, :].to(dtype) * self.scaling_factor, persistent=False
612
+ # )
613
+
614
+ class LlamaAttention(nn.Module):
615
+ """
616
+ LlamaAttention is a multi-headed attention module based on the 'Attention Is All You Need' paper.
617
+
618
+ Args:
619
+ config (LlamaConfig): Configuration for the attention module.
620
+
621
+ Attributes:
622
+ config (LlamaConfig): Configuration for the attention module.
623
+ hidden_size (int): The size of the hidden layer.
624
+ num_heads (int): The number of attention heads.
625
+ head_dim (int): The dimension of each attention head.
626
+ num_key_value_heads (int): The number of key-value attention heads.
627
+ num_key_value_groups (int): The number of key-value groups.
628
+ pretraining_tp (int): The pretraining time periods.
629
+ max_position_embeddings (int): The maximum position embeddings.
630
+
631
+ """
632
+
633
+ def __init__(self, config: LlamaConfig):
634
+ super().__init__()
635
+ self.config = config
636
+ self.hidden_size = config.hidden_size
637
+ self.num_heads = config.num_attention_heads
638
+ self.head_dim = self.hidden_size // self.num_heads
639
+ self.num_key_value_heads = config.num_key_value_heads
640
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
641
+ self.pretraining_tp = config.pretraining_tp
642
+ self.max_position_embeddings = config.max_position_embeddings
643
+
644
+ if (self.head_dim * self.num_heads) != self.hidden_size:
645
+ raise ValueError(
646
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
647
+ f" and `num_heads`: {self.num_heads})."
648
+ )
649
+ self.q_proj = nn.Linear(
650
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
651
+ )
652
+ self.k_proj = nn.Linear(
653
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
654
+ )
655
+ self.v_proj = nn.Linear(
656
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
657
+ )
658
+ self.o_proj = nn.Linear(
659
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
660
+ )
661
+ self._init_rope()
662
+
663
+ def _init_rope(self):
664
+ if self.config.rope_scaling is None:
665
+ self.rotary_emb = LlamaRotaryEmbedding(
666
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta
667
+ )
668
+ else:
669
+
670
+ # add: Support MiniCPM4.1-8B | JQZ 250910
671
+ try:
672
+ assert "rope_type" in self.config.rope_scaling.keys()
673
+ assert self.config.rope_scaling["rope_type"] == "longrope"
674
+ scaling_type = "longrope"
675
+ except:
676
+ scaling_type = self.config.rope_scaling["type"]
677
+ scaling_factor = self.config.rope_scaling["factor"]
678
+ # /add
679
+
680
+
681
+ # scaling_type == "longrope": # add: Support MiniCPM4.1-8B | JQZ 250910
682
+ self.rotary_emb = HackMiniCPMLongRoPE(
683
+ self.head_dim,
684
+ max_position_embeddings=self.max_position_embeddings,
685
+ short_factor=self.config.rope_scaling["short_factor"],
686
+ long_factor=self.config.rope_scaling["long_factor"],
687
+ original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"],
688
+ )
689
+
690
+ # try:
691
+ # scaling_type = self.config.rope_scaling["type"]
692
+ # scaling_factor = self.config.rope_scaling["factor"]
693
+ # if scaling_type == "linear":
694
+ # self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
695
+ # self.head_dim,
696
+ # max_position_embeddings=self.max_position_embeddings,
697
+ # scaling_factor=scaling_factor,
698
+ # base=self.config.rope_theta,
699
+ # )
700
+ # elif scaling_type == "dynamic":
701
+ # self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
702
+ # self.head_dim,
703
+ # max_position_embeddings=self.max_position_embeddings,
704
+ # scaling_factor=scaling_factor,
705
+ # base=self.config.rope_theta,
706
+ # )
707
+ # else:
708
+ # raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
709
+ # except:
710
+ # # print("For LLaMA 31")
711
+ # self.rotary_emb = LlamaRotaryEmbedding_L31(config=self.config)
712
+
713
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
714
+ return (
715
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
716
+ .transpose(1, 2)
717
+ .contiguous()
718
+ )
719
+
720
+ def forward(
721
+ self,
722
+ hidden_states: torch.Tensor,
723
+ attention_mask: Optional[torch.Tensor] = None,
724
+ position_ids: Optional[torch.LongTensor] = None,
725
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
726
+ output_attentions: bool = False,
727
+ use_cache: bool = False,
728
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
729
+ bsz, q_len, _ = hidden_states.size()
730
+
731
+ if self.pretraining_tp > 1:
732
+ key_value_slicing = (
733
+ self.num_key_value_heads * self.head_dim
734
+ ) // self.pretraining_tp
735
+ query_slices = self.q_proj.weight.split(
736
+ (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
737
+ )
738
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
739
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
740
+
741
+ query_states = [
742
+ F.linear(hidden_states, query_slices[i])
743
+ for i in range(self.pretraining_tp)
744
+ ]
745
+ query_states = torch.cat(query_states, dim=-1)
746
+
747
+ key_states = [
748
+ F.linear(hidden_states, key_slices[i])
749
+ for i in range(self.pretraining_tp)
750
+ ]
751
+ key_states = torch.cat(key_states, dim=-1)
752
+
753
+ value_states = [
754
+ F.linear(hidden_states, value_slices[i])
755
+ for i in range(self.pretraining_tp)
756
+ ]
757
+ value_states = torch.cat(value_states, dim=-1)
758
+
759
+ else:
760
+ query_states = self.q_proj(hidden_states)
761
+ key_states = self.k_proj(hidden_states)
762
+ value_states = self.v_proj(hidden_states)
763
+
764
+ query_states = query_states.view(
765
+ bsz, q_len, self.num_heads, self.head_dim
766
+ ).transpose(1, 2)
767
+ key_states = key_states.view(
768
+ bsz, q_len, self.num_key_value_heads, self.head_dim
769
+ ).transpose(1, 2)
770
+ value_states = value_states.view(
771
+ bsz, q_len, self.num_key_value_heads, self.head_dim
772
+ ).transpose(1, 2)
773
+
774
+ kv_seq_len = key_states.shape[-2]
775
+ if past_key_value is not None:
776
+ kv_seq_len += past_key_value[0].shape[-2]
777
+ if isinstance(self.rotary_emb, LlamaRotaryEmbedding_L31):
778
+ cos, sin = self.rotary_emb(query_states,position_ids)
779
+ query_states, key_states = apply_rotary_pos_emb_L31(query_states, key_states, cos, sin)
780
+ else:
781
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
782
+ query_states, key_states = apply_rotary_pos_emb(
783
+ query_states, key_states, cos, sin, position_ids
784
+ )
785
+
786
+ # [MODIFIED] Using KVCache mechanism for preallocated GPU memory optimization
787
+ # past_key_value is utilized to leverage previously computed key and value states.
788
+ # If past_key_value is available, reuse the states for k, v, and self_attention.
789
+ if past_key_value is not None:
790
+ key_states = past_key_value[0].cat(key_states, dim=2)
791
+ value_states = past_key_value[1].cat(value_states, dim=2)
792
+ # Reset past_key_value to avoid return past_key_value.
793
+ past_key_value = None
794
+
795
+ # repeat k/v heads if n_kv_heads < n_heads
796
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
797
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
798
+
799
+ attn_weights = torch.matmul(
800
+ query_states, key_states.transpose(2, 3)
801
+ ) / math.sqrt(self.head_dim)
802
+
803
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
804
+ raise ValueError(
805
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
806
+ f" {attn_weights.size()}"
807
+ )
808
+
809
+ if attention_mask is not None:
810
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
811
+ raise ValueError(
812
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
813
+ )
814
+ attn_weights = attn_weights + attention_mask
815
+
816
+ # upcast attention to fp32
817
+ attn_weights = nn.functional.softmax(
818
+ attn_weights, dim=-1, dtype=torch.float32
819
+ ).to(query_states.dtype)
820
+ attn_output = torch.matmul(attn_weights, value_states)
821
+
822
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
823
+ raise ValueError(
824
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
825
+ f" {attn_output.size()}"
826
+ )
827
+
828
+ attn_output = attn_output.transpose(1, 2).contiguous()
829
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
830
+
831
+ if self.pretraining_tp > 1:
832
+ attn_output = attn_output.split(
833
+ self.hidden_size // self.pretraining_tp, dim=2
834
+ )
835
+ o_proj_slices = self.o_proj.weight.split(
836
+ self.hidden_size // self.pretraining_tp, dim=1
837
+ )
838
+ attn_output = sum(
839
+ [
840
+ F.linear(attn_output[i], o_proj_slices[i])
841
+ for i in range(self.pretraining_tp)
842
+ ]
843
+ )
844
+ else:
845
+ attn_output = self.o_proj(attn_output)
846
+
847
+ if not output_attentions:
848
+ attn_weights = None
849
+
850
+ return attn_output, attn_weights, past_key_value
851
+
852
+
853
+ class LlamaDecoderLayer(nn.Module):
854
+ """
855
+ LlamaDecoderLayer represents a single layer of the Llama decoder.
856
+
857
+ Args:
858
+ config (LlamaConfig): Configuration for the decoder layer.
859
+
860
+ Attributes:
861
+ hidden_size (int): The size of the hidden layer.
862
+ self_attn (LlamaAttention): Multi-headed self-attention module.
863
+ mlp (LlamaMLP): Multi-layer perceptron module.
864
+ input_layernorm (LlamaRMSNorm): Layer normalization for input.
865
+ post_attention_layernorm (LlamaRMSNorm): Layer normalization after self-attention.
866
+ """
867
+
868
+ def __init__(self, config: LlamaConfig):
869
+ super().__init__()
870
+ self.hidden_size = config.hidden_size
871
+ self.self_attn = LlamaAttention(config=config)
872
+ self.mlp = LlamaMLP(config)
873
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
874
+ self.post_attention_layernorm = LlamaRMSNorm(
875
+ config.hidden_size, eps=config.rms_norm_eps
876
+ )
877
+
878
+ def forward(
879
+ self,
880
+ hidden_states: torch.Tensor,
881
+ attention_mask: Optional[torch.Tensor] = None,
882
+ position_ids: Optional[torch.LongTensor] = None,
883
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
884
+ output_attentions: Optional[bool] = False,
885
+ use_cache: Optional[bool] = False,
886
+ ) -> Tuple[
887
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
888
+ ]:
889
+ """
890
+ Forward pass for the LlamaDecoderLayer.
891
+
892
+ Args:
893
+ hidden_states (torch.FloatTensor): Input tensor of shape `(batch, seq_len, embed_dim)`.
894
+ attention_mask (torch.FloatTensor, optional): Attention mask of size
895
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
896
+ position_ids (torch.LongTensor, optional): Positional IDs tensor.
897
+ past_key_value (Tuple[torch.FloatTensor], optional): Cached past key and value projection states.
898
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers.
899
+ use_cache (bool, optional): If set to `True`, `past_key_values` key-value states are returned and can be
900
+ used to speed up decoding.
901
+
902
+ Returns:
903
+ Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Tuple containing:
904
+ - hidden_states (torch.FloatTensor): Output tensor.
905
+ - self_attn_weights (Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]): Self-attention weights if
906
+ `output_attentions` is `True`.
907
+ - present_key_value (Optional[Tuple[torch.FloatTensor]]): Cached key and value projection states if
908
+ `use_cache` is `True`.
909
+ """
910
+
911
+ residual = hidden_states
912
+
913
+ hidden_states = self.input_layernorm(hidden_states)
914
+
915
+ # Self Attention
916
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
917
+ hidden_states=hidden_states,
918
+ attention_mask=attention_mask,
919
+ position_ids=position_ids,
920
+ past_key_value=past_key_value,
921
+ output_attentions=output_attentions,
922
+ use_cache=use_cache,
923
+ )
924
+ hidden_states = residual + hidden_states
925
+
926
+ # Fully Connected
927
+ residual = hidden_states
928
+ hidden_states = self.post_attention_layernorm(hidden_states)
929
+ hidden_states = self.mlp(hidden_states)
930
+ hidden_states = residual + hidden_states
931
+
932
+ outputs = (hidden_states,)
933
+
934
+ if output_attentions:
935
+ outputs += (self_attn_weights,)
936
+
937
+ if use_cache:
938
+ outputs += (present_key_value,)
939
+
940
+ return outputs
941
+
942
+
943
+ LLAMA_START_DOCSTRING = r"""
944
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
945
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
946
+ etc.)
947
+
948
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
949
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
950
+ and behavior.
951
+
952
+ Parameters:
953
+ config ([`LlamaConfig`]):
954
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
955
+ load the weights associated with the model, only the configuration. Check out the
956
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
957
+ """
958
+
959
+
960
+ @add_start_docstrings(
961
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
962
+ LLAMA_START_DOCSTRING,
963
+ )
964
+ class LlamaPreTrainedModel(PreTrainedModel):
965
+ config_class = LlamaConfig
966
+ base_model_prefix = "model"
967
+ supports_gradient_checkpointing = True
968
+ _no_split_modules = ["LlamaDecoderLayer"]
969
+ _skip_keys_device_placement = "past_key_values"
970
+
971
+ def _init_weights(self, module):
972
+ std = self.config.initializer_range
973
+ if isinstance(module, nn.Linear):
974
+ module.weight.data.normal_(mean=0.0, std=std)
975
+ if module.bias is not None:
976
+ module.bias.data.zero_()
977
+ elif isinstance(module, nn.Embedding):
978
+ module.weight.data.normal_(mean=0.0, std=std)
979
+ if module.padding_idx is not None:
980
+ module.weight.data[module.padding_idx].zero_()
981
+
982
+ def _set_gradient_checkpointing(self, module, value=False):
983
+ if isinstance(module, LlamaModel):
984
+ module.gradient_checkpointing = value
985
+
986
+
987
+ LLAMA_INPUTS_DOCSTRING = r"""
988
+ Args:
989
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
990
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
991
+ it.
992
+
993
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
994
+ [`PreTrainedTokenizer.__call__`] for details.
995
+
996
+ [What are input IDs?](../glossary#input-ids)
997
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
998
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
999
+
1000
+ - 1 for tokens that are **not masked**,
1001
+ - 0 for tokens that are **masked**.
1002
+
1003
+ [What are attention masks?](../glossary#attention-mask)
1004
+
1005
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1006
+ [`PreTrainedTokenizer.__call__`] for details.
1007
+
1008
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1009
+ `past_key_values`).
1010
+
1011
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1012
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1013
+ information on the default strategy.
1014
+
1015
+ - 1 indicates the head is **not masked**,
1016
+ - 0 indicates the head is **masked**.
1017
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1018
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1019
+ config.n_positions - 1]`.
1020
+
1021
+ [What are position IDs?](../glossary#position-ids)
1022
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1023
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1024
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1025
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1026
+
1027
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1028
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1029
+
1030
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1031
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1032
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1033
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1034
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1035
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1036
+ model's internal embedding lookup matrix.
1037
+ use_cache (`bool`, *optional*):
1038
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1039
+ `past_key_values`).
1040
+ output_attentions (`bool`, *optional*):
1041
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1042
+ tensors for more detail.
1043
+ output_hidden_states (`bool`, *optional*):
1044
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1045
+ more detail.
1046
+ return_dict (`bool`, *optional*):
1047
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1048
+ """
1049
+
1050
+
1051
+ @add_start_docstrings(
1052
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1053
+ LLAMA_START_DOCSTRING,
1054
+ )
1055
+ class LlamaModel(LlamaPreTrainedModel):
1056
+ """
1057
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
1058
+
1059
+ Args:
1060
+ config: LlamaConfig
1061
+ """
1062
+
1063
+ def __init__(self, config: LlamaConfig):
1064
+ super().__init__(config)
1065
+ self.padding_idx = config.pad_token_id
1066
+ self.vocab_size = config.vocab_size
1067
+
1068
+ self.embed_tokens = nn.Embedding(
1069
+ config.vocab_size, config.hidden_size, self.padding_idx
1070
+ )
1071
+ self.layers = nn.ModuleList(
1072
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
1073
+ )
1074
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1075
+
1076
+ self.gradient_checkpointing = False
1077
+ # Initialize weights and apply final processing
1078
+ self.post_init()
1079
+
1080
+ def get_input_embeddings(self):
1081
+ return self.embed_tokens
1082
+
1083
+ def set_input_embeddings(self, value):
1084
+ self.embed_tokens = value
1085
+
1086
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
1087
+ def _prepare_decoder_attention_mask(
1088
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
1089
+ ):
1090
+ # create causal mask
1091
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1092
+ combined_attention_mask = None
1093
+ if input_shape[-1] > 1:
1094
+ combined_attention_mask = _make_causal_mask(
1095
+ input_shape,
1096
+ # inputs_embeds.dtype,
1097
+ torch.float32, # [MODIFIED] force to cast to float32
1098
+ device=inputs_embeds.device,
1099
+ past_key_values_length=past_key_values_length,
1100
+ )
1101
+
1102
+ if attention_mask is not None:
1103
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1104
+ expanded_attn_mask = _expand_mask(
1105
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1106
+ ).to(inputs_embeds.device)
1107
+ combined_attention_mask = (
1108
+ expanded_attn_mask
1109
+ if combined_attention_mask is None
1110
+ else expanded_attn_mask + combined_attention_mask
1111
+ )
1112
+
1113
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
1114
+ tree_mask = self.tree_mask
1115
+ tree_len = tree_mask.size(-1)
1116
+ combined_attention_mask[:, :, -tree_len:, -tree_len:][
1117
+ tree_mask == 0
1118
+ ] = combined_attention_mask.min()
1119
+
1120
+ return combined_attention_mask
1121
+
1122
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1123
+ def forward(
1124
+ self,
1125
+ input_ids: torch.LongTensor = None,
1126
+ attention_mask: Optional[torch.Tensor] = None,
1127
+ position_ids: Optional[torch.LongTensor] = None,
1128
+ past_key_values=None, # [MODIFIED] past_key_value is KVCache class
1129
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1130
+ use_cache: Optional[bool] = None,
1131
+ output_attentions: Optional[bool] = None,
1132
+ output_hidden_states: Optional[bool] = None,
1133
+ return_dict: Optional[bool] = None,
1134
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1135
+ output_attentions = (
1136
+ output_attentions
1137
+ if output_attentions is not None
1138
+ else self.config.output_attentions
1139
+ )
1140
+ output_hidden_states = (
1141
+ output_hidden_states
1142
+ if output_hidden_states is not None
1143
+ else self.config.output_hidden_states
1144
+ )
1145
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1146
+
1147
+ return_dict = (
1148
+ return_dict if return_dict is not None else self.config.use_return_dict
1149
+ )
1150
+
1151
+ # retrieve input_ids and inputs_embeds
1152
+ if input_ids is not None and inputs_embeds is not None:
1153
+ raise ValueError(
1154
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1155
+ )
1156
+ elif input_ids is not None:
1157
+ batch_size, seq_length = input_ids.shape
1158
+ elif inputs_embeds is not None:
1159
+ batch_size, seq_length, _ = inputs_embeds.shape
1160
+ else:
1161
+ raise ValueError(
1162
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1163
+ )
1164
+
1165
+ seq_length_with_past = seq_length
1166
+ past_key_values_length = 0
1167
+
1168
+ if past_key_values is not None:
1169
+ past_key_values_length = past_key_values[0][0].shape[2]
1170
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1171
+
1172
+ if position_ids is None:
1173
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1174
+ position_ids = torch.arange(
1175
+ past_key_values_length,
1176
+ seq_length + past_key_values_length,
1177
+ dtype=torch.long,
1178
+ device=device,
1179
+ )
1180
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1181
+ else:
1182
+ position_ids = position_ids.view(-1, seq_length).long()
1183
+
1184
+ if inputs_embeds is None:
1185
+ inputs_embeds = self.embed_tokens(input_ids)
1186
+ # embed positions
1187
+ if attention_mask is None:
1188
+ attention_mask = torch.ones(
1189
+ (batch_size, seq_length_with_past),
1190
+ dtype=torch.bool,
1191
+ device=inputs_embeds.device,
1192
+ )
1193
+ attention_mask = self._prepare_decoder_attention_mask(
1194
+ attention_mask,
1195
+ (batch_size, seq_length),
1196
+ inputs_embeds,
1197
+ past_key_values_length,
1198
+ )
1199
+
1200
+ hidden_states = inputs_embeds
1201
+
1202
+ if self.gradient_checkpointing and self.training:
1203
+ if use_cache:
1204
+ logger.warning_once(
1205
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1206
+ )
1207
+ use_cache = False
1208
+
1209
+ # decoder layers
1210
+ all_hidden_states = () if 1 else None
1211
+ all_self_attns = () if output_attentions else None
1212
+ next_decoder_cache = () if use_cache else None
1213
+
1214
+ for idx, decoder_layer in enumerate(self.layers):
1215
+ if idx==len(self.layers)-3 or idx==len(self.layers)//2 or idx==2:
1216
+ all_hidden_states += (hidden_states,)
1217
+
1218
+ past_key_value = (
1219
+ past_key_values[idx] if past_key_values is not None else None
1220
+ )
1221
+
1222
+ if self.gradient_checkpointing and self.training:
1223
+
1224
+ def create_custom_forward(module):
1225
+ def custom_forward(*inputs):
1226
+ # None for past_key_value
1227
+ return module(*inputs, output_attentions, None)
1228
+
1229
+ return custom_forward
1230
+
1231
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1232
+ create_custom_forward(decoder_layer),
1233
+ hidden_states,
1234
+ attention_mask,
1235
+ position_ids,
1236
+ None,
1237
+ )
1238
+ else:
1239
+ layer_outputs = decoder_layer(
1240
+ hidden_states,
1241
+ attention_mask=attention_mask,
1242
+ position_ids=position_ids,
1243
+ past_key_value=past_key_value,
1244
+ output_attentions=output_attentions,
1245
+ use_cache=use_cache,
1246
+ )
1247
+
1248
+ hidden_states = layer_outputs[0]
1249
+
1250
+ if use_cache:
1251
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1252
+
1253
+ if output_attentions:
1254
+ all_self_attns += (layer_outputs[1],)
1255
+
1256
+ hidden_states = self.norm(hidden_states)
1257
+
1258
+ # add hidden states from the last decoder layer
1259
+ if output_hidden_states:
1260
+ all_hidden_states += (hidden_states,)
1261
+
1262
+ # !!!
1263
+ # all_hidden_states += (hidden_states,)
1264
+
1265
+ next_cache = next_decoder_cache if use_cache else None
1266
+ if not return_dict:
1267
+ return tuple(
1268
+ v
1269
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1270
+ if v is not None
1271
+ )
1272
+ return BaseModelOutputWithPast(
1273
+ last_hidden_state=hidden_states,
1274
+ past_key_values=next_cache,
1275
+ hidden_states=all_hidden_states,
1276
+ attentions=all_self_attns,
1277
+ )
1278
+
1279
+
1280
+ class LlamaForCausalLM(LlamaPreTrainedModel):
1281
+ _tied_weights_keys = ["lm_head.weight"]
1282
+
1283
+ def __init__(self, config):
1284
+ super().__init__(config)
1285
+ self.model = LlamaModel(config)
1286
+ self.pretraining_tp = config.pretraining_tp
1287
+ self.vocab_size = config.vocab_size
1288
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1289
+
1290
+ # Initialize weights and apply final processing
1291
+ self.post_init()
1292
+
1293
+ def get_input_embeddings(self):
1294
+ return self.model.embed_tokens
1295
+
1296
+ def set_input_embeddings(self, value):
1297
+ self.model.embed_tokens = value
1298
+
1299
+ def get_output_embeddings(self):
1300
+ return self.lm_head
1301
+
1302
+ def set_output_embeddings(self, new_embeddings):
1303
+ self.lm_head = new_embeddings
1304
+
1305
+ def set_decoder(self, decoder):
1306
+ self.model = decoder
1307
+
1308
+ def get_decoder(self):
1309
+ return self.model
1310
+
1311
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1312
+ @replace_return_docstrings(
1313
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1314
+ )
1315
+ def forward(
1316
+ self,
1317
+ input_ids: torch.LongTensor = None,
1318
+ attention_mask: Optional[torch.Tensor] = None,
1319
+ position_ids: Optional[torch.LongTensor] = None,
1320
+ past_key_values=None, # [MODIFIED] past_key_value is KVCache class
1321
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1322
+ labels: Optional[torch.LongTensor] = None,
1323
+ use_cache: Optional[bool] = None,
1324
+ output_attentions: Optional[bool] = None,
1325
+ output_hidden_states: Optional[bool] = None,
1326
+ return_dict: Optional[bool] = None,
1327
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1328
+ r"""
1329
+ Args:
1330
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1331
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1332
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1333
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1334
+
1335
+ Returns:
1336
+
1337
+ Example:
1338
+
1339
+ ```python
1340
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1341
+
1342
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1343
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1344
+
1345
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1346
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1347
+
1348
+ >>> # Generate
1349
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1350
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1351
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1352
+ ```"""
1353
+
1354
+ output_attentions = (
1355
+ output_attentions
1356
+ if output_attentions is not None
1357
+ else self.config.output_attentions
1358
+ )
1359
+ output_hidden_states = (
1360
+ output_hidden_states
1361
+ if output_hidden_states is not None
1362
+ else self.config.output_hidden_states
1363
+ )
1364
+ return_dict = (
1365
+ return_dict if return_dict is not None else self.config.use_return_dict
1366
+ )
1367
+
1368
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1369
+ outputs = self.model(
1370
+ input_ids=input_ids,
1371
+ attention_mask=attention_mask,
1372
+ position_ids=position_ids,
1373
+ past_key_values=past_key_values,
1374
+ inputs_embeds=inputs_embeds,
1375
+ use_cache=use_cache,
1376
+ output_attentions=output_attentions,
1377
+ output_hidden_states=output_hidden_states,
1378
+ return_dict=return_dict,
1379
+ )
1380
+
1381
+ hidden_states = outputs[0]
1382
+ if self.pretraining_tp > 1:
1383
+ lm_head_slices = self.lm_head.weight.split(
1384
+ self.vocab_size // self.pretraining_tp, dim=0
1385
+ )
1386
+ logits = [
1387
+ F.linear(hidden_states, lm_head_slices[i])
1388
+ for i in range(self.pretraining_tp)
1389
+ ]
1390
+ logits = torch.cat(logits, dim=-1)
1391
+ else:
1392
+ logits = self.lm_head(hidden_states)
1393
+ logits = logits.float()
1394
+
1395
+ loss = None
1396
+ if labels is not None:
1397
+ # Shift so that tokens < n predict n
1398
+ shift_logits = logits[..., :-1, :].contiguous()
1399
+ shift_labels = labels[..., 1:].contiguous()
1400
+ # Flatten the tokens
1401
+ loss_fct = CrossEntropyLoss()
1402
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1403
+ shift_labels = shift_labels.view(-1)
1404
+ # Enable model parallelism
1405
+ shift_labels = shift_labels.to(shift_logits.device)
1406
+ loss = loss_fct(shift_logits, shift_labels)
1407
+
1408
+ if not return_dict:
1409
+ output = (logits,) + outputs[1:]
1410
+ return (loss,) + output if loss is not None else output
1411
+
1412
+ return CausalLMOutputWithPast(
1413
+ loss=loss,
1414
+ logits=logits,
1415
+ past_key_values=outputs.past_key_values,
1416
+ hidden_states=outputs.hidden_states,
1417
+ attentions=outputs.attentions,
1418
+ )
1419
+
1420
+ def prepare_inputs_for_generation(
1421
+ self,
1422
+ input_ids,
1423
+ past_key_values=None,
1424
+ attention_mask=None,
1425
+ inputs_embeds=None,
1426
+ **kwargs,
1427
+ ):
1428
+ if past_key_values:
1429
+ input_ids = input_ids[:, -1:]
1430
+
1431
+ position_ids = kwargs.get("position_ids", None)
1432
+ if attention_mask is not None and position_ids is None:
1433
+ # create position_ids on the fly for batch generation
1434
+ position_ids = attention_mask.long().cumsum(-1) - 1
1435
+ position_ids.masked_fill_(attention_mask == 0, 1)
1436
+ if past_key_values:
1437
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1438
+
1439
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1440
+ if inputs_embeds is not None and past_key_values is None:
1441
+ model_inputs = {"inputs_embeds": inputs_embeds}
1442
+ else:
1443
+ model_inputs = {"input_ids": input_ids}
1444
+
1445
+ model_inputs.update(
1446
+ {
1447
+ "position_ids": position_ids,
1448
+ "past_key_values": past_key_values,
1449
+ "use_cache": kwargs.get("use_cache"),
1450
+ "attention_mask": attention_mask,
1451
+ }
1452
+ )
1453
+ return model_inputs
1454
+
1455
+ @staticmethod
1456
+ def _reorder_cache(past_key_values, beam_idx):
1457
+ reordered_past = ()
1458
+ for layer_past in past_key_values:
1459
+ reordered_past += (
1460
+ tuple(
1461
+ past_state.index_select(0, beam_idx.to(past_state.device))
1462
+ for past_state in layer_past
1463
+ ),
1464
+ )
1465
+ return reordered_past
1466
+
1467
+
1468
+ @add_start_docstrings(
1469
+ """
1470
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1471
+
1472
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1473
+ (e.g. GPT-2) do.
1474
+
1475
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1476
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1477
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1478
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1479
+ each row of the batch).
1480
+ """,
1481
+ LLAMA_START_DOCSTRING,
1482
+ )
1483
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1484
+ def __init__(self, config):
1485
+ super().__init__(config)
1486
+ self.num_labels = config.num_labels
1487
+ self.model = LlamaModel(config)
1488
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1489
+
1490
+ # Initialize weights and apply final processing
1491
+ self.post_init()
1492
+
1493
+ def get_input_embeddings(self):
1494
+ return self.model.embed_tokens
1495
+
1496
+ def set_input_embeddings(self, value):
1497
+ self.model.embed_tokens = value
1498
+
1499
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1500
+ def forward(
1501
+ self,
1502
+ input_ids: torch.LongTensor = None,
1503
+ attention_mask: Optional[torch.Tensor] = None,
1504
+ position_ids: Optional[torch.LongTensor] = None,
1505
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1506
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1507
+ labels: Optional[torch.LongTensor] = None,
1508
+ use_cache: Optional[bool] = None,
1509
+ output_attentions: Optional[bool] = None,
1510
+ output_hidden_states: Optional[bool] = None,
1511
+ return_dict: Optional[bool] = None,
1512
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1513
+ r"""
1514
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1515
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1516
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1517
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1518
+ """
1519
+ return_dict = (
1520
+ return_dict if return_dict is not None else self.config.use_return_dict
1521
+ )
1522
+
1523
+ transformer_outputs = self.model(
1524
+ input_ids,
1525
+ attention_mask=attention_mask,
1526
+ position_ids=position_ids,
1527
+ past_key_values=past_key_values,
1528
+ inputs_embeds=inputs_embeds,
1529
+ use_cache=use_cache,
1530
+ output_attentions=output_attentions,
1531
+ output_hidden_states=output_hidden_states,
1532
+ return_dict=return_dict,
1533
+ )
1534
+ hidden_states = transformer_outputs[0]
1535
+ logits = self.score(hidden_states)
1536
+
1537
+ if input_ids is not None:
1538
+ batch_size = input_ids.shape[0]
1539
+ else:
1540
+ batch_size = inputs_embeds.shape[0]
1541
+
1542
+ if self.config.pad_token_id is None and batch_size != 1:
1543
+ raise ValueError(
1544
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1545
+ )
1546
+ if self.config.pad_token_id is None:
1547
+ sequence_lengths = -1
1548
+ else:
1549
+ if input_ids is not None:
1550
+ sequence_lengths = (
1551
+ torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1552
+ ).to(logits.device)
1553
+ else:
1554
+ sequence_lengths = -1
1555
+
1556
+ pooled_logits = logits[
1557
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1558
+ ]
1559
+
1560
+ loss = None
1561
+ if labels is not None:
1562
+ labels = labels.to(logits.device)
1563
+ if self.config.problem_type is None:
1564
+ if self.num_labels == 1:
1565
+ self.config.problem_type = "regression"
1566
+ elif self.num_labels > 1 and (
1567
+ labels.dtype == torch.long or labels.dtype == torch.int
1568
+ ):
1569
+ self.config.problem_type = "single_label_classification"
1570
+ else:
1571
+ self.config.problem_type = "multi_label_classification"
1572
+
1573
+ if self.config.problem_type == "regression":
1574
+ loss_fct = MSELoss()
1575
+ if self.num_labels == 1:
1576
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1577
+ else:
1578
+ loss = loss_fct(pooled_logits, labels)
1579
+ elif self.config.problem_type == "single_label_classification":
1580
+ loss_fct = CrossEntropyLoss()
1581
+ loss = loss_fct(
1582
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1583
+ )
1584
+ elif self.config.problem_type == "multi_label_classification":
1585
+ loss_fct = BCEWithLogitsLoss()
1586
+ loss = loss_fct(pooled_logits, labels)
1587
+ if not return_dict:
1588
+ output = (pooled_logits,) + transformer_outputs[1:]
1589
+ return ((loss,) + output) if loss is not None else output
1590
+
1591
+ return SequenceClassifierOutputWithPast(
1592
+ loss=loss,
1593
+ logits=pooled_logits,
1594
+ past_key_values=transformer_outputs.past_key_values,
1595
+ hidden_states=transformer_outputs.hidden_states,
1596
+ attentions=transformer_outputs.attentions,
1597
+ )
eagle/model/modeling_minicpm_kv.py ADDED
The diff for this file is too large to render. See raw diff
 
eagle/model/modeling_mixtral_kv.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Mixtral model."""
21
+ import inspect
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+ from .kv_cache import KVCache
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
+
33
+ # [MODIFIED] Import from transformer library
34
+ from transformers.activations import ACT2FN
35
+
36
+ from transformers.modeling_outputs import (
37
+ MoeCausalLMOutputWithPast,
38
+ MoeModelOutputWithPast,
39
+ SequenceClassifierOutputWithPast,
40
+ )
41
+ from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.utils import (
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from transformers import MixtralConfig
49
+
50
+
51
+
52
+
53
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
54
+ # It means that the function will not be traced through and simply appear as a node in the graph.
55
+
56
+
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+ _CONFIG_FOR_DOC = "MixtralConfig"
61
+
62
+
63
+ def _make_causal_mask(
64
+ input_ids_shape: torch.Size,
65
+ dtype: torch.dtype,
66
+ device: torch.device,
67
+ past_key_values_length: int = 0,
68
+ ):
69
+ """
70
+ Create a causal mask for bi-directional self-attention.
71
+
72
+ Args:
73
+ input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len).
74
+ dtype (torch.dtype): The data type of the mask.
75
+ device (torch.device): The device on which the mask will be placed.
76
+ past_key_values_length (int, optional): The length of past key values. Default is 0.
77
+
78
+ Returns:
79
+ torch.Tensor: The causal mask tensor.
80
+ """
81
+ bsz, tgt_len = input_ids_shape
82
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
83
+ mask_cond = torch.arange(mask.size(-1), device=device)
84
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
85
+ mask = mask.to(dtype)
86
+
87
+ if past_key_values_length > 0:
88
+ mask = torch.cat(
89
+ [
90
+ torch.zeros(
91
+ tgt_len, past_key_values_length, dtype=dtype, device=device
92
+ ),
93
+ mask,
94
+ ],
95
+ dim=-1,
96
+ )
97
+ return mask[None, None, :, :].expand(
98
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
99
+ )
100
+
101
+
102
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
103
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
104
+ """
105
+ Expand attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
106
+
107
+ Args:
108
+ mask (torch.Tensor): The attention mask tensor of shape `[bsz, seq_len]`.
109
+ dtype (torch.dtype): The data type of the mask.
110
+ tgt_len (Optional[int], optional): The target sequence length. If None, it defaults to the source sequence length.
111
+
112
+ Returns:
113
+ torch.Tensor: The expanded mask tensor.
114
+ """
115
+ bsz, src_len = mask.size()
116
+ tgt_len = tgt_len if tgt_len is not None else src_len
117
+
118
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
119
+
120
+ inverted_mask = 1.0 - expanded_mask
121
+
122
+ return inverted_mask.masked_fill(
123
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
124
+ )
125
+
126
+
127
+ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
128
+ r"""
129
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
130
+
131
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
132
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
133
+ experts is too unbalanced.
134
+
135
+ Args:
136
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
137
+ Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts].
138
+ num_experts (`int`, *optional*):
139
+ Number of experts
140
+
141
+ Returns:
142
+ The auxiliary loss.
143
+ """
144
+ if gate_logits is None:
145
+ return 0
146
+
147
+ if isinstance(gate_logits, tuple):
148
+ # cat along the layers?
149
+ compute_device = gate_logits[0].device
150
+ gate_logits = torch.cat([gate.to(compute_device) for gate in gate_logits], dim=0)
151
+
152
+ routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
153
+ routing_weights = routing_weights.softmax(dim=-1)
154
+
155
+ # cast the expert indices to int64, otherwise one-hot encoding will fail
156
+ if selected_experts.dtype != torch.int64:
157
+ selected_experts = selected_experts.to(torch.int64)
158
+
159
+ if len(selected_experts.shape) == 2:
160
+ selected_experts = selected_experts.unsqueeze(2)
161
+
162
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
163
+
164
+ # For a given token, determine if it was routed to a given expert.
165
+ expert_mask = torch.max(expert_mask, axis=-2).values
166
+
167
+ # cast to float32 otherwise mean will fail
168
+ expert_mask = expert_mask.to(torch.float32)
169
+ tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
170
+
171
+ router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1)
172
+ return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)) * (num_experts**2)
173
+
174
+
175
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
176
+ def _get_unpad_data(attention_mask):
177
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
178
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
179
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
180
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
181
+ return (
182
+ indices,
183
+ cu_seqlens,
184
+ max_seqlen_in_batch,
185
+ )
186
+
187
+
188
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
189
+ class MixtralRMSNorm(nn.Module):
190
+ def __init__(self, hidden_size, eps=1e-6):
191
+ """
192
+ MixtralRMSNorm is equivalent to T5LayerNorm
193
+ """
194
+ super().__init__()
195
+ self.weight = nn.Parameter(torch.ones(hidden_size))
196
+ self.variance_epsilon = eps
197
+
198
+ def forward(self, hidden_states):
199
+ input_dtype = hidden_states.dtype
200
+ hidden_states = hidden_states.to(torch.float32)
201
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
202
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
203
+ return self.weight * hidden_states.to(input_dtype)
204
+
205
+
206
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral
207
+ class MixtralRotaryEmbedding(nn.Module):
208
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
209
+ super().__init__()
210
+
211
+ self.dim = dim
212
+ self.max_position_embeddings = max_position_embeddings
213
+ self.base = base
214
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
215
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
216
+
217
+ # Build here to make `torch.jit.trace` work.
218
+ self._set_cos_sin_cache(
219
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
220
+ )
221
+
222
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
223
+ self.max_seq_len_cached = seq_len
224
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
225
+
226
+ freqs = torch.outer(t, self.inv_freq)
227
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
228
+ emb = torch.cat((freqs, freqs), dim=-1)
229
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
230
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
231
+
232
+ def forward(self, x, seq_len=None):
233
+ # x: [bs, num_attention_heads, seq_len, head_size]
234
+ if seq_len > self.max_seq_len_cached:
235
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
236
+
237
+ return (
238
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
239
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
240
+ )
241
+
242
+
243
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
244
+ def rotate_half(x):
245
+ """Rotates half the hidden dims of the input."""
246
+ x1 = x[..., : x.shape[-1] // 2]
247
+ x2 = x[..., x.shape[-1] // 2 :]
248
+ return torch.cat((-x2, x1), dim=-1)
249
+
250
+
251
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
252
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
253
+ """Applies Rotary Position Embedding to the query and key tensors.
254
+
255
+ Args:
256
+ q (`torch.Tensor`): The query tensor.
257
+ k (`torch.Tensor`): The key tensor.
258
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
259
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
260
+ position_ids (`torch.Tensor`):
261
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
262
+ used to pass offsetted position ids when working with a KV-cache.
263
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
264
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
265
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
266
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
267
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
268
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
269
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
270
+ Returns:
271
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
272
+ """
273
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
274
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
275
+ q_embed = (q * cos) + (rotate_half(q) * sin)
276
+ k_embed = (k * cos) + (rotate_half(k) * sin)
277
+ return q_embed, k_embed
278
+
279
+
280
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
281
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
282
+ """
283
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
284
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
285
+ """
286
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
287
+ if n_rep == 1:
288
+ return hidden_states
289
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
290
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
291
+
292
+
293
+ # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
294
+ class MixtralAttention(nn.Module):
295
+ """
296
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
297
+ and "Generating Long Sequences with Sparse Transformers".
298
+ """
299
+
300
+ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
301
+ super().__init__()
302
+ self.config = config
303
+ self.layer_idx = layer_idx
304
+ if layer_idx is None:
305
+ logger.warning_once(
306
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
307
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
308
+ "when creating this class."
309
+ )
310
+
311
+ self.hidden_size = config.hidden_size
312
+ self.num_heads = config.num_attention_heads
313
+ self.head_dim = self.hidden_size // self.num_heads
314
+ self.num_key_value_heads = config.num_key_value_heads
315
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
316
+ self.max_position_embeddings = config.max_position_embeddings
317
+ self.rope_theta = config.rope_theta
318
+ self.is_causal = True
319
+ self.attention_dropout = config.attention_dropout
320
+
321
+ if (self.head_dim * self.num_heads) != self.hidden_size:
322
+ raise ValueError(
323
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
324
+ f" and `num_heads`: {self.num_heads})."
325
+ )
326
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
327
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
328
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
329
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
330
+
331
+ self.rotary_emb = MixtralRotaryEmbedding(
332
+ self.head_dim,
333
+ max_position_embeddings=self.max_position_embeddings,
334
+ base=self.rope_theta,
335
+ )
336
+
337
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
338
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
339
+
340
+ def forward(
341
+ self,
342
+ hidden_states: torch.Tensor,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ position_ids: Optional[torch.LongTensor] = None,
345
+ past_key_value: Optional[Tuple[KVCache]] = None,
346
+ output_attentions: bool = False,
347
+ use_cache: bool = False,
348
+ **kwargs,
349
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
350
+ if "padding_mask" in kwargs:
351
+ warnings.warn(
352
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
353
+ )
354
+ bsz, q_len, _ = hidden_states.size()
355
+
356
+ query_states = self.q_proj(hidden_states)
357
+ key_states = self.k_proj(hidden_states)
358
+ value_states = self.v_proj(hidden_states)
359
+
360
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
361
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
362
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
363
+
364
+ kv_seq_len = key_states.shape[-2]
365
+ if past_key_value is not None:
366
+ if self.layer_idx is None:
367
+ raise ValueError(
368
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
369
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
370
+ "with a layer index."
371
+ )
372
+ kv_seq_len += past_key_value[0].shape[-2]
373
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
374
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
375
+
376
+ if past_key_value is not None:
377
+ key_states = past_key_value[0].cat(key_states, dim=2)
378
+ value_states = past_key_value[1].cat(value_states, dim=2)
379
+
380
+ # repeat k/v heads if n_kv_heads < n_heads
381
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
382
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
383
+
384
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
385
+
386
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
387
+ raise ValueError(
388
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
389
+ f" {attn_weights.size()}"
390
+ )
391
+
392
+ if attention_mask is not None:
393
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
394
+ raise ValueError(
395
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
396
+ )
397
+
398
+ attn_weights = attn_weights + attention_mask
399
+
400
+ # upcast attention to fp32
401
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
402
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
403
+ attn_output = torch.matmul(attn_weights, value_states)
404
+
405
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
406
+ raise ValueError(
407
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
408
+ f" {attn_output.size()}"
409
+ )
410
+
411
+ attn_output = attn_output.transpose(1, 2).contiguous()
412
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
413
+
414
+ attn_output = self.o_proj(attn_output)
415
+
416
+ if not output_attentions:
417
+ attn_weights = None
418
+
419
+ return attn_output, attn_weights, past_key_value
420
+
421
+
422
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
423
+
424
+
425
+
426
+ class MixtralBLockSparseTop2MLP(nn.Module):
427
+ def __init__(self, config: MixtralConfig):
428
+ super().__init__()
429
+ self.ffn_dim = config.intermediate_size
430
+ self.hidden_dim = config.hidden_size
431
+
432
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
433
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
434
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
435
+
436
+ self.act_fn = ACT2FN[config.hidden_act]
437
+
438
+ def forward(self, hidden_states, routing_weights):
439
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
440
+ current_hidden_states = self.w2(current_hidden_states)
441
+ return routing_weights * current_hidden_states
442
+
443
+
444
+ MISTRAL_ATTENTION_CLASSES = {
445
+ "eager": MixtralAttention,
446
+ }
447
+
448
+
449
+ class MixtralSparseMoeBlock(nn.Module):
450
+ """
451
+ This implementation is
452
+ strictly equivalent to standard MoE with full capacity (no
453
+ dropped tokens). It's faster since it formulates MoE operations
454
+ in terms of block-sparse operations to accomodate imbalanced
455
+ assignments of tokens to experts, whereas standard MoE either
456
+ (1) drop tokens at the cost of reduced performance or (2) set
457
+ capacity factor to number of experts and thus waste computation
458
+ and memory on padding.
459
+ """
460
+
461
+ def __init__(self, config):
462
+ super().__init__()
463
+ self.hidden_dim = config.hidden_size
464
+ self.ffn_dim = config.intermediate_size
465
+ self.num_experts = config.num_local_experts
466
+ self.top_k = config.num_experts_per_tok
467
+
468
+ # gating
469
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
470
+
471
+ self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
472
+
473
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
474
+ """ """
475
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
476
+ hidden_states = hidden_states.view(-1, hidden_dim)
477
+ # router_logits: (batch * sequence_length, n_experts)
478
+ router_logits = self.gate(hidden_states)
479
+
480
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
481
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
482
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
483
+ # we cast back to the input dtype
484
+ routing_weights = routing_weights.to(hidden_states.dtype)
485
+
486
+ final_hidden_states = torch.zeros(
487
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
488
+ )
489
+
490
+ # One hot encode the selected experts to create an expert mask
491
+ # this will be used to easily index which expert is going to be sollicitated
492
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
493
+
494
+ # Loop over all available experts in the model and perform the computation on each expert
495
+ for expert_idx in range(self.num_experts):
496
+ expert_layer = self.experts[expert_idx]
497
+ idx, top_x = torch.where(expert_mask[expert_idx])
498
+
499
+ if top_x.shape[0] == 0:
500
+ continue
501
+
502
+ # in torch it is faster to index using lists than torch tensors
503
+ top_x_list = top_x.tolist()
504
+ idx_list = idx.tolist()
505
+
506
+ # Index the correct hidden states and compute the expert hidden state for
507
+ # the current expert. We need to make sure to multiply the output hidden
508
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
509
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
510
+ current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
511
+
512
+ # However `index_add_` only support torch tensors for indexing so we'll use
513
+ # the `top_x` tensor here.
514
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
515
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
516
+ return final_hidden_states, router_logits
517
+
518
+
519
+ class MixtralDecoderLayer(nn.Module):
520
+ def __init__(self, config: MixtralConfig, layer_idx: int):
521
+ super().__init__()
522
+ self.hidden_size = config.hidden_size
523
+
524
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
525
+
526
+ self.block_sparse_moe = MixtralSparseMoeBlock(config)
527
+ self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
528
+ self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
529
+
530
+ def forward(
531
+ self,
532
+ hidden_states: torch.Tensor,
533
+ attention_mask: Optional[torch.Tensor] = None,
534
+ position_ids: Optional[torch.LongTensor] = None,
535
+ past_key_value: Optional[Tuple[KVCache]] = None,
536
+ output_attentions: Optional[bool] = False,
537
+ output_router_logits: Optional[bool] = False,
538
+ use_cache: Optional[bool] = False,
539
+ **kwargs,
540
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
541
+ if "padding_mask" in kwargs:
542
+ warnings.warn(
543
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
544
+ )
545
+ """
546
+ Args:
547
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
548
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
549
+ `(batch, sequence_length)` where padding elements are indicated by 0.
550
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
551
+ output_attentions (`bool`, *optional*):
552
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
553
+ returned tensors for more detail.
554
+ output_router_logits (`bool`, *optional*):
555
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
556
+ should not be returned during inference.
557
+ use_cache (`bool`, *optional*):
558
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
559
+ (see `past_key_values`).
560
+ """
561
+
562
+ residual = hidden_states
563
+
564
+ hidden_states = self.input_layernorm(hidden_states)
565
+
566
+ # Self Attention
567
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
568
+ hidden_states=hidden_states,
569
+ attention_mask=attention_mask,
570
+ position_ids=position_ids,
571
+ past_key_value=past_key_value,
572
+ output_attentions=output_attentions,
573
+ use_cache=use_cache,
574
+ )
575
+ hidden_states = residual + hidden_states
576
+
577
+ # Fully Connected
578
+ residual = hidden_states
579
+ hidden_states = self.post_attention_layernorm(hidden_states)
580
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
581
+ hidden_states = residual + hidden_states
582
+
583
+ outputs = (hidden_states,)
584
+
585
+ if output_attentions:
586
+ outputs += (self_attn_weights,)
587
+
588
+ if use_cache:
589
+ outputs += (present_key_value,)
590
+
591
+ if output_router_logits:
592
+ outputs += (router_logits,)
593
+
594
+ return outputs
595
+
596
+
597
+ MIXTRAL_START_DOCSTRING = r"""
598
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
599
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
600
+ etc.)
601
+
602
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
603
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
604
+ and behavior.
605
+
606
+ Parameters:
607
+ config ([`MixtralConfig`]):
608
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
609
+ load the weights associated with the model, only the configuration. Check out the
610
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
611
+ """
612
+
613
+
614
+ @add_start_docstrings(
615
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
616
+ MIXTRAL_START_DOCSTRING,
617
+ )
618
+ # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
619
+ class MixtralPreTrainedModel(PreTrainedModel):
620
+ config_class = MixtralConfig
621
+ base_model_prefix = "model"
622
+ supports_gradient_checkpointing = True
623
+ _no_split_modules = ["MixtralDecoderLayer"]
624
+ _skip_keys_device_placement = "past_key_values"
625
+ _supports_flash_attn_2 = True
626
+ _supports_cache_class = True
627
+
628
+ def _init_weights(self, module):
629
+ std = self.config.initializer_range
630
+ if isinstance(module, nn.Linear):
631
+ module.weight.data.normal_(mean=0.0, std=std)
632
+ if module.bias is not None:
633
+ module.bias.data.zero_()
634
+ elif isinstance(module, nn.Embedding):
635
+ module.weight.data.normal_(mean=0.0, std=std)
636
+ if module.padding_idx is not None:
637
+ module.weight.data[module.padding_idx].zero_()
638
+
639
+
640
+ MIXTRAL_INPUTS_DOCSTRING = r"""
641
+ Args:
642
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
643
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
644
+ it.
645
+
646
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
647
+ [`PreTrainedTokenizer.__call__`] for details.
648
+
649
+ [What are input IDs?](../glossary#input-ids)
650
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
651
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
652
+
653
+ - 1 for tokens that are **not masked**,
654
+ - 0 for tokens that are **masked**.
655
+
656
+ [What are attention masks?](../glossary#attention-mask)
657
+
658
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
659
+ [`PreTrainedTokenizer.__call__`] for details.
660
+
661
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
662
+ `past_key_values`).
663
+
664
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
665
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
666
+ information on the default strategy.
667
+
668
+ - 1 indicates the head is **not masked**,
669
+ - 0 indicates the head is **masked**.
670
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
671
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
672
+ config.n_positions - 1]`.
673
+
674
+ [What are position IDs?](../glossary#position-ids)
675
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
676
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
677
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
678
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
679
+
680
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
681
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
682
+
683
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
684
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
685
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
686
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
687
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
688
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
689
+ model's internal embedding lookup matrix.
690
+ use_cache (`bool`, *optional*):
691
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
692
+ `past_key_values`).
693
+ output_attentions (`bool`, *optional*):
694
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
695
+ tensors for more detail.
696
+ output_hidden_states (`bool`, *optional*):
697
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
698
+ more detail.
699
+ output_router_logits (`bool`, *optional*):
700
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
701
+ should not be returned during inference.
702
+ return_dict (`bool`, *optional*):
703
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
704
+ """
705
+
706
+
707
+ @add_start_docstrings(
708
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
709
+ MIXTRAL_START_DOCSTRING,
710
+ )
711
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
712
+ class MixtralModel(MixtralPreTrainedModel):
713
+ """
714
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
715
+
716
+ Args:
717
+ config: MixtralConfig
718
+ """
719
+
720
+ def __init__(self, config: MixtralConfig):
721
+ super().__init__(config)
722
+ self.padding_idx = config.pad_token_id
723
+ self.vocab_size = config.vocab_size
724
+
725
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
726
+ self.layers = nn.ModuleList(
727
+ [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
728
+ )
729
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
730
+ self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
731
+
732
+ self.gradient_checkpointing = False
733
+ # Initialize weights and apply final processing
734
+ self.post_init()
735
+
736
+ def get_input_embeddings(self):
737
+ return self.embed_tokens
738
+
739
+ def set_input_embeddings(self, value):
740
+ self.embed_tokens = value
741
+
742
+ def _prepare_decoder_attention_mask(
743
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
744
+ ):
745
+ # create causal mask
746
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
747
+ combined_attention_mask = None
748
+ if input_shape[-1] > 1:
749
+ combined_attention_mask = _make_causal_mask(
750
+ input_shape,
751
+ # inputs_embeds.dtype,
752
+ torch.float32, # [MODIFIED] force to cast to float32
753
+ device=inputs_embeds.device,
754
+ past_key_values_length=past_key_values_length,
755
+ )
756
+
757
+ if attention_mask is not None:
758
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
759
+ expanded_attn_mask = _expand_mask(
760
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
761
+ ).to(inputs_embeds.device)
762
+ combined_attention_mask = (
763
+ expanded_attn_mask
764
+ if combined_attention_mask is None
765
+ else expanded_attn_mask + combined_attention_mask
766
+ )
767
+
768
+
769
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
770
+ tree_mask = self.tree_mask
771
+ tree_len = tree_mask.size(-1)
772
+ combined_attention_mask[:, :, -tree_len:, -tree_len:][
773
+ tree_mask == 0
774
+ ] = combined_attention_mask.min()
775
+
776
+ return combined_attention_mask
777
+
778
+ # Ignore copy
779
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
780
+ def forward(
781
+ self,
782
+ input_ids: torch.LongTensor = None,
783
+ attention_mask: Optional[torch.Tensor] = None,
784
+ position_ids: Optional[torch.LongTensor] = None,
785
+ past_key_values: Optional[List[Tuple[KVCache]]] = None,
786
+ inputs_embeds: Optional[torch.FloatTensor] = None,
787
+ use_cache: Optional[bool] = None,
788
+ output_attentions: Optional[bool] = None,
789
+ output_hidden_states: Optional[bool] = None,
790
+ output_router_logits: Optional[bool] = None,
791
+ return_dict: Optional[bool] = None,
792
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
793
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
794
+ output_router_logits = (
795
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
796
+ )
797
+ output_hidden_states = (
798
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
799
+ )
800
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
801
+
802
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
803
+
804
+ # retrieve input_ids and inputs_embeds
805
+ if input_ids is not None and inputs_embeds is not None:
806
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
807
+ elif input_ids is not None:
808
+ batch_size, seq_length = input_ids.shape
809
+ elif inputs_embeds is not None:
810
+ batch_size, seq_length, _ = inputs_embeds.shape
811
+ else:
812
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
813
+
814
+ past_key_values_length = 0
815
+
816
+ if self.gradient_checkpointing and self.training:
817
+ if use_cache:
818
+ logger.warning_once(
819
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
820
+ )
821
+ use_cache = False
822
+
823
+ if past_key_values is not None:
824
+ past_key_values_length = past_key_values[0][0].shape[2]
825
+
826
+ if position_ids is None:
827
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
828
+ position_ids = torch.arange(
829
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
830
+ )
831
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
832
+ else:
833
+ position_ids = position_ids.view(-1, seq_length).long()
834
+
835
+ if inputs_embeds is None:
836
+ inputs_embeds = self.embed_tokens(input_ids)
837
+
838
+ if attention_mask is not None and self._use_flash_attention_2 and use_cache:
839
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
840
+ if is_padding_right:
841
+ raise ValueError(
842
+ "You are attempting to perform batched generation with padding_side='right'"
843
+ " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
844
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
845
+ )
846
+
847
+ # if self._use_flash_attention_2:
848
+ # # 2d mask is passed through the layers
849
+ # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
850
+ # else:
851
+ # 4d mask is passed through the layers
852
+ attention_mask = self._prepare_decoder_attention_mask(
853
+ attention_mask,
854
+ (batch_size, seq_length),
855
+ inputs_embeds,
856
+ past_key_values_length,
857
+ )
858
+
859
+ hidden_states = inputs_embeds
860
+
861
+ # decoder layers
862
+ all_hidden_states = () if output_hidden_states else None
863
+ all_self_attns = () if output_attentions else None
864
+ all_router_logits = () if output_router_logits else None
865
+ next_decoder_cache = None
866
+
867
+ for idx, decoder_layer in enumerate(self.layers):
868
+ if output_hidden_states:
869
+ all_hidden_states += (hidden_states,)
870
+
871
+ past_key_value = (
872
+ past_key_values[idx] if past_key_values is not None else None
873
+ )
874
+
875
+ if self.gradient_checkpointing and self.training:
876
+ layer_outputs = self._gradient_checkpointing_func(
877
+ decoder_layer.__call__,
878
+ hidden_states,
879
+ attention_mask,
880
+ position_ids,
881
+ past_key_value,
882
+ output_attentions,
883
+ output_router_logits,
884
+ use_cache,
885
+ )
886
+ else:
887
+ layer_outputs = decoder_layer(
888
+ hidden_states,
889
+ attention_mask=attention_mask,
890
+ position_ids=position_ids,
891
+ past_key_value=past_key_value,
892
+ output_attentions=output_attentions,
893
+ output_router_logits=output_router_logits,
894
+ use_cache=use_cache,
895
+ )
896
+
897
+ hidden_states = layer_outputs[0]
898
+
899
+ if use_cache:
900
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
901
+
902
+ if output_attentions:
903
+ all_self_attns += (layer_outputs[1],)
904
+
905
+ if output_router_logits:
906
+ all_router_logits += (layer_outputs[-1],)
907
+
908
+ hidden_states = self.norm(hidden_states)
909
+
910
+ # add hidden states from the last decoder layer
911
+ if output_hidden_states:
912
+ all_hidden_states += (hidden_states,)
913
+
914
+
915
+ next_cache = next_decoder_cache if use_cache else None
916
+ # if use_cache:
917
+ # next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
918
+
919
+ if not return_dict:
920
+ return tuple(
921
+ v
922
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
923
+ if v is not None
924
+ )
925
+ return MoeModelOutputWithPast(
926
+ last_hidden_state=hidden_states,
927
+ past_key_values=next_cache,
928
+ hidden_states=all_hidden_states,
929
+ attentions=all_self_attns,
930
+ router_logits=all_router_logits,
931
+ )
932
+
933
+
934
+ class MixtralForCausalLM(MixtralPreTrainedModel):
935
+ _tied_weights_keys = ["lm_head.weight"]
936
+
937
+ def __init__(self, config):
938
+ super().__init__(config)
939
+ self.model = MixtralModel(config)
940
+ self.vocab_size = config.vocab_size
941
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
942
+ self.router_aux_loss_coef = config.router_aux_loss_coef
943
+ self.num_experts = config.num_local_experts
944
+ self.num_experts_per_tok = config.num_experts_per_tok
945
+ # Initialize weights and apply final processing
946
+ self.post_init()
947
+
948
+ def get_input_embeddings(self):
949
+ return self.model.embed_tokens
950
+
951
+ def set_input_embeddings(self, value):
952
+ self.model.embed_tokens = value
953
+
954
+ def get_output_embeddings(self):
955
+ return self.lm_head
956
+
957
+ def set_output_embeddings(self, new_embeddings):
958
+ self.lm_head = new_embeddings
959
+
960
+ def set_decoder(self, decoder):
961
+ self.model = decoder
962
+
963
+ def get_decoder(self):
964
+ return self.model
965
+
966
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
967
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
968
+ # Ignore copy
969
+ def forward(
970
+ self,
971
+ input_ids: torch.LongTensor = None,
972
+ attention_mask: Optional[torch.Tensor] = None,
973
+ position_ids: Optional[torch.LongTensor] = None,
974
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
975
+ inputs_embeds: Optional[torch.FloatTensor] = None,
976
+ labels: Optional[torch.LongTensor] = None,
977
+ use_cache: Optional[bool] = None,
978
+ output_attentions: Optional[bool] = None,
979
+ output_hidden_states: Optional[bool] = None,
980
+ output_router_logits: Optional[bool] = None,
981
+ return_dict: Optional[bool] = None,
982
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
983
+ r"""
984
+ Args:
985
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
986
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
987
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
988
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
989
+
990
+ Returns:
991
+
992
+ Example:
993
+
994
+ ```python
995
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
996
+
997
+ >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
998
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
999
+
1000
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1001
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1002
+
1003
+ >>> # Generate
1004
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1005
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1006
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1007
+ ```"""
1008
+
1009
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1010
+ output_router_logits = (
1011
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1012
+ )
1013
+
1014
+ output_hidden_states = (
1015
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1016
+ )
1017
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1018
+
1019
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1020
+ outputs = self.model(
1021
+ input_ids=input_ids,
1022
+ attention_mask=attention_mask,
1023
+ position_ids=position_ids,
1024
+ past_key_values=past_key_values,
1025
+ inputs_embeds=inputs_embeds,
1026
+ use_cache=use_cache,
1027
+ output_attentions=output_attentions,
1028
+ output_hidden_states=output_hidden_states,
1029
+ output_router_logits=output_router_logits,
1030
+ return_dict=return_dict,
1031
+ )
1032
+
1033
+ hidden_states = outputs[0]
1034
+ logits = self.lm_head(hidden_states)
1035
+ logits = logits.float()
1036
+
1037
+ loss = None
1038
+ if labels is not None:
1039
+ # Shift so that tokens < n predict n
1040
+ shift_logits = logits[..., :-1, :].contiguous()
1041
+ shift_labels = labels[..., 1:].contiguous()
1042
+ # Flatten the tokens
1043
+ loss_fct = CrossEntropyLoss()
1044
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1045
+ shift_labels = shift_labels.view(-1)
1046
+ # Enable model parallelism
1047
+ shift_labels = shift_labels.to(shift_logits.device)
1048
+ loss = loss_fct(shift_logits, shift_labels)
1049
+
1050
+ aux_loss = None
1051
+ if output_router_logits:
1052
+ aux_loss = load_balancing_loss_func(
1053
+ outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok
1054
+ )
1055
+ if labels is not None:
1056
+ loss += self.router_aux_loss_coef * aux_loss
1057
+
1058
+ if not return_dict:
1059
+ output = (logits,) + outputs[1:]
1060
+ if output_router_logits:
1061
+ output = (aux_loss,) + output
1062
+ return (loss,) + output if loss is not None else output
1063
+
1064
+ return MoeCausalLMOutputWithPast(
1065
+ loss=loss,
1066
+ aux_loss=aux_loss,
1067
+ logits=logits,
1068
+ past_key_values=outputs.past_key_values,
1069
+ hidden_states=outputs.hidden_states,
1070
+ attentions=outputs.attentions,
1071
+ router_logits=outputs.router_logits,
1072
+ )
1073
+
1074
+
1075
+
1076
+
1077
+
1078
+
1079
+ @add_start_docstrings(
1080
+ """
1081
+ The Mixtral Model transformer with a sequence classification head on top (linear layer).
1082
+
1083
+ [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1084
+ (e.g. GPT-2) do.
1085
+
1086
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1087
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1088
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1089
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1090
+ each row of the batch).
1091
+ """,
1092
+ MIXTRAL_START_DOCSTRING,
1093
+ )
1094
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
1095
+ class MixtralForSequenceClassification(MixtralPreTrainedModel):
1096
+ def __init__(self, config):
1097
+ super().__init__(config)
1098
+ self.num_labels = config.num_labels
1099
+ self.model = MixtralModel(config)
1100
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1101
+
1102
+ # Initialize weights and apply final processing
1103
+ self.post_init()
1104
+
1105
+ def get_input_embeddings(self):
1106
+ return self.model.embed_tokens
1107
+
1108
+ def set_input_embeddings(self, value):
1109
+ self.model.embed_tokens = value
1110
+
1111
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1112
+ def forward(
1113
+ self,
1114
+ input_ids: torch.LongTensor = None,
1115
+ attention_mask: Optional[torch.Tensor] = None,
1116
+ position_ids: Optional[torch.LongTensor] = None,
1117
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1118
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1119
+ labels: Optional[torch.LongTensor] = None,
1120
+ use_cache: Optional[bool] = None,
1121
+ output_attentions: Optional[bool] = None,
1122
+ output_hidden_states: Optional[bool] = None,
1123
+ return_dict: Optional[bool] = None,
1124
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1125
+ r"""
1126
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1127
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1128
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1129
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1130
+ """
1131
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1132
+
1133
+ transformer_outputs = self.model(
1134
+ input_ids,
1135
+ attention_mask=attention_mask,
1136
+ position_ids=position_ids,
1137
+ past_key_values=past_key_values,
1138
+ inputs_embeds=inputs_embeds,
1139
+ use_cache=use_cache,
1140
+ output_attentions=output_attentions,
1141
+ output_hidden_states=output_hidden_states,
1142
+ return_dict=return_dict,
1143
+ )
1144
+ hidden_states = transformer_outputs[0]
1145
+ logits = self.score(hidden_states)
1146
+
1147
+ if input_ids is not None:
1148
+ batch_size = input_ids.shape[0]
1149
+ else:
1150
+ batch_size = inputs_embeds.shape[0]
1151
+
1152
+ if self.config.pad_token_id is None and batch_size != 1:
1153
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1154
+ if self.config.pad_token_id is None:
1155
+ sequence_lengths = -1
1156
+ else:
1157
+ if input_ids is not None:
1158
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1159
+ logits.device
1160
+ )
1161
+ else:
1162
+ sequence_lengths = -1
1163
+
1164
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1165
+
1166
+ loss = None
1167
+ if labels is not None:
1168
+ labels = labels.to(logits.device)
1169
+ if self.config.problem_type is None:
1170
+ if self.num_labels == 1:
1171
+ self.config.problem_type = "regression"
1172
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1173
+ self.config.problem_type = "single_label_classification"
1174
+ else:
1175
+ self.config.problem_type = "multi_label_classification"
1176
+
1177
+ if self.config.problem_type == "regression":
1178
+ loss_fct = MSELoss()
1179
+ if self.num_labels == 1:
1180
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1181
+ else:
1182
+ loss = loss_fct(pooled_logits, labels)
1183
+ elif self.config.problem_type == "single_label_classification":
1184
+ loss_fct = CrossEntropyLoss()
1185
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1186
+ elif self.config.problem_type == "multi_label_classification":
1187
+ loss_fct = BCEWithLogitsLoss()
1188
+ loss = loss_fct(pooled_logits, labels)
1189
+ if not return_dict:
1190
+ output = (pooled_logits,) + transformer_outputs[1:]
1191
+ return ((loss,) + output) if loss is not None else output
1192
+
1193
+ return SequenceClassifierOutputWithPast(
1194
+ loss=loss,
1195
+ logits=pooled_logits,
1196
+ past_key_values=transformer_outputs.past_key_values,
1197
+ hidden_states=transformer_outputs.hidden_states,
1198
+ attentions=transformer_outputs.attentions,
1199
+ )
eagle/model/modeling_qwen2_kv.py ADDED
@@ -0,0 +1,1513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch Qwen2 model."""
21
+
22
+ import math
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPast,
36
+ CausalLMOutputWithPast,
37
+ SequenceClassifierOutputWithPast,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
41
+ from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.utils import (
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ is_flash_attn_2_available,
46
+ is_flash_attn_greater_or_equal_2_10,
47
+ is_torchdynamo_compiling,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
52
+
53
+ if is_flash_attn_2_available():
54
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
+ _CONFIG_FOR_DOC = "Qwen2Config"
60
+
61
+
62
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
63
+ class Qwen2RMSNorm(nn.Module):
64
+ def __init__(self, hidden_size, eps=1e-6):
65
+ """
66
+ Qwen2RMSNorm is equivalent to T5LayerNorm
67
+ """
68
+ super().__init__()
69
+ self.weight = nn.Parameter(torch.ones(hidden_size))
70
+ self.variance_epsilon = eps
71
+
72
+ def forward(self, hidden_states):
73
+ input_dtype = hidden_states.dtype
74
+ hidden_states = hidden_states.to(torch.float32)
75
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
76
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
77
+ return self.weight * hidden_states.to(input_dtype)
78
+
79
+ def extra_repr(self):
80
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
81
+
82
+
83
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
84
+ class Qwen2RotaryEmbedding(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim=None,
88
+ max_position_embeddings=2048,
89
+ base=10000,
90
+ device=None,
91
+ scaling_factor=1.0,
92
+ rope_type="default",
93
+ config: Optional[Qwen2Config] = None,
94
+ ):
95
+ super().__init__()
96
+ # TODO (joao): remove the `if` below, only used for BC
97
+ self.rope_kwargs = {}
98
+ if config is None:
99
+ logger.warning_once(
100
+ "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
101
+ "`config` argument. All other arguments will be removed in v4.46"
102
+ )
103
+ self.rope_kwargs = {
104
+ "rope_type": rope_type,
105
+ "factor": scaling_factor,
106
+ "dim": dim,
107
+ "base": base,
108
+ "max_position_embeddings": max_position_embeddings,
109
+ }
110
+ self.rope_type = rope_type
111
+ self.max_seq_len_cached = max_position_embeddings
112
+ self.original_max_seq_len = max_position_embeddings
113
+ else:
114
+ # BC: "rope_type" was originally "type"
115
+ if config.rope_scaling is not None:
116
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
117
+ else:
118
+ self.rope_type = "default"
119
+ self.max_seq_len_cached = config.max_position_embeddings
120
+ self.original_max_seq_len = config.max_position_embeddings
121
+
122
+ self.config = config
123
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
124
+
125
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
126
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
127
+ self.original_inv_freq = self.inv_freq
128
+
129
+ def _dynamic_frequency_update(self, position_ids, device):
130
+ """
131
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
132
+ 1 - growing beyond the cached sequence length (allow scaling)
133
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
134
+ """
135
+ seq_len = torch.max(position_ids) + 1
136
+ if seq_len > self.max_seq_len_cached: # growth
137
+ inv_freq, self.attention_scaling = self.rope_init_fn(
138
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
139
+ )
140
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
141
+ self.max_seq_len_cached = seq_len
142
+
143
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
144
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
145
+ self.max_seq_len_cached = self.original_max_seq_len
146
+
147
+ @torch.no_grad()
148
+ def forward(self, x, position_ids):
149
+ if "dynamic" in self.rope_type:
150
+ self._dynamic_frequency_update(position_ids, device=x.device)
151
+
152
+ # Core RoPE block
153
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
154
+ position_ids_expanded = position_ids[:, None, :].float()
155
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
156
+ device_type = x.device.type
157
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
158
+ with torch.autocast(device_type=device_type, enabled=False):
159
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ cos = emb.cos()
162
+ sin = emb.sin()
163
+
164
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
165
+ cos = cos * self.attention_scaling
166
+ sin = sin * self.attention_scaling
167
+
168
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
169
+
170
+
171
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
172
+ def rotate_half(x):
173
+ """Rotates half the hidden dims of the input."""
174
+ x1 = x[..., : x.shape[-1] // 2]
175
+ x2 = x[..., x.shape[-1] // 2:]
176
+ return torch.cat((-x2, x1), dim=-1)
177
+
178
+
179
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
180
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
181
+ """Applies Rotary Position Embedding to the query and key tensors.
182
+
183
+ Args:
184
+ q (`torch.Tensor`): The query tensor.
185
+ k (`torch.Tensor`): The key tensor.
186
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
187
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
188
+ position_ids (`torch.Tensor`, *optional*):
189
+ Deprecated and unused.
190
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
191
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
192
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
193
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
194
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
195
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
196
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
197
+ Returns:
198
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
199
+ """
200
+ cos = cos.unsqueeze(unsqueeze_dim)
201
+ sin = sin.unsqueeze(unsqueeze_dim)
202
+ q_embed = (q * cos) + (rotate_half(q) * sin)
203
+ k_embed = (k * cos) + (rotate_half(k) * sin)
204
+ return q_embed, k_embed
205
+
206
+
207
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
208
+ class Qwen2MLP(nn.Module):
209
+ def __init__(self, config):
210
+ super().__init__()
211
+ self.hidden_size = config.hidden_size
212
+ self.intermediate_size = config.intermediate_size
213
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
214
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
215
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
216
+ self.act_fn = ACT2FN[config.hidden_act]
217
+
218
+ def forward(self, hidden_state):
219
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
220
+
221
+
222
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
223
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
224
+ """
225
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
226
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
227
+ """
228
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
229
+ if n_rep == 1:
230
+ return hidden_states
231
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
232
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
233
+
234
+
235
+ class Qwen2Attention(nn.Module):
236
+ """
237
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
238
+ and "Generating Long Sequences with Sparse Transformers".
239
+ """
240
+
241
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
242
+ super().__init__()
243
+ self.config = config
244
+ self.layer_idx = layer_idx
245
+ if layer_idx is None:
246
+ logger.warning_once(
247
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
248
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
249
+ "when creating this class."
250
+ )
251
+
252
+ self.hidden_size = config.hidden_size
253
+ self.num_heads = config.num_attention_heads
254
+ self.head_dim = self.hidden_size // self.num_heads
255
+ self.num_key_value_heads = config.num_key_value_heads
256
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
257
+ self.max_position_embeddings = config.max_position_embeddings
258
+ self.rope_theta = config.rope_theta
259
+ self.is_causal = True
260
+ self.attention_dropout = config.attention_dropout
261
+
262
+ if (self.head_dim * self.num_heads) != self.hidden_size:
263
+ raise ValueError(
264
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
265
+ f" and `num_heads`: {self.num_heads})."
266
+ )
267
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
268
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
269
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
270
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
271
+
272
+ self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
273
+
274
+ def forward(
275
+ self,
276
+ hidden_states: torch.Tensor,
277
+ attention_mask: Optional[torch.Tensor] = None,
278
+ position_ids: Optional[torch.LongTensor] = None,
279
+ past_key_value: Optional[Cache] = None,
280
+ output_attentions: bool = False,
281
+ use_cache: bool = False,
282
+ cache_position: Optional[torch.LongTensor] = None,
283
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
284
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
285
+ bsz, q_len, _ = hidden_states.size()
286
+
287
+ query_states = self.q_proj(hidden_states)
288
+ key_states = self.k_proj(hidden_states)
289
+ value_states = self.v_proj(hidden_states)
290
+
291
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
292
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
293
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
294
+
295
+ if position_embeddings is None:
296
+ logger.warning_once(
297
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
298
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
299
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
300
+ "removed and `position_embeddings` will be mandatory."
301
+ )
302
+ cos, sin = self.rotary_emb(value_states, position_ids)
303
+ else:
304
+ cos, sin = position_embeddings
305
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
306
+
307
+ if past_key_value is not None:
308
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
309
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
310
+ # key_states, value_states = past_key_value.cat(key_states, value_states, self.layer_idx)
311
+ past_key, past_value = past_key_value[self.layer_idx]
312
+ key_states = past_key.cat(key_states)
313
+ value_states = past_value.cat(value_states)
314
+ past_key_value = None
315
+
316
+ # repeat k/v heads if n_kv_heads < n_heads
317
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
318
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
319
+
320
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
321
+ if attention_mask is not None: # no matter the length, we just slice it
322
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
323
+ attn_weights = attn_weights + causal_mask
324
+
325
+ # upcast attention to fp32
326
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
327
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
328
+ attn_output = torch.matmul(attn_weights, value_states)
329
+
330
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
331
+ raise ValueError(
332
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
333
+ f" {attn_output.size()}"
334
+ )
335
+
336
+ attn_output = attn_output.transpose(1, 2).contiguous()
337
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
338
+
339
+ attn_output = self.o_proj(attn_output)
340
+
341
+ if not output_attentions:
342
+ attn_weights = None
343
+
344
+ return attn_output, attn_weights, past_key_value
345
+
346
+
347
+ class Qwen2FlashAttention2(Qwen2Attention):
348
+ """
349
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
350
+ as the weights of the module stays untouched. The only required change would be on the forward pass
351
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
352
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
353
+ config.max_window_layers layers.
354
+ """
355
+
356
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
357
+ def __init__(self, *args, **kwargs):
358
+ super().__init__(*args, **kwargs)
359
+
360
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
361
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
362
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
363
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
364
+
365
+ def forward(
366
+ self,
367
+ hidden_states: torch.Tensor,
368
+ attention_mask: Optional[torch.Tensor] = None,
369
+ position_ids: Optional[torch.LongTensor] = None,
370
+ past_key_value: Optional[Cache] = None,
371
+ output_attentions: bool = False,
372
+ use_cache: bool = False,
373
+ cache_position: Optional[torch.LongTensor] = None,
374
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
375
+ ):
376
+ bsz, q_len, _ = hidden_states.size()
377
+
378
+ query_states = self.q_proj(hidden_states)
379
+ key_states = self.k_proj(hidden_states)
380
+ value_states = self.v_proj(hidden_states)
381
+
382
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
383
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
384
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
385
+
386
+ if position_embeddings is None:
387
+ logger.warning_once(
388
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
389
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
390
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
391
+ "removed and `position_embeddings` will be mandatory."
392
+ )
393
+ cos, sin = self.rotary_emb(value_states, position_ids)
394
+ else:
395
+ cos, sin = position_embeddings
396
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
397
+
398
+ if past_key_value is not None:
399
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
400
+ cache_has_contents = past_key_value.get_seq_length[self.layer_idx][0].current_length.item() > 0
401
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
402
+ if (
403
+ getattr(self.config, "sliding_window", None) is not None
404
+ and kv_seq_len > self.config.sliding_window
405
+ and cache_has_contents
406
+ ):
407
+ slicing_tokens = 1 - self.config.sliding_window
408
+
409
+ past_key = past_key_value[self.layer_idx][0]
410
+ past_value = past_key_value[self.layer_idx][1]
411
+
412
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
413
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
414
+
415
+ if past_key.shape[-2] != self.config.sliding_window - 1:
416
+ raise ValueError(
417
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
418
+ f" {past_key.shape}"
419
+ )
420
+
421
+ if attention_mask is not None:
422
+ attention_mask = attention_mask[:, slicing_tokens:]
423
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
424
+
425
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
426
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
427
+
428
+ # repeat k/v heads if n_kv_heads < n_heads
429
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
430
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
431
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
432
+
433
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
434
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
435
+ # cast them back in float16 just to be sure everything works as expected.
436
+ input_dtype = query_states.dtype
437
+ if input_dtype == torch.float32:
438
+ if torch.is_autocast_enabled():
439
+ target_dtype = torch.get_autocast_gpu_dtype()
440
+ # Handle the case where the model is quantized
441
+ elif hasattr(self.config, "_pre_quantization_dtype"):
442
+ target_dtype = self.config._pre_quantization_dtype
443
+ else:
444
+ target_dtype = self.q_proj.weight.dtype
445
+
446
+ logger.warning_once(
447
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
448
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
449
+ f" {target_dtype}."
450
+ )
451
+
452
+ query_states = query_states.to(target_dtype)
453
+ key_states = key_states.to(target_dtype)
454
+ value_states = value_states.to(target_dtype)
455
+
456
+ # Reashape to the expected shape for Flash Attention
457
+ query_states = query_states.transpose(1, 2)
458
+ key_states = key_states.transpose(1, 2)
459
+ value_states = value_states.transpose(1, 2)
460
+
461
+ if (
462
+ self.config.use_sliding_window
463
+ and getattr(self.config, "sliding_window", None) is not None
464
+ and self.layer_idx >= self.config.max_window_layers
465
+ ):
466
+ sliding_window = self.config.sliding_window
467
+ else:
468
+ sliding_window = None
469
+
470
+ attn_output = _flash_attention_forward(
471
+ query_states,
472
+ key_states,
473
+ value_states,
474
+ attention_mask,
475
+ q_len,
476
+ position_ids=position_ids,
477
+ dropout=dropout_rate,
478
+ sliding_window=sliding_window,
479
+ is_causal=self.is_causal,
480
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
481
+ )
482
+
483
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
484
+ attn_output = self.o_proj(attn_output)
485
+
486
+ if not output_attentions:
487
+ attn_weights = None
488
+
489
+ return attn_output, attn_weights, past_key_value
490
+
491
+
492
+ class Qwen2SdpaAttention(Qwen2Attention):
493
+ """
494
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
495
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
496
+ SDPA API.
497
+ """
498
+
499
+ # Adapted from Qwen2Attention.forward
500
+ def forward(
501
+ self,
502
+ hidden_states: torch.Tensor,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ position_ids: Optional[torch.LongTensor] = None,
505
+ past_key_value: Optional[Cache] = None,
506
+ output_attentions: bool = False,
507
+ use_cache: bool = False,
508
+ cache_position: Optional[torch.LongTensor] = None,
509
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
510
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
511
+ if output_attentions:
512
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
513
+ logger.warning_once(
514
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
515
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
516
+ )
517
+ return super().forward(
518
+ hidden_states=hidden_states,
519
+ attention_mask=attention_mask,
520
+ position_ids=position_ids,
521
+ past_key_value=past_key_value,
522
+ output_attentions=output_attentions,
523
+ use_cache=use_cache,
524
+ )
525
+
526
+ bsz, q_len, _ = hidden_states.size()
527
+
528
+ query_states = self.q_proj(hidden_states)
529
+ key_states = self.k_proj(hidden_states)
530
+ value_states = self.v_proj(hidden_states)
531
+
532
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
533
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
534
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
535
+
536
+ if position_embeddings is None:
537
+ logger.warning_once(
538
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
539
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
540
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
541
+ "removed and `position_embeddings` will be mandatory."
542
+ )
543
+ cos, sin = self.rotary_emb(value_states, position_ids)
544
+ else:
545
+ cos, sin = position_embeddings
546
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
547
+
548
+ if past_key_value is not None:
549
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
550
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
551
+ # key_states, value_states = past_key_value.cat(key_states, value_states, self.layer_idx)
552
+ past_key, past_value = past_key_value[self.layer_idx]
553
+ key_states = past_key.cat(key_states)
554
+ value_states = past_value.cat(value_states)
555
+ past_key_value = None
556
+
557
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
558
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
559
+
560
+ causal_mask = attention_mask
561
+ if attention_mask is not None: # no matter the length, we just slice it
562
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
563
+
564
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
565
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
566
+ if query_states.device.type == "cuda" and attention_mask is not None:
567
+ query_states = query_states.contiguous()
568
+ key_states = key_states.contiguous()
569
+ value_states = value_states.contiguous()
570
+
571
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
572
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
573
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
574
+ is_causal = True if causal_mask is None and q_len > 1 else False
575
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
576
+ query_states,
577
+ key_states,
578
+ value_states,
579
+ attn_mask=causal_mask,
580
+ dropout_p=self.attention_dropout if self.training else 0.0,
581
+ is_causal=is_causal,
582
+ )
583
+
584
+ attn_output = attn_output.transpose(1, 2).contiguous()
585
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
586
+
587
+ attn_output = self.o_proj(attn_output)
588
+
589
+ return attn_output, None, past_key_value
590
+
591
+
592
+ QWEN2_ATTENTION_CLASSES = {
593
+ "eager": Qwen2Attention,
594
+ "flash_attention_2": Qwen2FlashAttention2,
595
+ "sdpa": Qwen2SdpaAttention,
596
+ }
597
+
598
+
599
+ class Qwen2DecoderLayer(nn.Module):
600
+ def __init__(self, config: Qwen2Config, layer_idx: int):
601
+ super().__init__()
602
+ self.hidden_size = config.hidden_size
603
+
604
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
605
+ logger.warning_once(
606
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
607
+ "unexpected results may be encountered."
608
+ )
609
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
610
+ # self.self_attn = QWEN2_ATTENTION_CLASSES["eager"](config, layer_idx)
611
+
612
+ self.mlp = Qwen2MLP(config)
613
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
614
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
615
+
616
+ def forward(
617
+ self,
618
+ hidden_states: torch.Tensor,
619
+ attention_mask: Optional[torch.Tensor] = None,
620
+ position_ids: Optional[torch.LongTensor] = None,
621
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
622
+ output_attentions: Optional[bool] = False,
623
+ use_cache: Optional[bool] = False,
624
+ cache_position: Optional[torch.LongTensor] = None,
625
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
626
+ **kwargs,
627
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
628
+ """
629
+ Args:
630
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
631
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
632
+ `(batch, sequence_length)` where padding elements are indicated by 0.
633
+ output_attentions (`bool`, *optional*):
634
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
635
+ returned tensors for more detail.
636
+ use_cache (`bool`, *optional*):
637
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
638
+ (see `past_key_values`).
639
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
640
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
641
+ Indices depicting the position of the input sequence tokens in the sequence.
642
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
643
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
644
+ with `head_dim` being the embedding dimension of each attention head.
645
+ kwargs (`dict`, *optional*):
646
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
647
+ into the model
648
+ """
649
+
650
+ residual = hidden_states
651
+
652
+ hidden_states = self.input_layernorm(hidden_states)
653
+
654
+ # Self Attention
655
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
656
+ hidden_states=hidden_states,
657
+ attention_mask=attention_mask,
658
+ position_ids=position_ids,
659
+ past_key_value=past_key_value,
660
+ output_attentions=output_attentions,
661
+ use_cache=use_cache,
662
+ cache_position=cache_position,
663
+ position_embeddings=position_embeddings,
664
+ )
665
+ hidden_states = residual + hidden_states
666
+
667
+ # Fully Connected
668
+ residual = hidden_states
669
+ hidden_states = self.post_attention_layernorm(hidden_states)
670
+ hidden_states = self.mlp(hidden_states)
671
+ hidden_states = residual + hidden_states
672
+
673
+ outputs = (hidden_states,)
674
+
675
+ if output_attentions:
676
+ outputs += (self_attn_weights,)
677
+
678
+ if use_cache:
679
+ outputs += (present_key_value,)
680
+
681
+ return outputs
682
+
683
+
684
+ QWEN2_START_DOCSTRING = r"""
685
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
686
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
687
+ etc.)
688
+
689
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
690
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
691
+ and behavior.
692
+
693
+ Parameters:
694
+ config ([`Qwen2Config`]):
695
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
696
+ load the weights associated with the model, only the configuration. Check out the
697
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
698
+ """
699
+
700
+
701
+ @add_start_docstrings(
702
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
703
+ QWEN2_START_DOCSTRING,
704
+ )
705
+ class Qwen2PreTrainedModel(PreTrainedModel):
706
+ config_class = Qwen2Config
707
+ base_model_prefix = "model"
708
+ supports_gradient_checkpointing = True
709
+ _no_split_modules = ["Qwen2DecoderLayer"]
710
+ _skip_keys_device_placement = "past_key_values"
711
+ _supports_flash_attn_2 = True
712
+ _supports_sdpa = True
713
+ _supports_cache_class = True
714
+ _supports_quantized_cache = True
715
+ _supports_static_cache = True
716
+
717
+ def _init_weights(self, module):
718
+ std = self.config.initializer_range
719
+ if isinstance(module, nn.Linear):
720
+ module.weight.data.normal_(mean=0.0, std=std)
721
+ if module.bias is not None:
722
+ module.bias.data.zero_()
723
+ elif isinstance(module, nn.Embedding):
724
+ module.weight.data.normal_(mean=0.0, std=std)
725
+ if module.padding_idx is not None:
726
+ module.weight.data[module.padding_idx].zero_()
727
+
728
+
729
+ QWEN2_INPUTS_DOCSTRING = r"""
730
+ Args:
731
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
732
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
733
+ it.
734
+
735
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
736
+ [`PreTrainedTokenizer.__call__`] for details.
737
+
738
+ [What are input IDs?](../glossary#input-ids)
739
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
740
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
741
+
742
+ - 1 for tokens that are **not masked**,
743
+ - 0 for tokens that are **masked**.
744
+
745
+ [What are attention masks?](../glossary#attention-mask)
746
+
747
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
748
+ [`PreTrainedTokenizer.__call__`] for details.
749
+
750
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
751
+ `past_key_values`).
752
+
753
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
754
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
755
+ information on the default strategy.
756
+
757
+ - 1 indicates the head is **not masked**,
758
+ - 0 indicates the head is **masked**.
759
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
760
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
761
+ config.n_positions - 1]`.
762
+
763
+ [What are position IDs?](../glossary#position-ids)
764
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
765
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
766
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
767
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
768
+
769
+ Two formats are allowed:
770
+ - a [`~cache_utils.Cache`] instance, see our
771
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
772
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
773
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
774
+ cache format.
775
+
776
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
777
+ legacy cache format will be returned.
778
+
779
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
780
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
781
+ of shape `(batch_size, sequence_length)`.
782
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
783
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
784
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
785
+ model's internal embedding lookup matrix.
786
+ use_cache (`bool`, *optional*):
787
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
788
+ `past_key_values`).
789
+ output_attentions (`bool`, *optional*):
790
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
791
+ tensors for more detail.
792
+ output_hidden_states (`bool`, *optional*):
793
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
794
+ more detail.
795
+ return_dict (`bool`, *optional*):
796
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
797
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
798
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
799
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
800
+ the complete sequence length.
801
+ """
802
+
803
+
804
+ @add_start_docstrings(
805
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
806
+ QWEN2_START_DOCSTRING,
807
+ )
808
+ class Qwen2Model(Qwen2PreTrainedModel):
809
+ """
810
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
811
+
812
+ Args:
813
+ config: Qwen2Config
814
+ """
815
+
816
+ def __init__(self, config: Qwen2Config):
817
+ super().__init__(config)
818
+ self.padding_idx = config.pad_token_id
819
+ self.vocab_size = config.vocab_size
820
+
821
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
822
+ self.layers = nn.ModuleList(
823
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
824
+ )
825
+ self._attn_implementation = config._attn_implementation
826
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
827
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
828
+
829
+ self.gradient_checkpointing = False
830
+ # Initialize weights and apply final processing
831
+ self.post_init()
832
+
833
+ def get_input_embeddings(self):
834
+ return self.embed_tokens
835
+
836
+ def set_input_embeddings(self, value):
837
+ self.embed_tokens = value
838
+
839
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
840
+ def forward(
841
+ self,
842
+ input_ids: torch.LongTensor = None,
843
+ attention_mask: Optional[torch.Tensor] = None,
844
+ position_ids: Optional[torch.LongTensor] = None,
845
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
846
+ inputs_embeds: Optional[torch.FloatTensor] = None,
847
+ use_cache: Optional[bool] = None,
848
+ output_attentions: Optional[bool] = None,
849
+ output_hidden_states: Optional[bool] = None,
850
+ return_dict: Optional[bool] = None,
851
+ cache_position: Optional[torch.LongTensor] = None,
852
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
853
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
854
+ output_hidden_states = (
855
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
856
+ )
857
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
858
+
859
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
860
+
861
+ if (input_ids is None) ^ (inputs_embeds is not None):
862
+ raise ValueError(
863
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
864
+ )
865
+
866
+ if self.gradient_checkpointing and self.training:
867
+ if use_cache:
868
+ logger.warning_once(
869
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
870
+ )
871
+ use_cache = False
872
+
873
+ # kept for BC (non `Cache` `past_key_values` inputs)
874
+ return_legacy_cache = False
875
+ # if use_cache and not isinstance(past_key_values, Cache):
876
+ # return_legacy_cache = True
877
+ # if past_key_values is None:
878
+ # past_key_values = DynamicCache()
879
+ # else:
880
+ # past_key_values = DynamicCache.from_legacy_cache(past_key_values)
881
+ # logger.warning_once(
882
+ # "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
883
+ # "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
884
+ # "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
885
+ # )
886
+
887
+ if inputs_embeds is None:
888
+ inputs_embeds = self.embed_tokens(input_ids)
889
+
890
+ if cache_position is None:
891
+ past_seen_tokens = past_key_values[0][0].current_length.item() if past_key_values is not None else 0
892
+ cache_position = torch.arange(
893
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
894
+ )
895
+
896
+ if position_ids is None:
897
+ position_ids = cache_position.unsqueeze(0)
898
+ causal_mask = self._update_causal_mask(
899
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
900
+ )
901
+
902
+ hidden_states = inputs_embeds
903
+
904
+ # create position embeddings to be shared across the decoder layers
905
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
906
+
907
+ # decoder layers
908
+ all_hidden_states = () if output_hidden_states else None
909
+ all_self_attns = () if output_attentions else None
910
+ next_decoder_cache = None
911
+
912
+ for decoder_layer in self.layers:
913
+ if output_hidden_states:
914
+ all_hidden_states += (hidden_states,)
915
+
916
+ if self.gradient_checkpointing and self.training:
917
+ layer_outputs = self._gradient_checkpointing_func(
918
+ decoder_layer.__call__,
919
+ hidden_states,
920
+ causal_mask,
921
+ position_ids,
922
+ past_key_values,
923
+ output_attentions,
924
+ use_cache,
925
+ cache_position,
926
+ position_embeddings,
927
+ )
928
+ else:
929
+ layer_outputs = decoder_layer(
930
+ hidden_states,
931
+ attention_mask=causal_mask,
932
+ position_ids=position_ids,
933
+ past_key_value=past_key_values,
934
+ output_attentions=output_attentions,
935
+ use_cache=use_cache,
936
+ cache_position=cache_position,
937
+ position_embeddings=position_embeddings,
938
+ )
939
+
940
+ hidden_states = layer_outputs[0]
941
+
942
+ if use_cache:
943
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
944
+
945
+ if output_attentions:
946
+ all_self_attns += (layer_outputs[1],)
947
+
948
+ hidden_states = self.norm(hidden_states)
949
+
950
+ # add hidden states from the last decoder layer
951
+ if output_hidden_states:
952
+ all_hidden_states += (hidden_states,)
953
+
954
+ next_cache = next_decoder_cache if use_cache else None
955
+ if return_legacy_cache:
956
+ next_cache = next_cache.to_legacy_cache()
957
+
958
+ if not return_dict:
959
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
960
+ return BaseModelOutputWithPast(
961
+ last_hidden_state=hidden_states,
962
+ past_key_values=next_cache,
963
+ hidden_states=all_hidden_states,
964
+ attentions=all_self_attns,
965
+ )
966
+
967
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
968
+ def _prepare_4d_causal_attention_mask_with_cache_position(
969
+ self,
970
+ attention_mask: torch.Tensor,
971
+ sequence_length: int,
972
+ target_length: int,
973
+ dtype: torch.dtype,
974
+ device: torch.device,
975
+ min_dtype: float,
976
+ cache_position: torch.Tensor,
977
+ batch_size: int,
978
+ ):
979
+ """
980
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
981
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
982
+
983
+ Args:
984
+ attention_mask (`torch.Tensor`):
985
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
986
+ sequence_length (`int`):
987
+ The sequence length being processed.
988
+ target_length (`int`):
989
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
990
+ dtype (`torch.dtype`):
991
+ The dtype to use for the 4D attention mask.
992
+ device (`torch.device`):
993
+ The device to plcae the 4D attention mask on.
994
+ min_dtype (`float`):
995
+ The minimum value representable with the dtype `dtype`.
996
+ cache_position (`torch.Tensor`):
997
+ Indices depicting the position of the input sequence tokens in the sequence.
998
+ batch_size (`torch.Tensor`):
999
+ Batch size.
1000
+ """
1001
+ if attention_mask is not None and attention_mask.dim() == 4:
1002
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1003
+ causal_mask = attention_mask
1004
+ else:
1005
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1006
+ if sequence_length != 1:
1007
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1008
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1009
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1010
+ if attention_mask is not None:
1011
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1012
+ mask_length = attention_mask.shape[-1]
1013
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1014
+ padding_mask = padding_mask == 0
1015
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1016
+ padding_mask, min_dtype
1017
+ )
1018
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
1019
+ tree_mask = self.tree_mask
1020
+ tree_len = tree_mask.size(-1)
1021
+ causal_mask[:, :, -tree_len:, -tree_len:][
1022
+ tree_mask == 0
1023
+ ] = causal_mask.min()
1024
+ # causal_mask[:, :, -tree_len:, -tree_len:][
1025
+ # tree_mask == 1
1026
+ # ] = 0
1027
+
1028
+ return causal_mask
1029
+
1030
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1031
+ def _update_causal_mask(
1032
+ self,
1033
+ attention_mask: torch.Tensor,
1034
+ input_tensor: torch.Tensor,
1035
+ cache_position: torch.Tensor,
1036
+ past_key_values: Cache,
1037
+ output_attentions: bool,
1038
+ ):
1039
+ if self.config._attn_implementation == "flash_attention_2":
1040
+ if attention_mask is not None and 0.0 in attention_mask:
1041
+ return attention_mask
1042
+ return None
1043
+
1044
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1045
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1046
+ # to infer the attention mask.
1047
+
1048
+ past_seen_tokens = past_key_values[0][0].current_length.item() if past_key_values is not None else 0
1049
+ using_static_cache = isinstance(past_key_values, StaticCache)
1050
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1051
+
1052
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1053
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1054
+ attention_mask,
1055
+ inputs_embeds=input_tensor,
1056
+ past_key_values_length=past_seen_tokens,
1057
+ is_training=self.training,
1058
+ ):
1059
+ return None
1060
+
1061
+ dtype, device = input_tensor.dtype, input_tensor.device
1062
+ min_dtype = torch.finfo(dtype).min
1063
+ sequence_length = input_tensor.shape[1]
1064
+ if using_static_cache:
1065
+ target_length = past_key_values.get_max_length()
1066
+ else:
1067
+ target_length = (
1068
+ attention_mask.shape[-1]
1069
+ if isinstance(attention_mask, torch.Tensor)
1070
+ else past_seen_tokens + sequence_length
1071
+ )
1072
+
1073
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1074
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1075
+ attention_mask,
1076
+ sequence_length=sequence_length,
1077
+ target_length=target_length,
1078
+ dtype=dtype,
1079
+ device=device,
1080
+ min_dtype=min_dtype,
1081
+ cache_position=cache_position,
1082
+ batch_size=input_tensor.shape[0],
1083
+ )
1084
+
1085
+ if (
1086
+ self.config._attn_implementation == "sdpa"
1087
+ and attention_mask is not None
1088
+ and attention_mask.device.type == "cuda"
1089
+ and not output_attentions
1090
+ ):
1091
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1092
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1093
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1094
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1095
+
1096
+ return causal_mask
1097
+
1098
+
1099
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
1100
+ _tied_weights_keys = ["lm_head.weight"]
1101
+
1102
+ def __init__(self, config):
1103
+ super().__init__(config)
1104
+ self.model = Qwen2Model(config)
1105
+ self.vocab_size = config.vocab_size
1106
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1107
+
1108
+ # Initialize weights and apply final processing
1109
+ self.post_init()
1110
+
1111
+ def get_input_embeddings(self):
1112
+ return self.model.embed_tokens
1113
+
1114
+ def set_input_embeddings(self, value):
1115
+ self.model.embed_tokens = value
1116
+
1117
+ def get_output_embeddings(self):
1118
+ return self.lm_head
1119
+
1120
+ def set_output_embeddings(self, new_embeddings):
1121
+ self.lm_head = new_embeddings
1122
+
1123
+ def set_decoder(self, decoder):
1124
+ self.model = decoder
1125
+
1126
+ def get_decoder(self):
1127
+ return self.model
1128
+
1129
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1130
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1131
+ def forward(
1132
+ self,
1133
+ input_ids: torch.LongTensor = None,
1134
+ attention_mask: Optional[torch.Tensor] = None,
1135
+ position_ids: Optional[torch.LongTensor] = None,
1136
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1137
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1138
+ labels: Optional[torch.LongTensor] = None,
1139
+ use_cache: Optional[bool] = None,
1140
+ output_attentions: Optional[bool] = None,
1141
+ output_hidden_states: Optional[bool] = None,
1142
+ return_dict: Optional[bool] = None,
1143
+ cache_position: Optional[torch.LongTensor] = None,
1144
+ num_logits_to_keep: int = 0,
1145
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1146
+ r"""
1147
+ Args:
1148
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1149
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1150
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1151
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1152
+
1153
+ num_logits_to_keep (`int`, *optional*):
1154
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1155
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1156
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1157
+
1158
+ Returns:
1159
+
1160
+ Example:
1161
+
1162
+ ```python
1163
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1164
+
1165
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1166
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1167
+
1168
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1169
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1170
+
1171
+ >>> # Generate
1172
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1173
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1174
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1175
+ ```"""
1176
+
1177
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1178
+ output_hidden_states = (
1179
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1180
+ )
1181
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1182
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1183
+ outputs = self.model(
1184
+ input_ids=input_ids,
1185
+ attention_mask=attention_mask,
1186
+ position_ids=position_ids,
1187
+ past_key_values=past_key_values,
1188
+ inputs_embeds=inputs_embeds,
1189
+ use_cache=use_cache,
1190
+ output_attentions=output_attentions,
1191
+ output_hidden_states=output_hidden_states,
1192
+ return_dict=return_dict,
1193
+ cache_position=cache_position,
1194
+ )
1195
+
1196
+ hidden_states = outputs[0]
1197
+ if labels is None and not is_torchdynamo_compiling():
1198
+ logger.warning_once(
1199
+ "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
1200
+ )
1201
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1202
+ # TODO: remove the float() operation in v4.46
1203
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1204
+
1205
+ loss = None
1206
+ if labels is not None:
1207
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1208
+ logits = logits.float()
1209
+ # Shift so that tokens < n predict n
1210
+ shift_logits = logits[..., :-1, :].contiguous()
1211
+ shift_labels = labels[..., 1:].contiguous()
1212
+ # Flatten the tokens
1213
+ loss_fct = CrossEntropyLoss()
1214
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1215
+ shift_labels = shift_labels.view(-1)
1216
+ # Enable model parallelism
1217
+ shift_labels = shift_labels.to(shift_logits.device)
1218
+ loss = loss_fct(shift_logits, shift_labels)
1219
+
1220
+ if not return_dict:
1221
+ output = (logits,) + outputs[1:]
1222
+ return (loss,) + output if loss is not None else output
1223
+
1224
+ return CausalLMOutputWithPast(
1225
+ loss=loss,
1226
+ logits=logits,
1227
+ past_key_values=outputs.past_key_values,
1228
+ hidden_states=outputs.hidden_states,
1229
+ attentions=outputs.attentions,
1230
+ )
1231
+
1232
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1233
+ def prepare_inputs_for_generation(
1234
+ self,
1235
+ input_ids,
1236
+ past_key_values=None,
1237
+ attention_mask=None,
1238
+ inputs_embeds=None,
1239
+ cache_position=None,
1240
+ position_ids=None,
1241
+ use_cache=True,
1242
+ num_logits_to_keep=None,
1243
+ **kwargs,
1244
+ ):
1245
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1246
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1247
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1248
+ if past_key_values is not None:
1249
+ if inputs_embeds is not None: # Exception 1
1250
+ input_ids = input_ids[:, -cache_position.shape[0]:]
1251
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1252
+ input_ids = input_ids[:, cache_position]
1253
+
1254
+ if attention_mask is not None and position_ids is None:
1255
+ # create position_ids on the fly for batch generation
1256
+ position_ids = attention_mask.long().cumsum(-1) - 1
1257
+ position_ids.masked_fill_(attention_mask == 0, 1)
1258
+ if past_key_values:
1259
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1260
+
1261
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1262
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1263
+
1264
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1265
+ if inputs_embeds is not None and cache_position[0] == 0:
1266
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1267
+ else:
1268
+ # The clone here is for the same reason as for `position_ids`.
1269
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
1270
+
1271
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1272
+ if model_inputs["inputs_embeds"] is not None:
1273
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1274
+ device = model_inputs["inputs_embeds"].device
1275
+ else:
1276
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1277
+ device = model_inputs["input_ids"].device
1278
+
1279
+ dtype = self.lm_head.weight.dtype
1280
+ min_dtype = torch.finfo(dtype).min
1281
+
1282
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1283
+ attention_mask,
1284
+ sequence_length=sequence_length,
1285
+ target_length=past_key_values.get_max_length(),
1286
+ dtype=dtype,
1287
+ device=device,
1288
+ min_dtype=min_dtype,
1289
+ cache_position=cache_position,
1290
+ batch_size=batch_size,
1291
+ )
1292
+
1293
+ if num_logits_to_keep is not None:
1294
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
1295
+
1296
+ model_inputs.update(
1297
+ {
1298
+ "position_ids": position_ids,
1299
+ "cache_position": cache_position,
1300
+ "past_key_values": past_key_values,
1301
+ "use_cache": use_cache,
1302
+ "attention_mask": attention_mask,
1303
+ }
1304
+ )
1305
+ return model_inputs
1306
+
1307
+
1308
+ @add_start_docstrings(
1309
+ """
1310
+ The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1311
+
1312
+ [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1313
+ (e.g. GPT-2) do.
1314
+
1315
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1316
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1317
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1318
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1319
+ each row of the batch).
1320
+ """,
1321
+ QWEN2_START_DOCSTRING,
1322
+ )
1323
+ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1324
+ def __init__(self, config):
1325
+ super().__init__(config)
1326
+ self.num_labels = config.num_labels
1327
+ self.model = Qwen2Model(config)
1328
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1329
+
1330
+ # Initialize weights and apply final processing
1331
+ self.post_init()
1332
+
1333
+ def get_input_embeddings(self):
1334
+ return self.model.embed_tokens
1335
+
1336
+ def set_input_embeddings(self, value):
1337
+ self.model.embed_tokens = value
1338
+
1339
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1340
+ def forward(
1341
+ self,
1342
+ input_ids: torch.LongTensor = None,
1343
+ attention_mask: Optional[torch.Tensor] = None,
1344
+ position_ids: Optional[torch.LongTensor] = None,
1345
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1346
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1347
+ labels: Optional[torch.LongTensor] = None,
1348
+ use_cache: Optional[bool] = None,
1349
+ output_attentions: Optional[bool] = None,
1350
+ output_hidden_states: Optional[bool] = None,
1351
+ return_dict: Optional[bool] = None,
1352
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1353
+ r"""
1354
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1355
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1356
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1357
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1358
+ """
1359
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1360
+
1361
+ transformer_outputs = self.model(
1362
+ input_ids,
1363
+ attention_mask=attention_mask,
1364
+ position_ids=position_ids,
1365
+ past_key_values=past_key_values,
1366
+ inputs_embeds=inputs_embeds,
1367
+ use_cache=use_cache,
1368
+ output_attentions=output_attentions,
1369
+ output_hidden_states=output_hidden_states,
1370
+ return_dict=return_dict,
1371
+ )
1372
+ hidden_states = transformer_outputs[0]
1373
+ logits = self.score(hidden_states)
1374
+
1375
+ if input_ids is not None:
1376
+ batch_size = input_ids.shape[0]
1377
+ else:
1378
+ batch_size = inputs_embeds.shape[0]
1379
+
1380
+ if self.config.pad_token_id is None and batch_size != 1:
1381
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1382
+ if self.config.pad_token_id is None:
1383
+ sequence_lengths = -1
1384
+ else:
1385
+ if input_ids is not None:
1386
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1387
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1388
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1389
+ sequence_lengths = sequence_lengths.to(logits.device)
1390
+ else:
1391
+ sequence_lengths = -1
1392
+
1393
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1394
+
1395
+ loss = None
1396
+ if labels is not None:
1397
+ labels = labels.to(logits.device)
1398
+ if self.config.problem_type is None:
1399
+ if self.num_labels == 1:
1400
+ self.config.problem_type = "regression"
1401
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1402
+ self.config.problem_type = "single_label_classification"
1403
+ else:
1404
+ self.config.problem_type = "multi_label_classification"
1405
+
1406
+ if self.config.problem_type == "regression":
1407
+ loss_fct = MSELoss()
1408
+ if self.num_labels == 1:
1409
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1410
+ else:
1411
+ loss = loss_fct(pooled_logits, labels)
1412
+ elif self.config.problem_type == "single_label_classification":
1413
+ loss_fct = CrossEntropyLoss()
1414
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1415
+ elif self.config.problem_type == "multi_label_classification":
1416
+ loss_fct = BCEWithLogitsLoss()
1417
+ loss = loss_fct(pooled_logits, labels)
1418
+ if not return_dict:
1419
+ output = (pooled_logits,) + transformer_outputs[1:]
1420
+ return ((loss,) + output) if loss is not None else output
1421
+
1422
+ return SequenceClassifierOutputWithPast(
1423
+ loss=loss,
1424
+ logits=pooled_logits,
1425
+ past_key_values=transformer_outputs.past_key_values,
1426
+ hidden_states=transformer_outputs.hidden_states,
1427
+ attentions=transformer_outputs.attentions,
1428
+ )
1429
+
1430
+
1431
+ @add_start_docstrings(
1432
+ """
1433
+ The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1434
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1435
+ """,
1436
+ QWEN2_START_DOCSTRING,
1437
+ )
1438
+ # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
1439
+ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
1440
+ def __init__(self, config):
1441
+ super().__init__(config)
1442
+ self.num_labels = config.num_labels
1443
+ self.model = Qwen2Model(config)
1444
+ if getattr(config, "classifier_dropout", None) is not None:
1445
+ classifier_dropout = config.classifier_dropout
1446
+ elif getattr(config, "hidden_dropout", None) is not None:
1447
+ classifier_dropout = config.hidden_dropout
1448
+ else:
1449
+ classifier_dropout = 0.1
1450
+ self.dropout = nn.Dropout(classifier_dropout)
1451
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1452
+
1453
+ # Initialize weights and apply final processing
1454
+ self.post_init()
1455
+
1456
+ def get_input_embeddings(self):
1457
+ return self.model.embed_tokens
1458
+
1459
+ def set_input_embeddings(self, value):
1460
+ self.model.embed_tokens = value
1461
+
1462
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1463
+ def forward(
1464
+ self,
1465
+ input_ids: Optional[torch.LongTensor] = None,
1466
+ attention_mask: Optional[torch.Tensor] = None,
1467
+ position_ids: Optional[torch.LongTensor] = None,
1468
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1469
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1470
+ labels: Optional[torch.LongTensor] = None,
1471
+ use_cache: Optional[bool] = None,
1472
+ output_attentions: Optional[bool] = None,
1473
+ output_hidden_states: Optional[bool] = None,
1474
+ return_dict: Optional[bool] = None,
1475
+ ) -> Union[Tuple, TokenClassifierOutput]:
1476
+ r"""
1477
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1478
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1479
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1480
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1481
+ """
1482
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1483
+
1484
+ outputs = self.model(
1485
+ input_ids,
1486
+ attention_mask=attention_mask,
1487
+ position_ids=position_ids,
1488
+ past_key_values=past_key_values,
1489
+ inputs_embeds=inputs_embeds,
1490
+ use_cache=use_cache,
1491
+ output_attentions=output_attentions,
1492
+ output_hidden_states=output_hidden_states,
1493
+ return_dict=return_dict,
1494
+ )
1495
+ sequence_output = outputs[0]
1496
+ sequence_output = self.dropout(sequence_output)
1497
+ logits = self.score(sequence_output)
1498
+
1499
+ loss = None
1500
+ if labels is not None:
1501
+ loss_fct = CrossEntropyLoss()
1502
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1503
+
1504
+ if not return_dict:
1505
+ output = (logits,) + outputs[2:]
1506
+ return ((loss,) + output) if loss is not None else output
1507
+
1508
+ return TokenClassifierOutput(
1509
+ loss=loss,
1510
+ logits=logits,
1511
+ hidden_states=outputs.hidden_states,
1512
+ attentions=outputs.attentions,
1513
+ )
eagle/model/utils.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+
4
+ # typing
5
+ from typing import List, Tuple
6
+ import time
7
+ import torch
8
+
9
+ # TODO
10
+ # from transformers import LlamaTokenizer
11
+ # tokenizer=LlamaTokenizer.from_pretrained("/home/lyh/weights/hf/vicuna_v13/7B/")
12
+
13
+ TOPK = 10 # topk for sparse tree
14
+
15
+ from transformers.generation.logits_process import (
16
+ LogitsProcessorList,
17
+ RepetitionPenaltyLogitsProcessor,
18
+ TemperatureLogitsWarper,
19
+ TopKLogitsWarper,
20
+ TopPLogitsWarper,
21
+ )
22
+
23
+
24
+ class Timer:
25
+ def __init__(self,name):
26
+ self.name = name
27
+ def __enter__(self):
28
+ torch.cuda.synchronize()
29
+ self.start = time.perf_counter()
30
+
31
+
32
+ def __exit__(self, exc_type, exc_value, traceback):
33
+ torch.cuda.synchronize()
34
+ elapsed = time.perf_counter() - self.start
35
+ print(f'{self.name} took {elapsed} seconds')
36
+
37
+
38
+ def prepare_logits_processor(
39
+ temperature: float = 0.0,
40
+ repetition_penalty: float = 0.0,
41
+ top_p: float = 0.0,
42
+ top_k: int = 0
43
+ ) -> LogitsProcessorList:
44
+ processor_list = LogitsProcessorList()
45
+ if temperature > 1e-5:
46
+ if temperature >= 1e-5 and temperature != 1.0:
47
+ processor_list.append(TemperatureLogitsWarper(temperature))
48
+ if repetition_penalty > 1.0:
49
+ processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
50
+ if 1e-8 <= top_p < 1.0:
51
+ processor_list.append(TopPLogitsWarper(top_p))
52
+ if top_k > 0:
53
+ processor_list.append(TopKLogitsWarper(top_k))
54
+ return processor_list
55
+
56
+
57
+ # test_processor = prepare_logits_processor(
58
+ # 0.0, 0.0, -1, 1
59
+ # )
60
+
61
+
62
+ def pad_path(path: List[int], length: int, pad_value: int = -2) -> List[int]:
63
+ """
64
+ Pad the given path list with a specific value up to a specified length.
65
+
66
+ Parameters:
67
+ - path (list): The original list that needs padding.
68
+ - length (int): The desired length of the padded list.
69
+ - pad_value (optional, default=-2): The value to use for padding.
70
+
71
+ Returns:
72
+ - list: A new list based on the original path but padded to the desired length.
73
+
74
+ Example:
75
+ >>> pad_path([1,2,3], 5)
76
+ [1, 2, 3, -2, -2]
77
+
78
+ Note:
79
+ If the given path is already longer than the specified length,
80
+ then no padding occurs, and the original path is returned.
81
+ """
82
+
83
+ # Calculate the number of padding values needed by subtracting the length
84
+ # of the path from the desired length.
85
+ # Append the padding values to the original path and return the new list.
86
+ return path + [pad_value] * (length - len(path))
87
+
88
+
89
+ def generate_tree_buffers(tree_choices, device="cuda"):
90
+ def custom_sort(lst):
91
+ # sort_keys=[len(list)]
92
+ sort_keys = []
93
+ for i in range(len(lst)):
94
+ sort_keys.append(lst[i] if lst[i] >= 0 else maxitem)
95
+ return sort_keys
96
+ with Timer("sort"):
97
+
98
+ sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x))
99
+ tree_len = len(sorted_tree_choices) + 1
100
+
101
+ # Initialize depth_counts to keep track of how many choices have a particular depth
102
+ depth_counts = []
103
+ prev_depth = 0
104
+ for path in sorted_tree_choices:
105
+ depth = len(path)
106
+ if depth != prev_depth:
107
+ depth_counts.append(0)
108
+ depth_counts[depth - 1] += 1
109
+ prev_depth = depth
110
+
111
+ tree_attn_mask = torch.eye(tree_len, tree_len)
112
+ tree_attn_mask[:, 0] = 1
113
+ start = 0
114
+ for i in range(len(depth_counts)):
115
+ for j in range(depth_counts[i]):
116
+ cur_tree_choice = sorted_tree_choices[start + j]
117
+ # retrieve ancestor position
118
+ if len(cur_tree_choice) == 1:
119
+ continue
120
+ ancestor_idx = []
121
+ for c in range(len(cur_tree_choice) - 1):
122
+ ancestor_idx.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
123
+ tree_attn_mask[j + start + 1, ancestor_idx] = 1
124
+ start += depth_counts[i]
125
+
126
+ tree_indices = torch.zeros(tree_len, dtype=torch.long)
127
+ p_indices = [0 for _ in range(tree_len - 1)]
128
+ b_indices = [[] for _ in range(tree_len - 1)]
129
+ tree_indices[0] = 0
130
+ start = 0
131
+ bias = 0
132
+ for i in range(len(depth_counts)):
133
+ inlayer_bias = 0
134
+ b = []
135
+ for j in range(depth_counts[i]):
136
+ cur_tree_choice = sorted_tree_choices[start + j]
137
+ cur_parent = cur_tree_choice[:-1]
138
+ if j != 0:
139
+ if cur_parent != parent:
140
+ bias += 1
141
+ inlayer_bias += 1
142
+ parent = cur_parent
143
+ b = []
144
+ else:
145
+ parent = cur_parent
146
+ tree_indices[start + j + 1] = cur_tree_choice[-1] + TOPK * (i + bias) + 1
147
+ p_indices[start + j] = inlayer_bias
148
+ if len(b) > 0:
149
+ b_indices[start + j] = copy.deepcopy(b)
150
+ else:
151
+ b_indices[start + j] = []
152
+ b.append(cur_tree_choice[-1] + TOPK * (i + bias) + 1)
153
+ start += depth_counts[i]
154
+
155
+ p_indices = [-1] + p_indices
156
+ tree_position_ids = torch.zeros(tree_len, dtype=torch.long)
157
+ start = 0
158
+ for i in range(len(depth_counts)):
159
+ tree_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
160
+ start += depth_counts[i]
161
+
162
+ retrieve_indices_nest = []
163
+ retrieve_paths = []
164
+ for i in range(len(sorted_tree_choices)):
165
+ cur_tree_choice = sorted_tree_choices[-i - 1]
166
+ retrieve_indice = []
167
+ if cur_tree_choice in retrieve_paths:
168
+ continue
169
+ else:
170
+ for c in range(len(cur_tree_choice)):
171
+ retrieve_indice.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]))
172
+ retrieve_paths.append(cur_tree_choice[:c + 1])
173
+ retrieve_indices_nest.append(retrieve_indice)
174
+ max_length = max([len(x) for x in retrieve_indices_nest])
175
+ retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
176
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
177
+ retrieve_indices = retrieve_indices + 1
178
+ retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices],
179
+ dim=1)
180
+
181
+ maxitem = retrieve_indices.max().item() + 5
182
+
183
+
184
+
185
+ retrieve_indices = retrieve_indices.tolist()
186
+ retrieve_indices = sorted(retrieve_indices, key=custom_sort)
187
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
188
+
189
+
190
+
191
+ # Aggregate the generated buffers into a dictionary
192
+ tree_buffers = {
193
+ "tree_attn_mask": tree_attn_mask.unsqueeze(0).unsqueeze(0),
194
+ "tree_indices": tree_indices,
195
+ "tree_position_ids": tree_position_ids,
196
+ "retrieve_indices": retrieve_indices,
197
+ }
198
+
199
+ # Move the tensors in the dictionary to the specified device
200
+ tree_buffers = {
201
+ k: v.clone().to(device)
202
+ if isinstance(v, torch.Tensor)
203
+ else torch.tensor(v, device=device)
204
+ for k, v in tree_buffers.items()
205
+ }
206
+
207
+ return tree_buffers
208
+
209
+
210
+ def initialize_tree0(input_ids, model, past_key_values, logits_processor):
211
+ draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, logits, hidden_state, sample_token = model(
212
+ input_ids, past_key_values=past_key_values, output_orig=True, logits_processor=logits_processor
213
+ )
214
+
215
+ # if logits_processor is not None:
216
+ # logits = orig[:, -1]
217
+ # logits = logits_processor(None, logits)
218
+ # probabilities = torch.nn.functional.softmax(logits, dim=1)
219
+ # token = torch.multinomial(probabilities, 1)
220
+ # else:
221
+ # token = torch.argmax(orig[:, -1])
222
+ # token = token[None, None]
223
+ # input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
224
+ # # Clone the output hidden states
225
+ #
226
+ # draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head)
227
+ # if output_orig:
228
+ # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token
229
+ # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token
230
+ return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token
231
+
232
+ def initialize_tree(input_ids, model, past_key_values, logits_processor):
233
+ outputs, orig, hidden_states = model(
234
+ input_ids, past_key_values=past_key_values, output_orig=True
235
+ )
236
+
237
+ if logits_processor is not None:
238
+ logits = orig[:, -1]
239
+ logits = logits_processor(None, logits)
240
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
241
+ token = torch.multinomial(probabilities, 1)
242
+ else:
243
+ token = torch.argmax(orig[:, -1])
244
+ token = token[None, None]
245
+ input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
246
+
247
+ # Clone the output hidden states
248
+ if model.use_eagle3:
249
+ ea_device = model.ea_layer.lm_head.weight.device
250
+ if outputs["hidden_states"][0].device != ea_device:
251
+ outputs["hidden_states"] = [x.to(ea_device) for x in outputs["hidden_states"]]
252
+ hidden_states=torch.cat(outputs["hidden_states"],dim=-1)
253
+ draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor)
254
+ return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token
255
+
256
+
257
+ def reset_tree_mode(
258
+ model,
259
+ ):
260
+ model.base_model.model.tree_mask = None
261
+ model.base_model.model.tree_mode = None
262
+
263
+
264
+ def reset_past_key_values(passed_key_values: List[torch.Tensor]) -> List[torch.Tensor]:
265
+ """
266
+ Resets the current lengths in the passed key-values to zero.
267
+
268
+ This function is designed to be used during the evaluation of a baseline model.
269
+ It iterates through each layer's key-values and sets their current lengths to zero,
270
+ effectively resetting their state.
271
+
272
+ Args:
273
+ - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
274
+
275
+ Returns:
276
+ - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
277
+ """
278
+ for i in range(len(passed_key_values)):
279
+ for j in range(2):
280
+ passed_key_values[i][j].current_length.fill_(0)
281
+ return passed_key_values
282
+
283
+
284
+ def generate_candidates(tree_logits, tree_indices, retrieve_indices, sample_token, logits_processor):
285
+ sample_token = sample_token.to(tree_indices.device)
286
+
287
+ candidates_logit = sample_token[0]
288
+
289
+ candidates_tree_logits = tree_logits
290
+
291
+ candidates = torch.cat([candidates_logit, candidates_tree_logits.view(-1)], dim=-1)
292
+
293
+ tree_candidates = candidates[tree_indices]
294
+
295
+ tree_candidates_ext = torch.cat(
296
+ [tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device) - 1], dim=0)
297
+
298
+ cart_candidates = tree_candidates_ext[retrieve_indices]
299
+
300
+
301
+ # Unsqueeze the tree candidates for dimension consistency.
302
+ tree_candidates = tree_candidates.unsqueeze(0)
303
+ return cart_candidates, tree_candidates
304
+
305
+
306
+ def tree_decoding(
307
+ model,
308
+ tree_candidates,
309
+ past_key_values,
310
+ tree_position_ids,
311
+ input_ids,
312
+ retrieve_indices,
313
+ ):
314
+ position_ids = tree_position_ids + input_ids.shape[1]
315
+ if position_ids is not None and position_ids.dim() == 1:
316
+ position_ids = position_ids.unsqueeze(0)
317
+ outputs, tree_logits, hidden_state = model(
318
+ tree_candidates,
319
+ output_orig=True,
320
+ past_key_values=past_key_values,
321
+ position_ids=position_ids,
322
+ )
323
+
324
+ if model.use_eagle3:
325
+ ea_device = model.ea_layer.lm_head.weight.device
326
+ if outputs["hidden_states"][0].device != ea_device:
327
+ outputs["hidden_states"] = [x.to(ea_device) for x in outputs["hidden_states"]]
328
+ hidden_state = torch.cat(outputs["hidden_states"], dim=-1)
329
+
330
+ logits = tree_logits[0, retrieve_indices]
331
+ return logits, hidden_state, outputs
332
+
333
+
334
+
335
+
336
+
337
+ def evaluate_posterior(
338
+ logits: torch.Tensor,
339
+ candidates: torch.Tensor,
340
+ logits_processor,
341
+ ):
342
+ """
343
+ Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.
344
+
345
+ Depending on the temperature value, the function either uses greedy decoding or evaluates posterior
346
+ probabilities to select the best candidate.
347
+
348
+ Args:
349
+ - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
350
+ - candidates (torch.Tensor): Candidate token sequences.
351
+ - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
352
+ - posterior_threshold (float): Threshold for posterior probability.
353
+ - posterior_alpha (float): Scaling factor for the threshold.
354
+
355
+ Returns:
356
+ - best_candidate (torch.Tensor): Index of the chosen best candidate.
357
+ - accept_length (int): Length of the accepted candidate sequence.
358
+ """
359
+ # Greedy decoding based on temperature value
360
+ if logits_processor is None:
361
+ # Find the tokens that match the maximum logits for each position in the sequence
362
+ posterior_mask = (
363
+ candidates[:, 1:].to(logits.device) == torch.argmax(logits[:, :-1], dim=-1)
364
+ ).int()
365
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
366
+ accept_length = candidates_accept_length.max()
367
+ # Choose the best candidate
368
+ if accept_length == 0:
369
+ # Default to the first candidate if none are accepted
370
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
371
+ else:
372
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
373
+ return best_candidate, accept_length, logits[best_candidate, accept_length]
374
+
375
+ else:
376
+ accept_length = 1
377
+ accept_cand = candidates[0][:1]
378
+ best_candidate = 0
379
+ for i in range(1, candidates.shape[1]):
380
+ if i != accept_length:
381
+ break
382
+ adjustflag = False
383
+ is_eq = (candidates[:, :accept_length] == accept_cand).all(dim=1)
384
+ fi = torch.nonzero(is_eq, as_tuple=True)[0][0]
385
+ gt_logits = logits[fi, i - 1][None]
386
+ gt_logits = logits_processor(None, gt_logits)[0]
387
+ gtp = torch.softmax(gt_logits, dim=0)
388
+ candidates_set = []
389
+ for j in range(candidates.shape[0]):
390
+ if is_eq[j]:
391
+ x = candidates[j, i]
392
+ xi = x.item()
393
+ if xi in candidates_set or xi == -1:
394
+ continue
395
+ candidates_set.append(xi)
396
+ r = random.random()
397
+ px = gtp[xi]
398
+ qx = 1.0
399
+ acp = px / qx
400
+ if r <= acp:
401
+ accept_cand = torch.cat((accept_cand, x[None]), dim=0)
402
+ accept_length += 1
403
+ best_candidate = j
404
+ break
405
+ else:
406
+ gtp[xi] = 0
407
+ gtp = gtp / gtp.sum()
408
+ adjustflag = True
409
+ if adjustflag and accept_length != candidates.shape[1]:
410
+ sample_p = gtp
411
+ else:
412
+ gt_logits = logits[best_candidate, accept_length - 1][None]
413
+ gt_logits = logits_processor(None, gt_logits)[0]
414
+ sample_p = torch.softmax(gt_logits, dim=0)
415
+ return torch.tensor(best_candidate), accept_length - 1, sample_p
416
+
417
+
418
+ @torch.no_grad()
419
+ def update_inference_inputs(
420
+ input_ids,
421
+ candidates,
422
+ best_candidate,
423
+ accept_length,
424
+ retrieve_indices,
425
+ logits_processor,
426
+ new_token,
427
+ past_key_values_data_list,
428
+ current_length_data,
429
+ model,
430
+ hidden_state_new,
431
+ sample_p
432
+ ):
433
+ prev_input_len = input_ids.shape[1]
434
+ # Map the best candidate indices to the original indices in the sequence
435
+ select_indices = (
436
+ retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len
437
+ )
438
+ # Append the tokens from the best candidate to the input sequence
439
+ input_ids = torch.cat(
440
+ [input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1
441
+ )
442
+ # Update the past key values based on the selected tokens
443
+ # Source tensor that contains relevant past information based on the selected candidate
444
+ for past_key_values_data in past_key_values_data_list:
445
+ tgt = past_key_values_data[..., select_indices.to(past_key_values_data.device), :]
446
+ # Destination tensor where the relevant past information will be stored
447
+ dst = past_key_values_data[..., prev_input_len: prev_input_len + tgt.shape[-2], :]
448
+ # Copy relevant past information from the source to the destination
449
+ dst.copy_(tgt, non_blocking=True)
450
+
451
+ # Update the current length tensor (currently only support batch size is 1)
452
+ current_length_data.fill_(prev_input_len + tgt.shape[-2])
453
+
454
+ retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices]
455
+ accept_hidden_state_new = retrieve_hidden_state_new[:, best_candidate, : accept_length + 1]
456
+ # token=model.base_model.lm_head(accept_hidden_state_new[:,-1]).argmax()
457
+ # token=token[None,None]
458
+ prob = sample_p
459
+ if logits_processor is not None:
460
+ token = torch.multinomial(prob, 1)
461
+ token = token[None]
462
+ else:
463
+ token = torch.argmax(prob)
464
+ token = token[None, None]
465
+ # hidden_state = torch.cat((hidden_state, accept_hidden_state_new), dim=1)
466
+ draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new,
467
+ input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1),
468
+ head=model.base_model.lm_head,logits_processor=logits_processor)
469
+
470
+
471
+ new_token += accept_length + 1
472
+
473
+ return input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, None, token
474
+
475
+
476
+ if __name__ == "__main__":
477
+ logits = torch.randn(1, 5)
478
+ tp = prepare_logits_processor(0.9, 0, 0.9, 0)
479
+ l = tp(None, logits)
480
+ if tp is None:
481
+ print(tp)
eagle/model/utils_c.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # typing
4
+ from typing import List
5
+
6
+ TOPK = 10 # topk for sparse tree
7
+
8
+
9
+ def pad_path(path: List[int], length: int, pad_value: int = -2) -> List[int]:
10
+ """
11
+ Pad the given path list with a specific value up to a specified length.
12
+
13
+ Parameters:
14
+ - path (list): The original list that needs padding.
15
+ - length (int): The desired length of the padded list.
16
+ - pad_value (optional, default=-2): The value to use for padding.
17
+
18
+ Returns:
19
+ - list: A new list based on the original path but padded to the desired length.
20
+
21
+ Example:
22
+ >>> pad_path([1,2,3], 5)
23
+ [1, 2, 3, -2, -2]
24
+
25
+ Note:
26
+ If the given path is already longer than the specified length,
27
+ then no padding occurs, and the original path is returned.
28
+ """
29
+
30
+ # Calculate the number of padding values needed by subtracting the length
31
+ # of the path from the desired length.
32
+ # Append the padding values to the original path and return the new list.
33
+ return path + [pad_value] * (length - len(path))
34
+
35
+ class node:
36
+ def __init__(self,parent=None,value=None,dict_key=None):
37
+ self.parent=parent
38
+ self.value=value
39
+ if parent:
40
+ self.depth=parent.depth+1
41
+ parent.children.append(self)
42
+ else:
43
+ self.depth=0
44
+ self.children=[]
45
+ self.dict_key=dict_key
46
+ def is_leaf(self):
47
+ return len(self.children)==0
48
+
49
+ def all_index(self):
50
+ if not self.parent.parent:
51
+ return [self.index]
52
+ else:
53
+ return self.parent.all_index()+[self.index]
54
+
55
+
56
+
57
+ class Tree:
58
+ def __init__(self,tree_list):
59
+ sorted_tree_list = sorted(tree_list, key=lambda x: (len(x), x))
60
+ self.root=node()
61
+ self.node_dic={}
62
+ for tree_node in sorted_tree_list:
63
+ cur_value=tree_node[-1]
64
+ if len(tree_node)==1:
65
+ cur_node=node(parent=self.root,value=cur_value,dict_key=tuple(tree_node))
66
+ else:
67
+ cur_parent=self.node_dic[tuple(tree_node[:-1])]
68
+ cur_node = node(parent=cur_parent, value=cur_value,dict_key=tuple(tree_node))
69
+ self.node_dic[tuple(tree_node)] = cur_node
70
+ self.indexnode()
71
+
72
+ def max_depth(self):
73
+ return max([item.depth for item in self.node_dic.values()])
74
+
75
+ def num_node_wchild(self):
76
+ num_c=0
77
+ for item in self.node_dic.values():
78
+ if not item.is_leaf():
79
+ num_c+=1
80
+ return num_c
81
+
82
+ def get_node_wchild(self):
83
+ ns=[]
84
+ for item in self.node_dic.values():
85
+ if not item.is_leaf():
86
+ ns.append(item)
87
+ return ns
88
+
89
+ def indexnode(self):
90
+ cur_index=0
91
+ for key in self.node_dic:
92
+ cur_node=self.node_dic[key]
93
+ if not cur_node.is_leaf():
94
+ cur_node.index=cur_index
95
+ cur_index+=1
96
+
97
+
98
+
99
+
100
+ def generate_tree_buffers(tree_choices, device="cuda"):
101
+ tree=Tree(tree_choices)
102
+ sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x))
103
+ tree_len = tree.num_node_wchild()
104
+
105
+
106
+ max_depth=tree.max_depth()
107
+ nodes_wc=tree.get_node_wchild()
108
+
109
+ depth_counts=[0 for _ in range(max_depth-1)]
110
+ for x in nodes_wc:
111
+ depth_counts[x.depth-1]+=1
112
+ depth_counts_sum = [sum(depth_counts[:i + 1]) for i in range(len(depth_counts))]
113
+
114
+
115
+ tree_attn_mask = torch.eye(tree_len, tree_len)
116
+
117
+ for id,x in enumerate(nodes_wc):
118
+ tree_attn_mask[id,x.all_index()]=1
119
+
120
+
121
+
122
+
123
+ tree_attn_mask_list0=[tree_attn_mask[:ml,:ml] for ml in depth_counts_sum]
124
+ tree_attn_mask_list=[]
125
+ for id,x in enumerate(tree_attn_mask_list0):
126
+ x=x[-depth_counts[id]:]
127
+ tree_attn_mask_list.append(x)
128
+
129
+
130
+
131
+ tree_indices_list = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts]
132
+ repeat_nums=[[] for _ in depth_counts]
133
+ start = 0
134
+ bias = 0
135
+ for i in range(len(depth_counts)):
136
+ bias = 0
137
+ repeat_j=0
138
+ for j in range(depth_counts[i]):
139
+ cur_node = nodes_wc[start + j]
140
+ cur_parent = cur_node.parent
141
+
142
+ if j != 0:
143
+ if cur_parent != parent:
144
+ bias += 1
145
+ parent = cur_parent
146
+ repeat_nums[i].append(j-repeat_j)
147
+ repeat_j=j
148
+ else:
149
+ parent = cur_parent
150
+ tree_indices_list[i][j] = cur_node.value + TOPK * (bias)
151
+ repeat_nums[i].append(j - repeat_j+1)
152
+ start += depth_counts[i]
153
+
154
+ position_ids = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts]
155
+
156
+ # start = 0
157
+ # for i in range(len(depth_counts)):
158
+ # position_ids[start: start + depth_counts[i]] = i
159
+ # start += depth_counts[i]
160
+
161
+ tree_buffers = {
162
+ "attn_mask": [i.unsqueeze(0).unsqueeze(0) for i in tree_attn_mask_list],
163
+ "tree_indices": tree_indices_list,
164
+ "position_ids":position_ids,
165
+ "repeat_nums":repeat_nums
166
+ }
167
+
168
+ # Move the tensors in the dictionary to the specified device
169
+ tree_buffers = {
170
+ k: [i.clone().to(device) for i in v]
171
+ if isinstance(v[0], torch.Tensor)
172
+ else (
173
+ torch.tensor(v, device=device)
174
+ if isinstance(v, torch.Tensor)
175
+ else v
176
+ )
177
+ for k, v in tree_buffers.items()
178
+ }
179
+ return tree_buffers
180
+
181
+
182
+ def reset_past_key_values(passed_key_values: List[torch.Tensor]) -> List[torch.Tensor]:
183
+ """
184
+ Resets the current lengths in the passed key-values to zero.
185
+
186
+ This function is designed to be used during the evaluation of a baseline model.
187
+ It iterates through each layer's key-values and sets their current lengths to zero,
188
+ effectively resetting their state.
189
+
190
+ Args:
191
+ - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
192
+
193
+ Returns:
194
+ - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
195
+ """
196
+ for i in range(len(passed_key_values)):
197
+ for j in range(2):
198
+ passed_key_values[i][j].current_length.fill_(0)
199
+ return passed_key_values
200
+
201
+
202
+
203
+ if __name__=="__main__":
204
+ from choices import mc_sim_7b_63
205
+ a=generate_tree_buffers(mc_sim_7b_63)
206
+ print(a)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ git+https://github.com/huggingface/transformers
3
+ torch
4
+ spaces
5
+ accelerate
6
+ tokenizers
7
+ numpy
8
+ Pillow
9
+ requests
10
+ sentencepiece
11
+ flash-attn
utils_chatbot.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def organize_messages(message, history):
3
+ msg_ls = [dict(
4
+ role = "system",
5
+ content = "You are a helpful assistant.",
6
+ )]
7
+ for user, assistant in history:
8
+ msg_ls.append(dict(
9
+ role = "user",
10
+ content = user,
11
+ ))
12
+ if assistant:
13
+ msg_ls.append(dict(
14
+ role = "assistant",
15
+ content = assistant,
16
+ ))
17
+ msg_ls.append(dict(
18
+ role = "user",
19
+ content = message,
20
+ ))
21
+ return msg_ls
22
+
23
+ def stream2display_text(stream_text, token_per_sec):
24
+ if stream_text.startswith("think>"):
25
+ stream_text = f"<{stream_text}"
26
+
27
+ if not stream_text.startswith("<think>"):
28
+ return stream_text
29
+
30
+ if not "</think>" in stream_text:
31
+ think_text, result_text = stream_text.replace("<think>", ""), ""
32
+ else:
33
+ think_text, result_text = stream_text.split("</think>")
34
+ think_text = think_text.replace("<think>", "")
35
+
36
+ result_text = result_text.replace("<|im_end|>", "")
37
+
38
+ think_block = "\n".join(f"> {line}" if line else ">" for line in think_text.rstrip().splitlines())
39
+ # display_text = f"{think_block}\n\n{result_text}"
40
+
41
+ display_text_ls = [think_block]
42
+ if result_text:
43
+ display_text_ls.append(f"{result_text}")
44
+ display_text_ls.append(f"```{token_per_sec:.2f} token/s```")
45
+
46
+ display_text = "\n\n".join(display_text_ls)
47
+
48
+ return display_text
49
+
50
+ def mtp_new_tokens(pred_ids, gen_tk_count, existing_tk_count, stop_token_ids):
51
+ output_ids = pred_ids[0][existing_tk_count:]
52
+
53
+ if stop_token_ids:
54
+ stop_token_ids_index = [
55
+ i
56
+ for i, id in enumerate(output_ids)
57
+ if id in stop_token_ids
58
+ ]
59
+ if len(stop_token_ids_index) > 0:
60
+ output_ids = output_ids[: stop_token_ids_index[0]]
61
+ new_tokens = output_ids[gen_tk_count:]
62
+
63
+ return new_tokens, len(output_ids)