yuhuili commited on
Commit
8362bbb
1 Parent(s): 24fa4e6
app.py CHANGED
@@ -1,67 +1,327 @@
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- from transformers import LlamaForCausalLM
5
- model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
6
- print(model)
7
-
8
- """
9
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
- """
11
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
12
-
13
-
14
- def respond(
15
- message,
16
- history: list[tuple[str, str]],
17
- system_message,
18
- max_tokens,
19
- temperature,
20
- top_p,
21
- ):
22
- messages = [{"role": "system", "content": system_message}]
23
-
24
- for val in history:
25
- if val[0]:
26
- messages.append({"role": "user", "content": val[0]})
27
- if val[1]:
28
- messages.append({"role": "assistant", "content": val[1]})
29
-
30
- messages.append({"role": "user", "content": message})
31
-
32
- response = ""
33
-
34
- for message in client.chat_completion(
35
- messages,
36
- max_tokens=max_tokens,
37
- stream=True,
38
- temperature=temperature,
39
- top_p=top_p,
40
- ):
41
- token = message.choices[0].delta.content
42
-
43
- response += token
44
- yield response
45
-
46
- """
47
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
48
- """
49
- demo = gr.ChatInterface(
50
- respond,
51
- additional_inputs=[
52
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
53
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
54
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
55
- gr.Slider(
56
- minimum=0.1,
57
- maximum=1.0,
58
- value=0.95,
59
- step=0.05,
60
- label="Top-p (nucleus sampling)",
61
- ),
62
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- if __name__ == "__main__":
67
- demo.launch()
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
  import gradio as gr
5
+ import argparse
6
+ try:
7
+ from ..model.ea_model import EaModel
8
+ except:
9
+ from eagle.model.ea_model import EaModel
10
+ import torch
11
+ from fastchat.model import get_conversation_template
12
+ import re
13
+
14
+
15
+ def truncate_list(lst, num):
16
+ if num not in lst:
17
+ return lst
18
+
19
+
20
+ first_index = lst.index(num)
21
+
22
+
23
+ return lst[:first_index + 1]
24
+
25
+
26
+
27
+
28
+
29
+ def find_list_markers(text):
30
+
31
+ pattern = re.compile(r'(?m)(^\d+\.\s|\n)')
32
+ matches = pattern.finditer(text)
33
+
34
+
35
+ return [(match.start(), match.end()) for match in matches]
36
+
37
+
38
+ def checkin(pointer,start,marker):
39
+ for b,e in marker:
40
+ if b<=pointer<e:
41
+ return True
42
+ if b<=start<e:
43
+ return True
44
+ return False
45
+
46
+ def highlight_text(text, text_list,color="black"):
47
+
48
+ pointer = 0
49
+ result = ""
50
+ markers=find_list_markers(text)
51
+
52
+
53
+ for sub_text in text_list:
54
+
55
+ start = text.find(sub_text, pointer)
56
+ if start==-1:
57
+ continue
58
+ end = start + len(sub_text)
59
+
60
+
61
+ if checkin(pointer,start,markers):
62
+ result += text[pointer:start]
63
+ else:
64
+ result += f"<span style='color: {color};'>{text[pointer:start]}</span>"
65
+
66
+ result += sub_text
67
+
68
+ pointer = end
69
+
70
+ if pointer < len(text):
71
+ result += f"<span style='color: {color};'>{text[pointer:]}</span>"
72
+
73
+ return result
74
+
75
+
76
+ def warmup(model):
77
+ conv = get_conversation_template(args.model_type)
78
+
79
+ if args.model_type == "llama-2-chat":
80
+ sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
81
+ conv.system_message = sys_p
82
+ elif args.model_type == "mixtral":
83
+ conv = get_conversation_template("llama-2-chat")
84
+ conv.system_message = ''
85
+ conv.sep2 = "</s>"
86
+ conv.append_message(conv.roles[0], "Hello")
87
+ conv.append_message(conv.roles[1], None)
88
+ prompt = conv.get_prompt()
89
+ if args.model_type == "llama-2-chat":
90
+ prompt += " "
91
+ input_ids = model.tokenizer([prompt]).input_ids
92
+ input_ids = torch.as_tensor(input_ids).cuda()
93
+ for output_ids in model.ea_generate(input_ids):
94
+ ol=output_ids.shape[1]
95
+
96
+ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_state,):
97
+ if not history:
98
+ return history, "0.00 tokens/s", "0.00", session_state
99
+ pure_history = session_state.get("pure_history", [])
100
+ assert args.model_type == "llama-2-chat" or "vicuna"
101
+ conv = get_conversation_template(args.model_type)
102
+
103
+ if args.model_type == "llama-2-chat":
104
+ sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
105
+ conv.system_message = sys_p
106
+ elif args.model_type == "mixtral":
107
+ conv = get_conversation_template("llama-2-chat")
108
+ conv.system_message = ''
109
+ conv.sep2 = "</s>"
110
+ elif args.model_type == "llama-3-instruct":
111
+ messages = [
112
+ {"role": "system",
113
+ "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
114
+ ]
115
+
116
+ for query, response in pure_history:
117
+ if args.model_type == "llama-3-instruct":
118
+ messages.append({
119
+ "role": "user",
120
+ "content": query
121
+ })
122
+ if response!=None:
123
+ messages.append({
124
+ "role": "assistant",
125
+ "content": response
126
+ })
127
+ else:
128
+ conv.append_message(conv.roles[0], query)
129
+ if args.model_type == "llama-2-chat" and response:
130
+ response = " " + response
131
+ conv.append_message(conv.roles[1], response)
132
+
133
+ if args.model_type == "llama-3-instruct":
134
+ prompt = model.tokenizer.apply_chat_template(
135
+ messages,
136
+ tokenize=False,
137
+ add_generation_prompt=True,
138
+ )
139
+ else:
140
+ prompt = conv.get_prompt()
141
+
142
+ if args.model_type == "llama-2-chat":
143
+ prompt += " "
144
+
145
+ input_ids = model.tokenizer([prompt]).input_ids
146
+ input_ids = torch.as_tensor(input_ids).cuda()
147
+ input_len = input_ids.shape[1]
148
+ naive_text = []
149
+ cu_len = input_len
150
+ totaltime=0
151
+ start_time=time.time()
152
+ total_ids=0
153
+ if use_EaInfer:
154
+
155
+ for output_ids in model.ea_generate(input_ids, temperature=temperature, top_p=top_p,
156
+ max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"):
157
+ totaltime+=(time.time()-start_time)
158
+ total_ids+=1
159
+ decode_ids = output_ids[0, input_len:].tolist()
160
+ decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id)
161
+ if args.model_type == "llama-3-instruct":
162
+ decode_ids = truncate_list(decode_ids, model.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
163
+ text = model.tokenizer.decode(decode_ids, skip_special_tokens=True, spaces_between_special_tokens=False,
164
+ clean_up_tokenization_spaces=True, )
165
+
166
+ naive_text.append(model.tokenizer.decode(output_ids[0, cu_len], skip_special_tokens=True,
167
+ spaces_between_special_tokens=False,
168
+ clean_up_tokenization_spaces=True, ))
169
+
170
+ cu_len = output_ids.shape[1]
171
+ colored_text = highlight_text(text, naive_text, "orange")
172
+ if highlight_EaInfer:
173
+ history[-1][1] = colored_text
174
+ else:
175
+ history[-1][1] = text
176
+ pure_history[-1][1] = text
177
+ session_state["pure_history"] = pure_history
178
+ new_tokens = cu_len-input_len
179
+ yield history,f"{new_tokens/totaltime:.2f} tokens/s",f"{new_tokens/total_ids:.2f}",session_state
180
+ start_time = time.time()
181
+
182
+
183
+ else:
184
+ for output_ids in model.naive_generate(input_ids, temperature=temperature, top_p=top_p,
185
+ max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"):
186
+ totaltime += (time.time() - start_time)
187
+ total_ids+=1
188
+ decode_ids = output_ids[0, input_len:].tolist()
189
+ decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id)
190
+ text = model.tokenizer.decode(decode_ids, skip_special_tokens=True, spaces_between_special_tokens=False,
191
+ clean_up_tokenization_spaces=True, )
192
+ naive_text.append(model.tokenizer.decode(output_ids[0, cu_len], skip_special_tokens=True,
193
+ spaces_between_special_tokens=False,
194
+ clean_up_tokenization_spaces=True, ))
195
+ cu_len = output_ids.shape[1]
196
+ colored_text = highlight_text(text, naive_text, "orange")
197
+ if highlight_EaInfer and use_EaInfer:
198
+ history[-1][1] = colored_text
199
+ else:
200
+ history[-1][1] = text
201
+ history[-1][1] = text
202
+ pure_history[-1][1] = text
203
+ new_tokens = cu_len - input_len
204
+ yield history,f"{new_tokens/totaltime:.2f} tokens/s",f"{new_tokens/total_ids:.2f}",session_state
205
+ start_time = time.time()
206
+
207
+
208
+ def user(user_message, history,session_state):
209
+ if history==None:
210
+ history=[]
211
+ pure_history = session_state.get("pure_history", [])
212
+ pure_history += [[user_message, None]]
213
+ session_state["pure_history"] = pure_history
214
+ return "", history + [[user_message, None]],session_state
215
+
216
+
217
+ def regenerate(history,session_state):
218
+ if not history:
219
+ return history, None,"0.00 tokens/s","0.00",session_state
220
+ pure_history = session_state.get("pure_history", [])
221
+ pure_history[-1][-1] = None
222
+ session_state["pure_history"]=pure_history
223
+ if len(history) > 1: # Check if there's more than one entry in history (i.e., at least one bot response)
224
+ new_history = history[:-1] # Remove the last bot response
225
+ last_user_message = history[-1][0] # Get the last user message
226
+ return new_history + [[last_user_message, None]], None,"0.00 tokens/s","0.00",session_state
227
+ history[-1][1] = None
228
+ return history, None,"0.00 tokens/s","0.00",session_state
229
+
230
+
231
+ def clear(history,session_state):
232
+ pure_history = session_state.get("pure_history", [])
233
+ pure_history = []
234
+ session_state["pure_history"] = pure_history
235
+ return [],"0.00 tokens/s","0.00",session_state
236
+
237
+
238
+
239
+
240
+ parser = argparse.ArgumentParser()
241
+ parser.add_argument(
242
+ "--ea-model-path",
243
+ type=str,
244
+ default="yuhuili/EAGLE-LLaMA3-Instruct-8B",
245
+ help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
246
+ )
247
+ parser.add_argument("--base-model-path", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct",
248
+ help="path of basemodel, huggingface project or local path")
249
+ parser.add_argument(
250
+ "--load-in-8bit", action="store_true", help="Use 8-bit quantization"
251
+ )
252
+ parser.add_argument(
253
+ "--load-in-4bit", action="store_true", help="Use 4-bit quantization"
254
+ )
255
+ parser.add_argument("--model-type", type=str, default="llama-3-instruct",choices=["llama-2-chat","vicuna","mixtral","llama-3-instruct"])
256
+ parser.add_argument(
257
+ "--total-token",
258
+ type=int,
259
+ default=59,
260
+ help="The maximum number of new generated tokens.",
261
+ )
262
+ parser.add_argument(
263
+ "--max-new-token",
264
+ type=int,
265
+ default=512,
266
+ help="The maximum number of new generated tokens.",
267
  )
268
+ args = parser.parse_args()
269
+
270
+ model = EaModel.from_pretrained(
271
+ base_model_path=args.base_model_path,
272
+ ea_model_path=args.ea_model_path,
273
+ total_token=args.total_token,
274
+ torch_dtype=torch.float16,
275
+ low_cpu_mem_usage=True,
276
+ load_in_4bit=args.load_in_4bit,
277
+ load_in_8bit=args.load_in_8bit,
278
+ device_map="auto",
279
+ )
280
+ model.eval()
281
+ warmup(model)
282
+
283
+ custom_css = """
284
+ #speed textarea {
285
+ color: red;
286
+ font-size: 30px;
287
+ }"""
288
+
289
+ with gr.Blocks(css=custom_css) as demo:
290
+ gs = gr.State({"pure_history": []})
291
+ gr.Markdown('''## EAGLE-2 Chatbot''')
292
+ with gr.Row():
293
+ speed_box = gr.Textbox(label="Speed", elem_id="speed", interactive=False, value="0.00 tokens/s")
294
+ compression_box = gr.Textbox(label="Compression Ratio", elem_id="speed", interactive=False, value="0.00")
295
+ with gr.Row():
296
+ with gr.Column():
297
+ use_EaInfer = gr.Checkbox(label="Use EAGLE-2", value=True)
298
+ highlight_EaInfer = gr.Checkbox(label="Highlight the tokens generated by EAGLE-2", value=True)
299
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="temperature", value=0.5)
300
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="top_p", value=0.9)
301
+ note=gr.Markdown(show_label=False,value='''The original LLM is LLaMA3-Instruct 8B, running on a single RTX 3090. The Compression Ratio is defined as the number of generated tokens divided by the number of forward passes in the original LLM. If "Highlight the tokens generated by EAGLE-2" is checked, the tokens correctly guessed by EAGLE-2
302
+ will be displayed in orange. Note: Checking this option may cause special formatting rendering issues in a few cases, especially when generating code''')
303
+
304
+
305
+ chatbot = gr.Chatbot(height=600,show_label=False)
306
+
307
 
308
+ msg = gr.Textbox(label="Your input")
309
+ with gr.Row():
310
+ send_button = gr.Button("Send")
311
+ stop_button = gr.Button("Stop")
312
+ regenerate_button = gr.Button("Regenerate")
313
+ clear_button = gr.Button("Clear")
314
+ enter_event=msg.submit(user, [msg, chatbot,gs], [msg, chatbot,gs], queue=True).then(
315
+ bot, [chatbot, temperature, top_p, use_EaInfer, highlight_EaInfer,gs], [chatbot,speed_box,compression_box,gs]
316
+ )
317
+ clear_button.click(clear, [chatbot,gs], [chatbot,speed_box,compression_box,gs], queue=True)
318
 
319
+ send_event=send_button.click(user, [msg, chatbot,gs], [msg, chatbot,gs],queue=True).then(
320
+ bot, [chatbot, temperature, top_p, use_EaInfer, highlight_EaInfer,gs], [chatbot,speed_box,compression_box,gs]
321
+ )
322
+ regenerate_event=regenerate_button.click(regenerate, [chatbot,gs], [chatbot, msg,speed_box,compression_box,gs],queue=True).then(
323
+ bot, [chatbot, temperature, top_p, use_EaInfer, highlight_EaInfer,gs], [chatbot,speed_box,compression_box,gs]
324
+ )
325
+ stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event])
326
+ demo.queue()
327
+ demo.launch(share=True)
model/__init__.py ADDED
File without changes
model/cnets.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import copy
22
+ import os
23
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "5"
24
+ import math
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
+
33
+ from transformers.activations import ACT2FN
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
35
+ SequenceClassifierOutputWithPast
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
38
+ from transformers.utils import (
39
+ add_start_docstrings,
40
+ add_start_docstrings_to_model_forward,
41
+ logging,
42
+ replace_return_docstrings,
43
+ )
44
+
45
+ try:
46
+ from .configs import EConfig
47
+ from .choices import *
48
+ except:
49
+ from configs import EConfig
50
+ from choices import *
51
+ from utils import prepare_logits_processor
52
+
53
+ import time
54
+
55
+
56
+ class Timer:
57
+ def __init__(self, name):
58
+ self.name = name
59
+
60
+ def __enter__(self):
61
+ torch.cuda.synchronize()
62
+ self.start = time.perf_counter()
63
+
64
+ def __exit__(self, exc_type, exc_value, traceback):
65
+ torch.cuda.synchronize()
66
+ elapsed = time.perf_counter() - self.start
67
+ print(f'{self.name} took {elapsed} seconds')
68
+
69
+
70
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
71
+ def _make_causal_mask(
72
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
73
+ ):
74
+ """
75
+ Make causal mask used for bi-directional self-attention.
76
+ """
77
+ bsz, tgt_len = input_ids_shape
78
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
79
+ mask_cond = torch.arange(mask.size(-1), device=device)
80
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
81
+ mask = mask.to(dtype)
82
+
83
+ if past_key_values_length > 0:
84
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
85
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
86
+
87
+
88
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
89
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
90
+ """
91
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
92
+ """
93
+ bsz, src_len = mask.size()
94
+ tgt_len = tgt_len if tgt_len is not None else src_len
95
+
96
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
97
+
98
+ inverted_mask = 1.0 - expanded_mask
99
+
100
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
101
+
102
+
103
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
104
+ """
105
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
106
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
107
+ """
108
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
109
+ if n_rep == 1:
110
+ return hidden_states
111
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
112
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
113
+
114
+
115
+ def rotate_half(x):
116
+ """Rotates half the hidden dims of the input."""
117
+ x1 = x[..., : x.shape[-1] // 2]
118
+ x2 = x[..., x.shape[-1] // 2:]
119
+ return torch.cat((-x2, x1), dim=-1)
120
+
121
+
122
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
123
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
124
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
125
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
126
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
127
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
128
+ q_embed = (q * cos) + (rotate_half(q) * sin)
129
+ k_embed = (k * cos) + (rotate_half(k) * sin)
130
+ return q_embed, k_embed
131
+
132
+
133
+ class LlamaRotaryEmbedding(torch.nn.Module):
134
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
135
+ super().__init__()
136
+
137
+ self.dim = dim
138
+ self.max_position_embeddings = max_position_embeddings
139
+ self.base = base
140
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
141
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
142
+
143
+ # Build here to make `torch.jit.trace` work.
144
+ self._set_cos_sin_cache(
145
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
146
+ )
147
+
148
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
149
+ self.max_seq_len_cached = seq_len
150
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
151
+
152
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
153
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
154
+ emb = torch.cat((freqs, freqs), dim=-1)
155
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
156
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
157
+
158
+ def forward(self, x, seq_len=None):
159
+ # x: [bs, num_attention_heads, seq_len, head_size]
160
+ if seq_len > self.max_seq_len_cached:
161
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
162
+
163
+ return (
164
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
165
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
166
+ )
167
+
168
+
169
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
170
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
171
+
172
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
173
+ self.scaling_factor = scaling_factor
174
+ super().__init__(dim, max_position_embeddings, base, device)
175
+
176
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
177
+ self.max_seq_len_cached = seq_len
178
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
179
+ t = t / self.scaling_factor
180
+
181
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
182
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
183
+ emb = torch.cat((freqs, freqs), dim=-1)
184
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
185
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
186
+
187
+
188
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
189
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
190
+
191
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
192
+ self.scaling_factor = scaling_factor
193
+ super().__init__(dim, max_position_embeddings, base, device)
194
+
195
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
196
+ self.max_seq_len_cached = seq_len
197
+
198
+ if seq_len > self.max_position_embeddings:
199
+ base = self.base * (
200
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
201
+ ) ** (self.dim / (self.dim - 2))
202
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
203
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
204
+
205
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
206
+
207
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
208
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
209
+ emb = torch.cat((freqs, freqs), dim=-1)
210
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
211
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
212
+
213
+
214
+ class LlamaAttention(nn.Module):
215
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
216
+
217
+ def __init__(self, config):
218
+ super().__init__()
219
+ self.config = config
220
+ self.hidden_size = config.hidden_size
221
+ self.num_heads = config.num_attention_heads
222
+ self.head_dim = self.hidden_size // self.num_heads
223
+ self.num_key_value_heads = config.num_key_value_heads
224
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
225
+ self.max_position_embeddings = config.max_position_embeddings
226
+
227
+ if (self.head_dim * self.num_heads) != self.hidden_size:
228
+ raise ValueError(
229
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
230
+ f" and `num_heads`: {self.num_heads})."
231
+ )
232
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
233
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
234
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
235
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
236
+ self._init_rope()
237
+
238
+ def _init_rope(self):
239
+ if self.config.rope_scaling is None:
240
+ if hasattr(self.config, "rope_theta"):
241
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim,
242
+ max_position_embeddings=self.max_position_embeddings,
243
+ base=self.config.rope_theta)
244
+ else:
245
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim,
246
+ max_position_embeddings=self.max_position_embeddings)
247
+ else:
248
+ scaling_type = self.config.rope_scaling["type"]
249
+ scaling_factor = self.config.rope_scaling["factor"]
250
+ if scaling_type == "linear":
251
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
252
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
253
+ )
254
+ elif scaling_type == "dynamic":
255
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
256
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
257
+ )
258
+ else:
259
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
260
+
261
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
262
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
270
+ output_attentions: bool = False,
271
+ use_cache: bool = False,
272
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
273
+ bsz, q_len, _ = hidden_states.size()
274
+
275
+ if self.config.pretraining_tp > 1:
276
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
277
+ query_slices = self.q_proj.weight.split(
278
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
279
+ )
280
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
281
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
282
+
283
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
284
+ query_states = torch.cat(query_states, dim=-1)
285
+
286
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
287
+ key_states = torch.cat(key_states, dim=-1)
288
+
289
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
290
+ value_states = torch.cat(value_states, dim=-1)
291
+
292
+ else:
293
+ query_states = self.q_proj(hidden_states)
294
+ key_states = self.k_proj(hidden_states)
295
+ value_states = self.v_proj(hidden_states)
296
+
297
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
298
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
299
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
300
+
301
+ kv_seq_len = key_states.shape[-2]
302
+ if past_key_value is not None:
303
+ kv_seq_len += past_key_value[0].shape[-2]
304
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
305
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
306
+
307
+ if past_key_value is not None:
308
+ # reuse k, v, self_attention
309
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
310
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
311
+
312
+ past_key_value = (key_states, value_states) if use_cache else None
313
+
314
+ # repeat k/v heads if n_kv_heads < n_heads
315
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
316
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
317
+
318
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
319
+
320
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
321
+ raise ValueError(
322
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
323
+ f" {attn_weights.size()}"
324
+ )
325
+
326
+ if attention_mask is not None:
327
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
328
+ raise ValueError(
329
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
330
+ )
331
+ attn_weights = attn_weights + attention_mask
332
+
333
+ # upcast attention to fp32
334
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
335
+ attn_output = torch.matmul(attn_weights, value_states)
336
+
337
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
338
+ raise ValueError(
339
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
340
+ f" {attn_output.size()}"
341
+ )
342
+
343
+ attn_output = attn_output.transpose(1, 2).contiguous()
344
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
345
+
346
+ if self.config.pretraining_tp > 1:
347
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
348
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
349
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
350
+ else:
351
+ attn_output = self.o_proj(attn_output)
352
+
353
+ if not output_attentions:
354
+ attn_weights = None
355
+
356
+ return attn_output, attn_weights, past_key_value
357
+
358
+
359
+ class LlamaMLP(nn.Module):
360
+ def __init__(self, config):
361
+ super().__init__()
362
+ self.config = config
363
+ self.hidden_size = config.hidden_size
364
+ self.intermediate_size = config.intermediate_size
365
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
366
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
367
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
368
+ self.act_fn = ACT2FN[config.hidden_act]
369
+
370
+ def forward(self, x):
371
+ if self.config.pretraining_tp > 1:
372
+ slice = self.intermediate_size // self.config.pretraining_tp
373
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
374
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
375
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
376
+
377
+ gate_proj = torch.cat(
378
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
379
+ )
380
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
381
+
382
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
383
+ down_proj = [
384
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
385
+ ]
386
+ down_proj = sum(down_proj)
387
+ else:
388
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
389
+
390
+ return down_proj
391
+
392
+
393
+ class LlamaRMSNorm(nn.Module):
394
+ def __init__(self, hidden_size, eps=1e-6):
395
+ """
396
+ LlamaRMSNorm is equivalent to T5LayerNorm
397
+ """
398
+ super().__init__()
399
+ self.weight = nn.Parameter(torch.ones(hidden_size))
400
+ self.variance_epsilon = eps
401
+
402
+ def forward(self, hidden_states):
403
+ input_dtype = hidden_states.dtype
404
+ hidden_states = hidden_states.to(torch.float32)
405
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
406
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
407
+ return self.weight * hidden_states.to(input_dtype)
408
+
409
+
410
+ class LlamaDecoderLayer(nn.Module):
411
+ def __init__(self, config, index):
412
+ super().__init__()
413
+ self.hidden_size = config.hidden_size
414
+ self.self_attn = LlamaAttention(config=config)
415
+ self.mlp = LlamaMLP(config)
416
+ self.index = index
417
+ if self.index != 0:
418
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
419
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ attention_mask: Optional[torch.Tensor] = None,
425
+ position_ids: Optional[torch.LongTensor] = None,
426
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
427
+ output_attentions: Optional[bool] = False,
428
+ use_cache: Optional[bool] = False,
429
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
430
+ """
431
+ Args:
432
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
433
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
434
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
435
+ output_attentions (`bool`, *optional*):
436
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
437
+ returned tensors for more detail.
438
+ use_cache (`bool`, *optional*):
439
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
440
+ (see `past_key_values`).
441
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
442
+ """
443
+
444
+ residual = hidden_states
445
+
446
+ if self.index != 0:
447
+ hidden_states = self.input_layernorm(hidden_states)
448
+
449
+ # Self Attention
450
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
451
+ hidden_states=hidden_states,
452
+ attention_mask=attention_mask,
453
+ position_ids=position_ids,
454
+ past_key_value=past_key_value,
455
+ output_attentions=output_attentions,
456
+ use_cache=use_cache,
457
+ )
458
+ hidden_states = residual + hidden_states
459
+
460
+ # Fully Connected
461
+ residual = hidden_states
462
+ hidden_states = self.post_attention_layernorm(hidden_states)
463
+ hidden_states = self.mlp(hidden_states)
464
+ hidden_states = residual + hidden_states
465
+
466
+ outputs = (hidden_states,)
467
+
468
+ if output_attentions:
469
+ outputs += (self_attn_weights,)
470
+
471
+ if use_cache:
472
+ outputs += (present_key_value,)
473
+
474
+ return outputs
475
+
476
+
477
+ class I(nn.Module):
478
+ def __init__(self):
479
+ super().__init__()
480
+ self.dummy = nn.Parameter(torch.ones(1, dtype=torch.float32))
481
+
482
+ def forward(self, x):
483
+ return x + self.dummy - self.dummy # (also tried x+self.dummy)
484
+
485
+
486
+ def len_list(x, n):
487
+ return [i for i in x if len(i) <= n]
488
+
489
+
490
+ class Model(nn.Module):
491
+ def __init__(self, config, load_emb=False, path=None, bias=True, total_tokens=63, depth=5, top_k=8, threshold=1.0):
492
+ super().__init__()
493
+
494
+ self.gradient_checkpointing = True
495
+ self.padding_idx = config.pad_token_id
496
+ self.vocab_size = config.vocab_size
497
+
498
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
499
+ if load_emb:
500
+ from safetensors import safe_open
501
+ import json
502
+ try:
503
+ with open(os.path.join(path, "model.safetensors.index.json"), "r") as f:
504
+ index_json = json.loads(f.read())
505
+ emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
506
+ with safe_open(os.path.join(path, emb_path),
507
+ framework="pt",
508
+ device="cpu") as f:
509
+ tensor_slice = f.get_slice("model.embed_tokens.weight")
510
+ vocab_size, hidden_dim = tensor_slice.get_shape()
511
+ tensor = tensor_slice[:, :hidden_dim].float()
512
+ except:
513
+ with open(os.path.join(path, "pytorch_model.bin.index.json"), "r") as f:
514
+ index_json = json.loads(f.read())
515
+ emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
516
+ weights = torch.load(os.path.join(path, emb_path))
517
+ tensor = weights["model.embed_tokens.weight"].float()
518
+ self.embed_tokens.weight.data = tensor
519
+
520
+ self.top_k = top_k
521
+ self.total_tokens = total_tokens - 1
522
+ self.depth = depth
523
+ self.threshold = math.log(threshold)
524
+ # print("total_tokens",total_tokens)
525
+ # print("depth",depth)
526
+ # print("top_k",top_k)
527
+ # print("threshold",threshold)
528
+
529
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config, index) for index in range(config.num_hidden_layers)])
530
+ self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=bias)
531
+ self.act = ACT2FN[config.hidden_act]
532
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
533
+ for param in self.embed_tokens.parameters():
534
+ param.requires_grad = False
535
+
536
+ def init_tree(self):
537
+ self.tree_mask_init = torch.eye(self.top_k, device=self.embed_tokens.weight.device)[None, None]
538
+ self.position_ids = torch.zeros(self.top_k, device=self.embed_tokens.weight.device, dtype=torch.long)
539
+ self.tree_mask_init = self.tree_mask_init.to(self.embed_tokens.weight.device)
540
+
541
+ def reset(self):
542
+ self.tree_mask = None
543
+
544
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
545
+ # create causal mask
546
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
547
+ combined_attention_mask = None
548
+ if input_shape[-1] > 1:
549
+ combined_attention_mask = _make_causal_mask(
550
+ input_shape,
551
+ # inputs_embeds.dtype,
552
+ torch.float32, # [MODIFIED] force to cast to float32
553
+ device=inputs_embeds.device,
554
+ past_key_values_length=past_key_values_length,
555
+ )
556
+
557
+ if attention_mask is not None:
558
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
559
+ expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to(
560
+ inputs_embeds.device
561
+ )
562
+ combined_attention_mask = (
563
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
564
+ )
565
+
566
+ # [MODIFIED] add tree mask
567
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
568
+ tree_mask = self.tree_mask
569
+ _, _, tree_shape0, tree_shape1 = tree_mask.shape
570
+ combined_attention_mask[:, :, -tree_shape0:, -tree_shape1:][
571
+ tree_mask == 0
572
+ ] = torch.finfo(torch.float32).min
573
+
574
+ return combined_attention_mask
575
+
576
+ def forward(
577
+ self,
578
+ hidden_states,
579
+ input_ids,
580
+ attention_mask: Optional[torch.Tensor] = None,
581
+ position_ids: Optional[torch.LongTensor] = None,
582
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
583
+ inputs_embeds: Optional[torch.FloatTensor] = None,
584
+ use_cache: Optional[bool] = None,
585
+ output_attentions: Optional[bool] = None,
586
+ output_hidden_states: Optional[bool] = None,
587
+ return_dict: Optional[bool] = None,
588
+ std=None
589
+ ):
590
+ batch_size, seq_length, _ = hidden_states.shape
591
+ seq_length_with_past = seq_length
592
+ past_key_values_length = 0
593
+
594
+ with torch.no_grad():
595
+ inputs_embeds = self.embed_tokens(input_ids)
596
+ # inputs_embeds = inputs_embeds.detach()
597
+
598
+ # if std is not None:
599
+ # noise = torch.randn(inputs_embeds.size(),device=inputs_embeds.device) * std
600
+ # inputs_embeds=inputs_embeds+noise
601
+
602
+ if past_key_values is not None:
603
+ past_key_values_length = past_key_values[0][0].shape[2]
604
+ seq_length_with_past = seq_length_with_past + past_key_values_length
605
+ if position_ids is None:
606
+ device = hidden_states.device if hidden_states is not None else inputs_embeds.device
607
+ position_ids = torch.arange(
608
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
609
+ )
610
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
611
+ else:
612
+ position_ids = position_ids.view(-1, seq_length).long()
613
+
614
+ if attention_mask is None:
615
+ attention_mask = torch.ones(
616
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
617
+ )
618
+ attention_mask = self._prepare_decoder_attention_mask(
619
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
620
+ )
621
+
622
+ # if self.gradient_checkpointing and self.training:
623
+ # if use_cache:
624
+ # use_cache = False
625
+
626
+ # hidden_states=self.act(self.fc(torch.cat((inputs_embeds,hidden_states),dim=-1)))
627
+ inputs_embeds = inputs_embeds.to(hidden_states.dtype)
628
+ hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
629
+
630
+ all_hidden_states = () if output_hidden_states else None
631
+ next_decoder_cache = () if use_cache else None
632
+
633
+ for idx, decoder_layer in enumerate(self.layers):
634
+ if output_hidden_states:
635
+ all_hidden_states += (hidden_states,)
636
+
637
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
638
+
639
+ if self.gradient_checkpointing and self.training:
640
+
641
+ def create_custom_forward(module):
642
+ def custom_forward(*inputs):
643
+ # None for past_key_value
644
+ return module(*inputs, past_key_value, output_attentions)
645
+
646
+ return custom_forward
647
+
648
+ layer_outputs = torch.utils.checkpoint.checkpoint(
649
+ create_custom_forward(decoder_layer),
650
+ hidden_states,
651
+ attention_mask,
652
+ position_ids,
653
+ )
654
+ else:
655
+ layer_outputs = decoder_layer(
656
+ hidden_states,
657
+ attention_mask=attention_mask,
658
+ position_ids=position_ids,
659
+ past_key_value=past_key_value,
660
+ output_attentions=output_attentions,
661
+ use_cache=use_cache,
662
+ )
663
+
664
+ hidden_states = layer_outputs[0]
665
+
666
+ if use_cache:
667
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
668
+
669
+ if use_cache:
670
+ return hidden_states, next_decoder_cache
671
+
672
+ return hidden_states
673
+
674
+ def reset_kv(self):
675
+ self.stable_kv = None
676
+
677
+ @torch.no_grad()
678
+ def topK_genrate(self, hidden_states, input_ids, head, logits_processor):
679
+
680
+ input_ids = input_ids.to(hidden_states.device)
681
+ total_tokens = self.total_tokens
682
+ depth = self.depth
683
+ top_k = self.top_k
684
+
685
+ sample_token = input_ids[:, -1]
686
+
687
+ scores_list = []
688
+ parents_list = []
689
+ ss_token = []
690
+
691
+ input_ids = input_ids[:, 1:]
692
+ input_ids = input_ids.to(hidden_states.device)
693
+
694
+ len_posi = input_ids.shape[1]
695
+ self.reset()
696
+
697
+ # with Timer("draft many"):
698
+ if hasattr(self, "stable_kv") and self.stable_kv is not None:
699
+ kv_len = self.stable_kv[0][0].shape[2]
700
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids[:, kv_len:],
701
+ past_key_values=self.stable_kv, use_cache=True)
702
+ else:
703
+ out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, use_cache=True)
704
+ self.stable_kv = past_key_values
705
+ last_hidden = out_hidden[:, -1]
706
+
707
+ last_headout = head(last_hidden)
708
+
709
+ last_p = self.logsoftmax(last_headout)
710
+ top = torch.topk(last_p, top_k, dim=-1)
711
+ topk_index, topk_p = top.indices, top.values
712
+ scores = topk_p[0]
713
+ scores_list.append(scores[None])
714
+ parents_list.append(torch.zeros(1, dtype=torch.long, device=scores.device))
715
+ ss_token.append(topk_index)
716
+ input_ids = topk_index
717
+ input_hidden = last_hidden[None].repeat(1, top_k, 1)
718
+ tree_mask = self.tree_mask_init
719
+ topk_cs_index = torch.arange(top_k, device=self.embed_tokens.weight.device)
720
+
721
+ # 4
722
+ for i in range(depth):
723
+ self.tree_mask = tree_mask
724
+ position_ids = len_posi + self.position_ids
725
+ # with Timer("draft one"):
726
+ out_hidden, past_key_values = self(input_hidden, input_ids=input_ids, past_key_values=past_key_values,
727
+ position_ids=position_ids, use_cache=True)
728
+ len_posi += 1
729
+
730
+ # with Timer("sort1"):
731
+ bias1 = top_k if i > 0 else 0
732
+ bias2 = max(0, i - 1)
733
+ bias = 1 + top_k ** 2 * bias2 + bias1
734
+ parents = (topk_cs_index + bias)
735
+ parents_list.append(parents)
736
+
737
+ last_headout = head(out_hidden[0])
738
+ last_p = self.logsoftmax(last_headout)
739
+
740
+ top = torch.topk(last_p, top_k, dim=-1)
741
+ topk_index, topk_p = top.indices, top.values
742
+
743
+ cu_scores = topk_p + scores[:, None]
744
+
745
+ topk_cs = torch.topk(cu_scores.view(-1), top_k, dim=-1)
746
+ topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
747
+ scores = topk_cs_p
748
+
749
+ out_ids = topk_cs_index // top_k
750
+ input_hidden = out_hidden[:, out_ids]
751
+ # with Timer("2index"):
752
+ # in_ids = topk_cs_index % top_k
753
+ # input_ids = topk_index[out_ids, in_ids][None]
754
+ # with Timer("1index"):
755
+ input_ids = topk_index.view(-1)[topk_cs_index][None]
756
+ # print(input_ids.equal(input_ids0))
757
+
758
+ ss_token.append(topk_index)
759
+ scores_list.append(cu_scores)
760
+ tree_mask = torch.cat((tree_mask[:, :, out_ids], self.tree_mask_init), dim=3)
761
+
762
+ # if self.threshold < 0 and cu_scores.max() < self.threshold:
763
+ # break
764
+
765
+ # del parents_list,scores_list,ss_token
766
+ # return draft_tokens, mask_index,tree_mask,tree_position_ids
767
+
768
+ # with Timer("post"):
769
+
770
+ scores_list = torch.cat(scores_list, dim=0).view(-1)
771
+ ss_token_list = torch.cat(ss_token, dim=0).view(-1)
772
+ top_scores = torch.topk(scores_list, total_tokens, dim=-1)
773
+ top_scores_index = top_scores.indices
774
+ top_scores_index = torch.sort(top_scores_index).values
775
+
776
+ draft_tokens = ss_token_list[top_scores_index]
777
+ draft_tokens = torch.cat((sample_token, draft_tokens), dim=0)
778
+
779
+ draft_parents = torch.cat(parents_list, dim=0)[top_scores_index // top_k].long()
780
+ mask_index = torch.searchsorted(top_scores_index, draft_parents - 1, right=False)
781
+ # mask_index[(top_scores_index[mask_index]!=draft_parents - 1)]=-1
782
+ mask_index[draft_parents == 0] = -1
783
+ mask_index = mask_index + 1
784
+ mask_index_list = mask_index.tolist()
785
+ # with Timer("mask"):
786
+ tree_mask = torch.eye(total_tokens + 1).bool()
787
+ tree_mask[:, 0] = True
788
+ for i in range(total_tokens):
789
+ tree_mask[i + 1].add_(tree_mask[mask_index_list[i]])
790
+
791
+ # with Timer("mask1"):
792
+ # tree_mask0 = [[False for _ in range(total_tokens + 1)] for _ in range(total_tokens + 1)]
793
+ # tree_mask0[0][0] = True
794
+ # for i in range(total_tokens):
795
+ # #tree_mask0[i + 1][0]=True
796
+ # tree_mask0[i + 1][i + 1] = True
797
+ # p=mask_index_list[i]
798
+ # tree_mask0[i + 1][p] = True
799
+ # while p:
800
+ # p=mask_index_list[p-1]
801
+ # tree_mask0[i + 1][p] = True
802
+ # tree_mask0 = torch.tensor(tree_mask0, dtype=torch.bool)
803
+ #
804
+ # print(tree_mask0.equal(tree_mask))
805
+ tree_position_ids = torch.sum(tree_mask, dim=1) - 1
806
+
807
+ tree_mask = tree_mask.float()[None, None]
808
+ draft_tokens = draft_tokens[None]
809
+
810
+ del parents_list, scores_list, ss_token, ss_token_list, draft_parents
811
+
812
+ # with Timer("retrieve"):
813
+
814
+ max_depth = torch.max(tree_position_ids) + 1
815
+ noleaf_index = torch.unique(mask_index).tolist()
816
+ noleaf_num = len(noleaf_index) - 1
817
+ leaf_num = total_tokens - noleaf_num
818
+
819
+ retrieve_indices = torch.zeros(leaf_num, max_depth.item(), dtype=torch.long) - 1
820
+ retrieve_indices = retrieve_indices.tolist()
821
+
822
+ rid = 0
823
+ position_ids_list = tree_position_ids.tolist()
824
+
825
+ for i in range(total_tokens + 1):
826
+ if i not in noleaf_index:
827
+ cid = i
828
+ depth = position_ids_list[i]
829
+ for j in reversed(range(depth + 1)):
830
+ retrieve_indices[rid][j] = cid
831
+ cid = mask_index_list[cid - 1]
832
+ rid += 1
833
+
834
+ if logits_processor is not None:
835
+ maxitem = total_tokens + 5
836
+
837
+ def custom_sort(lst):
838
+ # sort_keys=[len(list)]
839
+ sort_keys = []
840
+ for i in range(len(lst)):
841
+ sort_keys.append(lst[i] if lst[i] >= 0 else maxitem)
842
+ return sort_keys
843
+
844
+ retrieve_indices = sorted(retrieve_indices, key=custom_sort)
845
+
846
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
847
+ del mask_index, mask_index_list, noleaf_index, noleaf_num, leaf_num, max_depth, rid
848
+ tree_position_ids = tree_position_ids.to(hidden_states.device)
849
+
850
+ return draft_tokens, retrieve_indices, tree_mask, tree_position_ids
851
+
852
+ @torch.no_grad()
853
+ def acc(self, data, head, max_length=5):
854
+ hidden_states = data["hidden_states"]
855
+ input_ids = data["input_ids"]
856
+ # attention_mask=data["attention_mask"]
857
+ loss_mask = data["loss_mask"]
858
+ sample_mask = data["sample_mask"]
859
+ target = data["target"]
860
+ total = [0 for _ in range(max_length)]
861
+ correct = [0 for _ in range(max_length)]
862
+ bs, sl = hidden_states.shape[0], hidden_states.shape[1]
863
+ target_headout = head(target)
864
+ hidden_states_headout = head(hidden_states)
865
+
866
+ for i in range(bs):
867
+ for j in range(sl):
868
+ if loss_mask[i, j] == 0:
869
+ continue
870
+ single_hidden_states = hidden_states[i, :j]
871
+ single_input_ids = input_ids[i, :j]
872
+
873
+ single_hidden_states = single_hidden_states[None, :, :]
874
+ single_input_ids = single_input_ids[None, :]
875
+ for k in range(max_length):
876
+ tmp_in_target_headout = hidden_states_headout[i, single_hidden_states.shape[1] - 1]
877
+ tmp_out_target_headout = target_headout[i, single_hidden_states.shape[1] - 1]
878
+ target_in_token = torch.argmax(tmp_in_target_headout)
879
+ target_out_token = torch.argmax(tmp_out_target_headout)
880
+ tmp_token = input_ids[i, single_hidden_states.shape[1] - 1]
881
+ tmp_sample_mask = sample_mask[i, single_hidden_states.shape[1] - 1]
882
+ if not (target_in_token == tmp_token):
883
+ break
884
+ out_hidden = self(single_hidden_states, input_ids=single_input_ids)
885
+ last_hidden = out_hidden[:, -1]
886
+ last_headout = head(last_hidden)
887
+ token = torch.argmax(last_headout)
888
+ total[k] += 1
889
+ if token == target_out_token:
890
+ correct[k] += 1
891
+ else:
892
+ for kk in range(k, max_length):
893
+ total[kk] += 1
894
+ break
895
+
896
+ single_hidden_states = torch.cat((single_hidden_states, out_hidden[:, -1:]), dim=1)
897
+ single_input_ids = torch.cat(
898
+ (single_input_ids, torch.tensor([[token]]).to(single_input_ids.device)), dim=1)
899
+
900
+ acc = [correct[i] / total[i] for i in range(len(correct))]
901
+ return acc
902
+
903
+
904
+ class Vhead(nn.Module):
905
+ def __init__(self, ins=6566, outs=32000):
906
+ super().__init__()
907
+ self.fc = nn.Linear(ins, outs, bias=False)
908
+
909
+ def forward(self, x):
910
+ return self.fc(x)
911
+
912
+
913
+ import torch
914
+
915
+
916
+ def count_parameters(model):
917
+ return sum(p.numel() for p in model.parameters())
918
+
919
+
920
+ if __name__ == "__main__":
921
+ config = EConfig.from_pretrained('config.json')
922
+ model = Model(config, load_emb=True, path="/home/lyh/weights/hf/vicuna_v13/7B/")
923
+ print(model)
model/configs.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ class EConfig(PretrainedConfig):
3
+ r"""
4
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
5
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
6
+ defaults will yield a similar configuration to that of the LLaMA-7B.
7
+
8
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
9
+ documentation from [`PretrainedConfig`] for more information.
10
+
11
+
12
+ Args:
13
+ vocab_size (`int`, *optional*, defaults to 32000):
14
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
15
+ `inputs_ids` passed when calling [`LlamaModel`]
16
+ hidden_size (`int`, *optional*, defaults to 4096):
17
+ Dimension of the hidden representations.
18
+ intermediate_size (`int`, *optional*, defaults to 11008):
19
+ Dimension of the MLP representations.
20
+ num_hidden_layers (`int`, *optional*, defaults to 32):
21
+ Number of hidden layers in the Transformer encoder.
22
+ num_attention_heads (`int`, *optional*, defaults to 32):
23
+ Number of attention heads for each attention layer in the Transformer encoder.
24
+ num_key_value_heads (`int`, *optional*):
25
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
26
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
27
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
28
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
29
+ by meanpooling all the original heads within that group. For more details checkout [this
30
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
31
+ `num_attention_heads`.
32
+ pretraining_tp (`int`, *optional*, defaults to `1`):
33
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
34
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
35
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
36
+ issue](https://github.com/pytorch/pytorch/issues/76232).
37
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
38
+ The non-linear activation function (function or string) in the decoder.
39
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
40
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
41
+ just in case (e.g., 512 or 1024 or 2048).
42
+ initializer_range (`float`, *optional*, defaults to 0.02):
43
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
44
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
45
+ The epsilon used by the rms normalization layers.
46
+ use_cache (`bool`, *optional*, defaults to `True`):
47
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
48
+ relevant if `config.is_decoder=True`.
49
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
50
+ Whether to tie weight embeddings
51
+ rope_scaling (`Dict`, *optional*):
52
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
53
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
54
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
55
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
56
+ these scaling strategies behave:
57
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
58
+ experimental feature, subject to breaking API changes in future versions.
59
+
60
+ Example:
61
+
62
+ ```python
63
+ >>> from transformers import LlamaModel, LlamaConfig
64
+
65
+ >>> # Initializing a LLaMA llama-7b style configuration
66
+ >>> configuration = LlamaConfig()
67
+
68
+ >>> # Initializing a model from the llama-7b style configuration
69
+ >>> model = LlamaModel(configuration)
70
+
71
+ >>> # Accessing the model configuration
72
+ >>> configuration = model.config
73
+ ```"""
74
+ model_type = "llama"
75
+ keys_to_ignore_at_inference = ["past_key_values"]
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_size=32000,
80
+ hidden_size=4096,
81
+ intermediate_size=11008,
82
+ num_hidden_layers=32,
83
+ num_attention_heads=32,
84
+ num_key_value_heads=None,
85
+ hidden_act="silu",
86
+ max_position_embeddings=2048,
87
+ initializer_range=0.02,
88
+ rms_norm_eps=1e-6,
89
+ use_cache=True,
90
+ pad_token_id=None,
91
+ bos_token_id=1,
92
+ eos_token_id=2,
93
+ pretraining_tp=1,
94
+ tie_word_embeddings=False,
95
+ rope_scaling=None,
96
+ **kwargs,
97
+ ):
98
+ self.vocab_size = vocab_size
99
+ self.max_position_embeddings = max_position_embeddings
100
+ self.hidden_size = hidden_size
101
+ self.intermediate_size = intermediate_size
102
+ self.num_hidden_layers = num_hidden_layers
103
+ self.num_attention_heads = num_attention_heads
104
+
105
+ # for backward compatibility
106
+ if num_key_value_heads is None:
107
+ num_key_value_heads = num_attention_heads
108
+
109
+ self.num_key_value_heads = num_key_value_heads
110
+ self.hidden_act = hidden_act
111
+ self.initializer_range = initializer_range
112
+ self.rms_norm_eps = rms_norm_eps
113
+ self.pretraining_tp = pretraining_tp
114
+ self.use_cache = use_cache
115
+ self.rope_scaling = rope_scaling
116
+ self._rope_scaling_validation()
117
+
118
+ super().__init__(
119
+ pad_token_id=pad_token_id,
120
+ bos_token_id=bos_token_id,
121
+ eos_token_id=eos_token_id,
122
+ tie_word_embeddings=tie_word_embeddings,
123
+ **kwargs,
124
+ )
125
+
126
+ def _rope_scaling_validation(self):
127
+ """
128
+ Validate the `rope_scaling` configuration.
129
+ """
130
+ if self.rope_scaling is None:
131
+ return
132
+
133
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
134
+ raise ValueError(
135
+ "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
136
+ f"got {self.rope_scaling}"
137
+ )
138
+ rope_scaling_type = self.rope_scaling.get("type", None)
139
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
140
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
141
+ raise ValueError(
142
+ f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
143
+ )
144
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
145
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
model/ea_model.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import time
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import PreTrainedModel, PretrainedConfig,AutoConfig
8
+ from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
9
+ from .modeling_mixtral_kv import MixtralForCausalLM as KVMixtralForCausalLM
10
+ from .utils import *
11
+ from .kv_cache import initialize_past_key_values
12
+ from .choices import mc_sim_7b_63
13
+ from transformers import AutoTokenizer
14
+ import os
15
+ from huggingface_hub import hf_hub_download
16
+ from .cnets import Model
17
+ from .configs import EConfig
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+
22
+
23
+ class EaModel(nn.Module):
24
+
25
+ def __init__(
26
+ self,
27
+ base_model,
28
+ base_model_name_or_path,
29
+ ea_model_path,
30
+ total_token,
31
+ depth,
32
+ top_k,
33
+ threshold,
34
+ ea_layer_state_dict
35
+ ):
36
+
37
+ super().__init__()
38
+ self.base_model = base_model
39
+ self.config = base_model.config
40
+ self.hidden_size = base_model.lm_head.weight.shape[-1]
41
+ self.vocab_size = base_model.lm_head.weight.shape[0]
42
+ self.base_model_name_or_path = base_model_name_or_path
43
+ self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path,use_fast=False)
44
+ config = EConfig.from_pretrained(ea_model_path)
45
+ with open(ea_model_path,"r") as f:
46
+ con=json.loads(f.read())
47
+ try:
48
+ bias=con["bias"]
49
+ except:
50
+ bias=True
51
+ self.ea_layer = Model(config,bias=bias,total_tokens=total_token,depth=depth,top_k=top_k,threshold=threshold)
52
+
53
+ low_memory=False
54
+
55
+ device = base_model.model.layers[-1].self_attn.q_proj.weight.device
56
+ if device!=base_model.lm_head.weight.device:
57
+ self.ea_layer.diff_device = True
58
+ if not low_memory:
59
+ # self.ea_layer.head=nn.Linear(base_model.lm_head.in_features,base_model.lm_head.out_features,bias=False)
60
+ # self.ea_layer.head.weight=copy.deepcopy(base_model.lm_head.weight)
61
+ # self.ea_layer.head.to(device)
62
+ self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device)
63
+ else:
64
+ self.ea_layer.layer_device = device
65
+
66
+ else:
67
+ self.ea_layer.diff_device = False
68
+ self.ea_layer.load_state_dict(ea_layer_state_dict, strict=True)
69
+ self.ea_layer.to(self.base_model.dtype).to(device)
70
+ self.ea_layer.init_tree()
71
+
72
+ def get_tokenizer(self):
73
+ """Get the tokenizer of the base model.
74
+
75
+ Returns:
76
+ Tokenizer: The tokenizer of the base model.
77
+ """
78
+ return self.tokenizer
79
+
80
+ @classmethod
81
+ def from_pretrained(
82
+ cls,
83
+ Type="LLaMA",
84
+ base_model_path=None,
85
+ ea_model_path=None,
86
+ total_token=59,
87
+ depth=5,
88
+ top_k=10,
89
+ threshold=1.0,
90
+ **kwargs,
91
+ ):
92
+ #assert Type=="LLaMA" or "Mixtral"
93
+ Type=AutoConfig.from_pretrained(base_model_path).architectures[0]
94
+ if Type=='LlamaForCausalLM':
95
+ base_model = KVLlamaForCausalLM.from_pretrained(
96
+ base_model_path, **kwargs
97
+ )
98
+ else:
99
+ base_model = KVMixtralForCausalLM.from_pretrained(
100
+ base_model_path, **kwargs
101
+ )
102
+
103
+ configpath=os.path.join(ea_model_path,"config.json")
104
+ if not os.path.exists(configpath):
105
+ configpath = hf_hub_download(ea_model_path, "config.json")
106
+ load_model_path=os.path.join(ea_model_path, "pytorch_model.bin")
107
+ if not os.path.exists(load_model_path):
108
+ load_model_path=hf_hub_download(ea_model_path, "pytorch_model.bin")
109
+ ea_layer_state_dict = torch.load(load_model_path,
110
+ map_location="cpu")
111
+ model = cls(
112
+ base_model,
113
+ base_model_path,
114
+ configpath,
115
+ total_token,
116
+ depth,
117
+ top_k,
118
+ threshold,
119
+ ea_layer_state_dict
120
+ )
121
+
122
+
123
+
124
+ if total_token==-1:
125
+ device = model.base_model.model.layers[0].self_attn.q_proj.weight.device
126
+ cans=[40,48,50,56,60]
127
+ x=[1,1.05,1.07,1.1,1.13]
128
+ times=[]
129
+
130
+ for i in range(len(cans)):
131
+ length = cans[i]
132
+ input_ids = torch.randint(0, model.config.vocab_size - 200, (1, length)).to(device)
133
+ torch.cuda.synchronize()
134
+ start_time = time.time()
135
+ for _ in range(20):
136
+ torch.cuda.synchronize()
137
+ with torch.no_grad():
138
+ outputs = model.base_model(input_ids)
139
+ torch.cuda.synchronize()
140
+ torch.cuda.synchronize()
141
+ end_time = time.time()
142
+ times.append((end_time - start_time) / x[i])
143
+ total_token=cans[times.index(min(times))]
144
+ model.ea_layer.total_tokens=total_token-1
145
+
146
+
147
+
148
+
149
+ return model
150
+
151
+ def forward(
152
+ self,
153
+ input_ids=None,
154
+ attention_mask=None,
155
+ past_key_values=None,
156
+ output_orig=False,
157
+ position_ids=None,
158
+ ):
159
+
160
+ with torch.inference_mode():
161
+ # Pass input through the base model
162
+ outputs = self.base_model.model(
163
+ input_ids=input_ids,
164
+ attention_mask=attention_mask,
165
+ past_key_values=past_key_values,
166
+ position_ids=position_ids,
167
+ )
168
+ if output_orig:
169
+ orig = self.base_model.lm_head(outputs[0])
170
+ hidden_states = outputs[0]
171
+ # if init:
172
+ # if logits_processor is not None:
173
+ # logits = orig[:, -1]
174
+ # logits = logits_processor(None, logits)
175
+ # probabilities = torch.nn.functional.softmax(logits, dim=1)
176
+ # token = torch.multinomial(probabilities, 1)
177
+ # else:
178
+ # token = torch.argmax(orig[:, -1])
179
+ # token = token[None, None]
180
+ # input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
181
+ # # Clone the output hidden states
182
+ #
183
+ # draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head)
184
+ # if output_orig:
185
+ # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token
186
+ # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token
187
+ # else:
188
+ if output_orig:
189
+ return outputs, orig, hidden_states
190
+ else:
191
+ return outputs, hidden_states
192
+
193
+ @torch.no_grad()
194
+ def eagenerate(
195
+ self,
196
+ input_ids,
197
+ temperature=0.0,
198
+ top_p=0.0,
199
+ top_k=0.0,
200
+ max_new_tokens=512,
201
+ max_length=2048,
202
+ log=False,
203
+ is_llama3=False,
204
+
205
+ ):
206
+ if is_llama3:
207
+ stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
208
+ max_length=max_length-self.ea_layer.total_tokens-10
209
+
210
+ if temperature > 1e-5:
211
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
212
+ else:
213
+ logits_processor = None
214
+ #assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
215
+ # Avoid modifying the input_ids in-place
216
+
217
+ padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device)
218
+ input_ids = input_ids.clone()
219
+ self.ea_layer.reset_kv()
220
+
221
+
222
+
223
+ # Initialize the past key and value states
224
+ if hasattr(self, "past_key_values"):
225
+ past_key_values = self.past_key_values
226
+ past_key_values_data = self.past_key_values_data
227
+ current_length_data = self.current_length_data
228
+ # Reset the past key and value states
229
+ current_length_data.zero_()
230
+ else:
231
+ (
232
+ past_key_values,
233
+ past_key_values_data,
234
+ current_length_data,
235
+ ) = initialize_past_key_values(self.base_model)
236
+ self.past_key_values = past_key_values
237
+ self.past_key_values_data = past_key_values_data
238
+ self.current_length_data = current_length_data
239
+
240
+ input_len = input_ids.shape[1]
241
+ reset_tree_mode(self)
242
+ draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
243
+ input_ids, self, past_key_values, logits_processor
244
+ )
245
+ new_token = 0
246
+
247
+ for idx in range(max_length):
248
+ #with Timer("all"):
249
+ self.base_model.model.tree_mask = tree_mask
250
+
251
+ draft_tokens=draft_tokens.to(input_ids.device)
252
+ #with Timer("tree_decoding"):
253
+ logits, hidden_state_new, outputs = tree_decoding(
254
+ self,
255
+ draft_tokens,
256
+ past_key_values,
257
+ tree_position_ids,
258
+ input_ids,
259
+ retrieve_indices,
260
+ )
261
+ #retrieve_indices=tree_buffers["retrieve_indices"]
262
+ #logits = logits[0, retrieve_indices]
263
+ draft_tokens=torch.cat((draft_tokens,padding),dim=1)
264
+ candidates=draft_tokens[0,retrieve_indices]
265
+ best_candidate, accept_length, sample_p = evaluate_posterior(
266
+ logits, candidates, logits_processor
267
+ )
268
+ # print(accept_length)
269
+ #with Timer("update_inference_inputs"):
270
+ input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs(
271
+ input_ids,
272
+ candidates,
273
+ best_candidate,
274
+ accept_length,
275
+ retrieve_indices,
276
+ logits_processor,
277
+ new_token,
278
+ past_key_values_data,
279
+ current_length_data,
280
+ self,
281
+ hidden_state_new,
282
+ sample_p
283
+ )
284
+
285
+ if is_llama3:
286
+ if stop_token_id in input_ids[0, input_len:].tolist():
287
+ break
288
+
289
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
290
+ break
291
+ if new_token > max_new_tokens:
292
+ break
293
+ if input_ids.shape[1] > max_length:
294
+ break
295
+ if not log:
296
+ return input_ids
297
+ else:
298
+ return input_ids, new_token, idx
299
+
300
+
301
+ @torch.no_grad()
302
+ def naivegenerate(
303
+ self,
304
+ input_ids,
305
+ temperature=0.0,
306
+ top_p=0.0,
307
+ top_k=0.0,
308
+ max_new_tokens=512,
309
+ max_length=2048,
310
+ log=False,
311
+ is_llama3=False,
312
+
313
+ ):
314
+ if is_llama3:
315
+ stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
316
+ max_length = max_length - self.ea_layer.total_tokens - 10
317
+
318
+ if temperature > 1e-5:
319
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
320
+ else:
321
+ logits_processor = None
322
+ # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
323
+ # Avoid modifying the input_ids in-place
324
+
325
+ padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
326
+ input_ids = input_ids.clone()
327
+ self.ea_layer.reset_kv()
328
+
329
+
330
+
331
+ # Initialize the past key and value states
332
+ if hasattr(self, "past_key_values"):
333
+ past_key_values = self.past_key_values
334
+ past_key_values_data = self.past_key_values_data
335
+ current_length_data = self.current_length_data
336
+ # Reset the past key and value states
337
+ current_length_data.zero_()
338
+ else:
339
+ (
340
+ past_key_values,
341
+ past_key_values_data,
342
+ current_length_data,
343
+ ) = initialize_past_key_values(self.base_model)
344
+ self.past_key_values = past_key_values
345
+ self.past_key_values_data = past_key_values_data
346
+ self.current_length_data = current_length_data
347
+
348
+ input_len = input_ids.shape[1]
349
+ reset_tree_mode(self)
350
+ outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
351
+ new_token = 0
352
+
353
+ for idx in range(max_length):
354
+ if logits_processor is not None:
355
+ logits = outputs.logits[:, -1]
356
+ logits = logits_processor(None, logits)
357
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
358
+ input_id = torch.multinomial(probabilities, 1)
359
+ else:
360
+ input_id = outputs.logits[:, -1:].argmax(dim=-1)
361
+ outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
362
+ input_ids = torch.cat([input_ids, input_id], dim=-1)
363
+ new_token+=1
364
+
365
+ if is_llama3:
366
+ if stop_token_id in input_ids[0, input_len:].tolist():
367
+ break
368
+
369
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
370
+ break
371
+ if new_token > max_new_tokens:
372
+ break
373
+ if input_ids.shape[1] > max_length:
374
+ break
375
+ if not log:
376
+ return input_ids
377
+ else:
378
+ return input_ids, new_token, idx
379
+
380
+ @torch.no_grad()
381
+ def ea_generate(
382
+ self,
383
+ input_ids,
384
+ temperature=0.0,
385
+ top_p=0.0,
386
+ top_k=0.0,
387
+ max_new_tokens=512,
388
+ max_length=2048,
389
+ log=False,
390
+ is_llama3=False,
391
+
392
+ ):
393
+ if is_llama3:
394
+ stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
395
+ max_length=max_length-self.ea_layer.total_tokens-10
396
+
397
+ if temperature > 1e-5:
398
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
399
+ else:
400
+ logits_processor = None
401
+ #assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
402
+ # Avoid modifying the input_ids in-place
403
+
404
+ padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device)
405
+ input_ids = input_ids.clone()
406
+ self.ea_layer.reset_kv()
407
+
408
+
409
+
410
+ # Initialize the past key and value states
411
+ if hasattr(self, "past_key_values"):
412
+ past_key_values = self.past_key_values
413
+ past_key_values_data = self.past_key_values_data
414
+ current_length_data = self.current_length_data
415
+ # Reset the past key and value states
416
+ current_length_data.zero_()
417
+ else:
418
+ (
419
+ past_key_values,
420
+ past_key_values_data,
421
+ current_length_data,
422
+ ) = initialize_past_key_values(self.base_model)
423
+ self.past_key_values = past_key_values
424
+ self.past_key_values_data = past_key_values_data
425
+ self.current_length_data = current_length_data
426
+
427
+ input_len = input_ids.shape[1]
428
+ reset_tree_mode(self)
429
+ draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
430
+ input_ids, self, past_key_values, logits_processor
431
+ )
432
+ new_token = 0
433
+
434
+ for idx in range(max_length):
435
+ #with Timer("all"):
436
+ self.base_model.model.tree_mask = tree_mask
437
+
438
+ draft_tokens=draft_tokens.to(input_ids.device)
439
+ #with Timer("tree_decoding"):
440
+ logits, hidden_state_new, outputs = tree_decoding(
441
+ self,
442
+ draft_tokens,
443
+ past_key_values,
444
+ tree_position_ids,
445
+ input_ids,
446
+ retrieve_indices,
447
+ )
448
+ #retrieve_indices=tree_buffers["retrieve_indices"]
449
+ #logits = logits[0, retrieve_indices]
450
+ draft_tokens=torch.cat((draft_tokens,padding),dim=1)
451
+ candidates=draft_tokens[0,retrieve_indices]
452
+ best_candidate, accept_length, sample_p = evaluate_posterior(
453
+ logits, candidates, logits_processor
454
+ )
455
+ # print(accept_length)
456
+ #with Timer("update_inference_inputs"):
457
+ input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs(
458
+ input_ids,
459
+ candidates,
460
+ best_candidate,
461
+ accept_length,
462
+ retrieve_indices,
463
+ logits_processor,
464
+ new_token,
465
+ past_key_values_data,
466
+ current_length_data,
467
+ self,
468
+ hidden_state_new,
469
+ sample_p
470
+ )
471
+
472
+ yield input_ids
473
+
474
+ if is_llama3:
475
+ if stop_token_id in input_ids[0, input_len:].tolist():
476
+ break
477
+
478
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
479
+ break
480
+ if new_token > max_new_tokens:
481
+ break
482
+ if input_ids.shape[1] > max_length:
483
+ break
484
+
485
+
486
+ @torch.no_grad()
487
+ def naive_generate(
488
+ self,
489
+ input_ids,
490
+ temperature=0.0,
491
+ top_p=0.0,
492
+ top_k=0.0,
493
+ max_new_tokens=512,
494
+ max_length=2048,
495
+ log=False,
496
+ is_llama3=False,
497
+
498
+ ):
499
+ if is_llama3:
500
+ stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
501
+ max_length = max_length - self.ea_layer.total_tokens - 10
502
+
503
+ if temperature > 1e-5:
504
+ logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
505
+ else:
506
+ logits_processor = None
507
+ # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
508
+ # Avoid modifying the input_ids in-place
509
+
510
+ padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)
511
+ input_ids = input_ids.clone()
512
+ self.ea_layer.reset_kv()
513
+
514
+ # Initialize the past key and value states
515
+ if hasattr(self, "past_key_values"):
516
+ past_key_values = self.past_key_values
517
+ past_key_values_data = self.past_key_values_data
518
+ current_length_data = self.current_length_data
519
+ # Reset the past key and value states
520
+ current_length_data.zero_()
521
+ else:
522
+ (
523
+ past_key_values,
524
+ past_key_values_data,
525
+ current_length_data,
526
+ ) = initialize_past_key_values(self.base_model)
527
+ self.past_key_values = past_key_values
528
+ self.past_key_values_data = past_key_values_data
529
+ self.current_length_data = current_length_data
530
+
531
+ input_len = input_ids.shape[1]
532
+ reset_tree_mode(self)
533
+ outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True)
534
+ new_token = 0
535
+
536
+ for idx in range(max_length):
537
+ if logits_processor is not None:
538
+ logits = outputs.logits[:, -1]
539
+ logits = logits_processor(None, logits)
540
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
541
+ input_id = torch.multinomial(probabilities, 1)
542
+ else:
543
+ input_id = outputs.logits[:, -1:].argmax(dim=-1)
544
+ outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values)
545
+ input_ids = torch.cat([input_ids, input_id], dim=-1)
546
+ new_token += 1
547
+
548
+ yield input_ids
549
+
550
+ if is_llama3:
551
+ if stop_token_id in input_ids[0, input_len:].tolist():
552
+ break
553
+
554
+ if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
555
+ break
556
+ if new_token > max_new_tokens:
557
+ break
558
+ if input_ids.shape[1] > max_length:
559
+ break
560
+
561
+
562
+
model/kv_cache.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class KVCache:
5
+ """
6
+ A key-value cache for the model.
7
+
8
+ This class provides a mechanism to maintain a growing cache of keys and values,
9
+ particularly useful for models that benefit from caching previous states,
10
+ like transformers during autoregressive decoding.
11
+
12
+ Attributes:
13
+ data (torch.Tensor): The tensor storing keys and values.
14
+ current_length (int): Current length of the data being stored.
15
+ """
16
+
17
+ def __init__(self, data, current_length):
18
+ """
19
+ Initialize the KVCache.
20
+
21
+ Args:
22
+ data (torch.Tensor): Initial tensor to store the keys and values.
23
+ current_length (int): Initial length of the data.
24
+ """
25
+ self.data = data
26
+ self.current_length = current_length
27
+
28
+ @property
29
+ def shape(self):
30
+ """Return the shape of the data tensor with updated length."""
31
+ return (
32
+ self.data.shape[0],
33
+ self.data.shape[1],
34
+ self.current_length.item(),
35
+ self.data.shape[3],
36
+ )
37
+
38
+ def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2):
39
+ """
40
+ Copy values from the current data at specified indices to a new location.
41
+
42
+ Args:
43
+ indices (torch.Tensor): Indices of the data tensor to be copied.
44
+ prev_length (int): Previous length before adding new data.
45
+ dim (int, optional): Dimension along which copying should be performed. Default is 2.
46
+ """
47
+ tgt = self.data.index_select(dim, indices)
48
+ dst = self.data.narrow(dim, prev_length, tgt.shape[dim])
49
+ dst.copy_(tgt, non_blocking=True)
50
+ self.current_length.fill_(prev_length + tgt.shape[dim])
51
+
52
+ def cat(self, tensor: torch.Tensor, dim: int = 2):
53
+ """
54
+ Concatenate the given tensor with the current data.
55
+
56
+ Args:
57
+ tensor (torch.Tensor): The tensor to be concatenated.
58
+ dim (int, optional): The dimension along which concatenation should be done. Default is 2.
59
+
60
+ Returns:
61
+ torch.Tensor: The data tensor after concatenation up to the current length.
62
+ """
63
+ dst = self.data.narrow(dim, self.current_length, tensor.shape[dim])
64
+ dst.copy_(tensor)
65
+ self.current_length.add_(tensor.shape[dim])
66
+ return torch.narrow(self.data, 2, 0, self.current_length)
67
+
68
+
69
+ def initialize_past_key_values(model):
70
+ """
71
+ Initialize past key and value states for a given transformer model.
72
+
73
+ This function prepares key-value cache structures for the model, allowing it to store and reuse
74
+ past key and value states during autoregressive decoding, which can improve efficiency.
75
+
76
+ Args:
77
+ model (nn.Module): The transformer model for which past key-value states need to be initialized.
78
+
79
+ Returns:
80
+ tuple:
81
+ - past_key_values (list): A list of KVCache objects for each layer in the model.
82
+ - past_key_values_data (torch.Tensor): The tensor that will store all keys and values.
83
+ - current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache.
84
+ """
85
+ # Extracting configuration from the model
86
+ config = model.config
87
+ # Initializing the batch size to 1, this can be modified if different batch sizes are required
88
+ batch_size = 1
89
+ # Initializing a tensor to store past keys and values for all layers
90
+
91
+ devices=[]
92
+ for i in range(config.num_hidden_layers):
93
+ try:
94
+ device = model.model.layers[i].self_attn.q_proj.weight.device
95
+ except:
96
+ device=model.layers[i].self_attn.q_proj.weight.device
97
+ devices.append(device)
98
+ past_key_values_data_list=[]
99
+ startnum=0
100
+ startdevice=devices[0]
101
+ for id,i in enumerate(devices):
102
+ if startdevice!=i:
103
+ past_key_values_data = torch.zeros(
104
+ startnum * 2,
105
+ batch_size,
106
+ config.num_key_value_heads,
107
+ config.max_position_embeddings,
108
+ config.hidden_size // config.num_attention_heads,
109
+ device=startdevice,
110
+ dtype=model.dtype,
111
+ )
112
+ past_key_values_data_list.append(past_key_values_data)
113
+ startdevice = i
114
+ startnum=0
115
+ startnum += 1
116
+ past_key_values_data = torch.zeros(
117
+ startnum * 2,
118
+ batch_size,
119
+ config.num_key_value_heads,
120
+ config.max_position_embeddings,
121
+ config.hidden_size // config.num_attention_heads,
122
+ device=startdevice,
123
+ dtype=model.dtype,
124
+ )
125
+ past_key_values_data_list.append(past_key_values_data)
126
+ # Initialize tensor to store the current length of the cached data for all layers.
127
+ # [IMPORTANT] It needs to be kept on CPU for quick access and updates.
128
+ current_length_data = torch.zeros(
129
+ config.num_hidden_layers * 2, dtype=torch.long, device="cpu"
130
+ )
131
+ # Creating a KVCache for each pair of key and value in all layers
132
+ past_key_values = [] * config.num_hidden_layers
133
+
134
+ bias=0
135
+ start_data_m=devices[0].index
136
+ for i in range(config.num_hidden_layers):
137
+ data_m=devices[i].index
138
+ if data_m!=start_data_m:
139
+ bias=0
140
+ start_data_m=data_m
141
+ try:
142
+ past_key_values.append(
143
+ [
144
+ KVCache(past_key_values_data_list[data_m-devices[0].index][2*bias + j], current_length_data[i * 2 + j])
145
+ for j in range(2)
146
+ ]
147
+ )
148
+ except:
149
+ past_key_values.append(
150
+ [
151
+ KVCache(past_key_values_data_list[0][2 * bias + j],
152
+ current_length_data[i * 2 + j])
153
+ for j in range(2)
154
+ ]
155
+ )
156
+ bias+=1
157
+ return past_key_values, past_key_values_data_list, current_length_data
model/modeling_llama_kv.py ADDED
@@ -0,0 +1,1398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://github.com/huggingface/transformers/blob/v4.31-release/src/transformers/models/llama/modeling_llama.py
2
+ # Modifications are denoted by the symbol: [MODIFIED]
3
+
4
+
5
+ """ PyTorch LLaMA model."""
6
+ import math
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+
15
+ # [MODIFIED] Import from transformer library
16
+ from transformers.activations import ACT2FN
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ SequenceClassifierOutputWithPast,
21
+ )
22
+ from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import (
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+ from transformers import LlamaConfig
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ _CONFIG_FOR_DOC = "LlamaConfig"
34
+
35
+
36
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
37
+ def _make_causal_mask(
38
+ input_ids_shape: torch.Size,
39
+ dtype: torch.dtype,
40
+ device: torch.device,
41
+ past_key_values_length: int = 0,
42
+ ):
43
+ """
44
+ Create a causal mask for bi-directional self-attention.
45
+
46
+ Args:
47
+ input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len).
48
+ dtype (torch.dtype): The data type of the mask.
49
+ device (torch.device): The device on which the mask will be placed.
50
+ past_key_values_length (int, optional): The length of past key values. Default is 0.
51
+
52
+ Returns:
53
+ torch.Tensor: The causal mask tensor.
54
+ """
55
+ bsz, tgt_len = input_ids_shape
56
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
57
+ mask_cond = torch.arange(mask.size(-1), device=device)
58
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
59
+ mask = mask.to(dtype)
60
+
61
+ if past_key_values_length > 0:
62
+ mask = torch.cat(
63
+ [
64
+ torch.zeros(
65
+ tgt_len, past_key_values_length, dtype=dtype, device=device
66
+ ),
67
+ mask,
68
+ ],
69
+ dim=-1,
70
+ )
71
+ return mask[None, None, :, :].expand(
72
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
73
+ )
74
+
75
+
76
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
77
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
78
+ """
79
+ Expand attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
80
+
81
+ Args:
82
+ mask (torch.Tensor): The attention mask tensor of shape `[bsz, seq_len]`.
83
+ dtype (torch.dtype): The data type of the mask.
84
+ tgt_len (Optional[int], optional): The target sequence length. If None, it defaults to the source sequence length.
85
+
86
+ Returns:
87
+ torch.Tensor: The expanded mask tensor.
88
+ """
89
+ bsz, src_len = mask.size()
90
+ tgt_len = tgt_len if tgt_len is not None else src_len
91
+
92
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
93
+
94
+ inverted_mask = 1.0 - expanded_mask
95
+
96
+ return inverted_mask.masked_fill(
97
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
98
+ )
99
+
100
+
101
+ import torch.nn as nn
102
+ import torch
103
+
104
+
105
+ class LlamaRMSNorm(nn.Module):
106
+ """
107
+ LlamaRMSNorm is equivalent to T5LayerNorm.
108
+
109
+ Args:
110
+ hidden_size (int): The size of the hidden states.
111
+ eps (float, optional): A small value to prevent division by zero. Default is 1e-6.
112
+ """
113
+
114
+ def __init__(self, hidden_size, eps=1e-6):
115
+ super().__init__()
116
+ self.weight = nn.Parameter(torch.ones(hidden_size))
117
+ self.variance_epsilon = eps
118
+
119
+ def forward(self, hidden_states):
120
+ """
121
+ Apply LlamaRMSNorm to the input hidden states.
122
+
123
+ Args:
124
+ hidden_states (torch.Tensor): Input hidden states.
125
+
126
+ Returns:
127
+ torch.Tensor: The normalized and scaled hidden states.
128
+ """
129
+ input_dtype = hidden_states.dtype
130
+ hidden_states = hidden_states.to(torch.float32)
131
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
132
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
133
+ return self.weight * hidden_states.to(input_dtype)
134
+
135
+
136
+ class LlamaRotaryEmbedding(nn.Module):
137
+ """
138
+ Llama Rotary Positional Embedding Module.
139
+
140
+ Args:
141
+ dim (int): The dimension of the embedding.
142
+ max_position_embeddings (int, optional): The maximum position for embeddings. Default is 2048.
143
+ base (int, optional): The base value for rotational encoding. Default is 10000.
144
+ device (str, optional): The device on which the computation will be performed. Default is None.
145
+ """
146
+
147
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
148
+ super().__init__()
149
+
150
+ self.dim = dim
151
+ self.max_position_embeddings = max_position_embeddings
152
+ self.base = base
153
+ inv_freq = 1.0 / (
154
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
155
+ )
156
+ self.register_buffer("inv_freq", inv_freq)
157
+
158
+ # Build here to make `torch.jit.trace` work.
159
+ self._set_cos_sin_cache(
160
+ seq_len=max_position_embeddings,
161
+ device=self.inv_freq.device,
162
+ dtype=torch.get_default_dtype(),
163
+ )
164
+
165
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
166
+ """
167
+ Set the cosine and sine cache for positional embeddings.
168
+
169
+ Args:
170
+ seq_len (int): The sequence length.
171
+ device (str): The device on which the cache tensors will be stored.
172
+ dtype: The data type of the cache tensors.
173
+ """
174
+ self.max_seq_len_cached = seq_len
175
+ t = torch.arange(
176
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
177
+ )
178
+
179
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
180
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
181
+ emb = torch.cat((freqs, freqs), dim=-1)
182
+ self.register_buffer(
183
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
184
+ )
185
+ self.register_buffer(
186
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
187
+ )
188
+
189
+ def forward(self, x, seq_len=None):
190
+ """
191
+ Forward pass of the LlamaRotaryEmbedding module.
192
+
193
+ Args:
194
+ x (torch.Tensor): Input tensor of shape [bs, num_attention_heads, seq_len, head_size].
195
+ seq_len (int): The sequence length. If greater than the cached length, the cache will be updated.
196
+
197
+ Returns:
198
+ tuple: A tuple containing two tensors, the cosine and sine embeddings, both of shape [1, 1, seq_len, dim].
199
+ """
200
+ if seq_len > self.max_seq_len_cached:
201
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
202
+
203
+ return (
204
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
205
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
206
+ )
207
+
208
+
209
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
210
+ """
211
+ LlamaRotaryEmbedding extended with linear scaling.
212
+
213
+ This class adds linear scaling to LlamaRotaryEmbedding. Credits to the Reddit user /u/kaiokendev.
214
+
215
+ Args:
216
+ dim (int): The dimension of the embedding.
217
+ max_position_embeddings (int, optional): The maximum number of position embeddings. Default is 2048.
218
+ base (int, optional): The base value for the rotational embeddings. Default is 10000.
219
+ device (str or torch.device, optional): The device where the embeddings should be stored. Default is None.
220
+ scaling_factor (float, optional): The scaling factor for the embeddings. Default is 1.0.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ dim,
226
+ max_position_embeddings=2048,
227
+ base=10000,
228
+ device=None,
229
+ scaling_factor=1.0,
230
+ ):
231
+ self.scaling_factor = scaling_factor
232
+ super().__init__(dim, max_position_embeddings, base, device)
233
+
234
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
235
+ """
236
+ Set the cosine and sine cache for the rotary embeddings.
237
+
238
+ Args:
239
+ seq_len (int): The sequence length.
240
+ device (str or torch.device): The device where the cache should be stored.
241
+ dtype: The data type for the cache.
242
+ """
243
+ self.max_seq_len_cached = seq_len
244
+ t = torch.arange(
245
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
246
+ )
247
+ t = t / self.scaling_factor
248
+
249
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
250
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
251
+ emb = torch.cat((freqs, freqs), dim=-1)
252
+ self.register_buffer(
253
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
254
+ )
255
+ self.register_buffer(
256
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
257
+ )
258
+
259
+
260
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
261
+ """
262
+ LlamaRotaryEmbedding extended with Dynamic NTK scaling.
263
+
264
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ dim,
270
+ max_position_embeddings=2048,
271
+ base=10000,
272
+ device=None,
273
+ scaling_factor=1.0,
274
+ ):
275
+ """
276
+ Initialize the LlamaDynamicNTKScalingRotaryEmbedding.
277
+
278
+ Args:
279
+ dim (int): The dimensionality of the embedding.
280
+ max_position_embeddings (int, optional): Maximum number of position embeddings. Default is 2048.
281
+ base (int, optional): Base value for scaling calculations. Default is 10000.
282
+ device: The device to place tensors on. If None, uses the default device.
283
+ scaling_factor (float, optional): Scaling factor for NTK scaling. Default is 1.0.
284
+ """
285
+ self.scaling_factor = scaling_factor
286
+ super().__init__(dim, max_position_embeddings, base, device)
287
+
288
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
289
+ """
290
+ Set the cached values for cosine and sine.
291
+
292
+ Args:
293
+ seq_len (int): The sequence length.
294
+ device: The device to place tensors on.
295
+ dtype: The data type of tensors.
296
+ """
297
+ self.max_seq_len_cached = seq_len
298
+
299
+ if seq_len > self.max_position_embeddings:
300
+ base = self.base * (
301
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
302
+ - (self.scaling_factor - 1)
303
+ ) ** (self.dim / (self.dim - 2))
304
+ inv_freq = 1.0 / (
305
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
306
+ )
307
+ self.register_buffer("inv_freq", inv_freq)
308
+
309
+ t = torch.arange(
310
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
311
+ )
312
+
313
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
314
+ emb = torch.cat((freqs, freqs), dim=-1)
315
+ self.register_buffer(
316
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
317
+ )
318
+ self.register_buffer(
319
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
320
+ )
321
+
322
+
323
+ def rotate_half(x):
324
+ """
325
+ Rotates half the hidden dimensions of the input.
326
+
327
+ Args:
328
+ x (torch.Tensor): Input tensor.
329
+
330
+ Returns:
331
+ torch.Tensor: Tensor with half of its hidden dimensions rotated.
332
+ """
333
+ x1 = x[..., : x.shape[-1] // 2]
334
+ x2 = x[..., x.shape[-1] // 2:]
335
+ return torch.cat((-x2, x1), dim=-1)
336
+
337
+
338
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
339
+ """
340
+ Apply rotary position embeddings to query and key tensors.
341
+
342
+ Args:
343
+ q (torch.Tensor): Query tensor.
344
+ k (torch.Tensor): Key tensor.
345
+ cos (torch.Tensor): Cosine values.
346
+ sin (torch.Tensor): Sine values.
347
+ position_ids (torch.Tensor): Position IDs.
348
+
349
+ Returns:
350
+ torch.Tensor: Query and key tensors with rotary position embeddings applied.
351
+ """
352
+ cos = cos.squeeze(1).squeeze(0)
353
+ sin = sin.squeeze(1).squeeze(0)
354
+ cos = cos[position_ids].unsqueeze(1)
355
+ sin = sin[position_ids].unsqueeze(1)
356
+ q_embed = (q * cos) + (rotate_half(q) * sin)
357
+ k_embed = (k * cos) + (rotate_half(k) * sin)
358
+ return q_embed, k_embed
359
+
360
+
361
+ class LlamaMLP(nn.Module):
362
+ """
363
+ LlamaMLP is a multi-layer perceptron module used in the Llama model.
364
+
365
+ Args:
366
+ config: The configuration for the MLP.
367
+
368
+ Attributes:
369
+ pretraining_tp (int): The pretraining time periods.
370
+ hidden_size (int): The size of the hidden layer.
371
+ intermediate_size (int): The size of the intermediate layer.
372
+ gate_proj (nn.Linear): The linear projection for gating.
373
+ up_proj (nn.Linear): The linear projection for the up projection.
374
+ down_proj (nn.Linear): The linear projection for the down projection.
375
+ act_fn: The activation function.
376
+
377
+ """
378
+
379
+ def __init__(self, config):
380
+ super().__init__()
381
+ self.pretraining_tp = config.pretraining_tp
382
+ self.hidden_size = config.hidden_size
383
+ self.intermediate_size = config.intermediate_size
384
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
385
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
386
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
387
+ self.act_fn = ACT2FN[config.hidden_act]
388
+
389
+ def forward(self, x):
390
+ """
391
+ Forward pass of the MLP.
392
+
393
+ Args:
394
+ x: Input tensor.
395
+
396
+ Returns:
397
+ torch.Tensor: Output tensor.
398
+ """
399
+ if self.pretraining_tp > 1:
400
+ slice = self.intermediate_size // self.pretraining_tp
401
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
402
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
403
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
404
+
405
+ gate_proj = torch.cat(
406
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)],
407
+ dim=-1,
408
+ )
409
+ up_proj = torch.cat(
410
+ [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)],
411
+ dim=-1,
412
+ )
413
+
414
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
415
+ down_proj = [
416
+ F.linear(intermediate_states[i], down_proj_slices[i])
417
+ for i in range(self.pretraining_tp)
418
+ ]
419
+ down_proj = sum(down_proj)
420
+ else:
421
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
422
+
423
+ return down_proj
424
+
425
+
426
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
427
+ """
428
+ Repeat key and value tensors n times along the specified dimension.
429
+
430
+ Args:
431
+ hidden_states (torch.Tensor): Input tensor with shape (batch, num_key_value_heads, seqlen, head_dim).
432
+ n_rep (int): Number of times to repeat.
433
+
434
+ Returns:
435
+ torch.Tensor: Repeated tensor with shape (batch, num_key_value_heads * n_rep, seqlen, head_dim).
436
+ """
437
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
438
+ if n_rep == 1:
439
+ return hidden_states
440
+ hidden_states = hidden_states[:, :, None, :, :].expand(
441
+ batch, num_key_value_heads, n_rep, slen, head_dim
442
+ )
443
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
444
+
445
+
446
+ class LlamaAttention(nn.Module):
447
+ """
448
+ LlamaAttention is a multi-headed attention module based on the 'Attention Is All You Need' paper.
449
+
450
+ Args:
451
+ config (LlamaConfig): Configuration for the attention module.
452
+
453
+ Attributes:
454
+ config (LlamaConfig): Configuration for the attention module.
455
+ hidden_size (int): The size of the hidden layer.
456
+ num_heads (int): The number of attention heads.
457
+ head_dim (int): The dimension of each attention head.
458
+ num_key_value_heads (int): The number of key-value attention heads.
459
+ num_key_value_groups (int): The number of key-value groups.
460
+ pretraining_tp (int): The pretraining time periods.
461
+ max_position_embeddings (int): The maximum position embeddings.
462
+
463
+ """
464
+
465
+ def __init__(self, config: LlamaConfig):
466
+ super().__init__()
467
+ self.config = config
468
+ self.hidden_size = config.hidden_size
469
+ self.num_heads = config.num_attention_heads
470
+ self.head_dim = self.hidden_size // self.num_heads
471
+ self.num_key_value_heads = config.num_key_value_heads
472
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
473
+ self.pretraining_tp = config.pretraining_tp
474
+ self.max_position_embeddings = config.max_position_embeddings
475
+
476
+ if (self.head_dim * self.num_heads) != self.hidden_size:
477
+ raise ValueError(
478
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
479
+ f" and `num_heads`: {self.num_heads})."
480
+ )
481
+ self.q_proj = nn.Linear(
482
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
483
+ )
484
+ self.k_proj = nn.Linear(
485
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
486
+ )
487
+ self.v_proj = nn.Linear(
488
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
489
+ )
490
+ self.o_proj = nn.Linear(
491
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
492
+ )
493
+ self._init_rope()
494
+
495
+ def _init_rope(self):
496
+ if self.config.rope_scaling is None:
497
+ self.rotary_emb = LlamaRotaryEmbedding(
498
+ self.head_dim, max_position_embeddings=self.max_position_embeddings,base=self.config.rope_theta
499
+ )
500
+ else:
501
+ scaling_type = self.config.rope_scaling["type"]
502
+ scaling_factor = self.config.rope_scaling["factor"]
503
+ if scaling_type == "linear":
504
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
505
+ self.head_dim,
506
+ max_position_embeddings=self.max_position_embeddings,
507
+ scaling_factor=scaling_factor,
508
+ )
509
+ elif scaling_type == "dynamic":
510
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
511
+ self.head_dim,
512
+ max_position_embeddings=self.max_position_embeddings,
513
+ scaling_factor=scaling_factor,
514
+ )
515
+ else:
516
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
517
+
518
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
519
+ return (
520
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
521
+ .transpose(1, 2)
522
+ .contiguous()
523
+ )
524
+
525
+ def forward(
526
+ self,
527
+ hidden_states: torch.Tensor,
528
+ attention_mask: Optional[torch.Tensor] = None,
529
+ position_ids: Optional[torch.LongTensor] = None,
530
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
531
+ output_attentions: bool = False,
532
+ use_cache: bool = False,
533
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
534
+ bsz, q_len, _ = hidden_states.size()
535
+
536
+ if self.pretraining_tp > 1:
537
+ key_value_slicing = (
538
+ self.num_key_value_heads * self.head_dim
539
+ ) // self.pretraining_tp
540
+ query_slices = self.q_proj.weight.split(
541
+ (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
542
+ )
543
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
544
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
545
+
546
+ query_states = [
547
+ F.linear(hidden_states, query_slices[i])
548
+ for i in range(self.pretraining_tp)
549
+ ]
550
+ query_states = torch.cat(query_states, dim=-1)
551
+
552
+ key_states = [
553
+ F.linear(hidden_states, key_slices[i])
554
+ for i in range(self.pretraining_tp)
555
+ ]
556
+ key_states = torch.cat(key_states, dim=-1)
557
+
558
+ value_states = [
559
+ F.linear(hidden_states, value_slices[i])
560
+ for i in range(self.pretraining_tp)
561
+ ]
562
+ value_states = torch.cat(value_states, dim=-1)
563
+
564
+ else:
565
+ query_states = self.q_proj(hidden_states)
566
+ key_states = self.k_proj(hidden_states)
567
+ value_states = self.v_proj(hidden_states)
568
+
569
+ query_states = query_states.view(
570
+ bsz, q_len, self.num_heads, self.head_dim
571
+ ).transpose(1, 2)
572
+ key_states = key_states.view(
573
+ bsz, q_len, self.num_key_value_heads, self.head_dim
574
+ ).transpose(1, 2)
575
+ value_states = value_states.view(
576
+ bsz, q_len, self.num_key_value_heads, self.head_dim
577
+ ).transpose(1, 2)
578
+
579
+ kv_seq_len = key_states.shape[-2]
580
+ if past_key_value is not None:
581
+ kv_seq_len += past_key_value[0].shape[-2]
582
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
583
+ query_states, key_states = apply_rotary_pos_emb(
584
+ query_states, key_states, cos, sin, position_ids
585
+ )
586
+
587
+ # [MODIFIED] Using KVCache mechanism for preallocated GPU memory optimization
588
+ # past_key_value is utilized to leverage previously computed key and value states.
589
+ # If past_key_value is available, reuse the states for k, v, and self_attention.
590
+ if past_key_value is not None:
591
+ key_states = past_key_value[0].cat(key_states, dim=2)
592
+ value_states = past_key_value[1].cat(value_states, dim=2)
593
+ # Reset past_key_value to avoid return past_key_value.
594
+ past_key_value = None
595
+
596
+ # repeat k/v heads if n_kv_heads < n_heads
597
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
598
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
599
+
600
+ attn_weights = torch.matmul(
601
+ query_states, key_states.transpose(2, 3)
602
+ ) / math.sqrt(self.head_dim)
603
+
604
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
605
+ raise ValueError(
606
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
607
+ f" {attn_weights.size()}"
608
+ )
609
+
610
+ if attention_mask is not None:
611
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
612
+ raise ValueError(
613
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
614
+ )
615
+ attn_weights = attn_weights + attention_mask
616
+
617
+ # upcast attention to fp32
618
+ attn_weights = nn.functional.softmax(
619
+ attn_weights, dim=-1, dtype=torch.float32
620
+ ).to(query_states.dtype)
621
+ attn_output = torch.matmul(attn_weights, value_states)
622
+
623
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
624
+ raise ValueError(
625
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
626
+ f" {attn_output.size()}"
627
+ )
628
+
629
+ attn_output = attn_output.transpose(1, 2).contiguous()
630
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
631
+
632
+ if self.pretraining_tp > 1:
633
+ attn_output = attn_output.split(
634
+ self.hidden_size // self.pretraining_tp, dim=2
635
+ )
636
+ o_proj_slices = self.o_proj.weight.split(
637
+ self.hidden_size // self.pretraining_tp, dim=1
638
+ )
639
+ attn_output = sum(
640
+ [
641
+ F.linear(attn_output[i], o_proj_slices[i])
642
+ for i in range(self.pretraining_tp)
643
+ ]
644
+ )
645
+ else:
646
+ attn_output = self.o_proj(attn_output)
647
+
648
+ if not output_attentions:
649
+ attn_weights = None
650
+
651
+ return attn_output, attn_weights, past_key_value
652
+
653
+
654
+ class LlamaDecoderLayer(nn.Module):
655
+ """
656
+ LlamaDecoderLayer represents a single layer of the Llama decoder.
657
+
658
+ Args:
659
+ config (LlamaConfig): Configuration for the decoder layer.
660
+
661
+ Attributes:
662
+ hidden_size (int): The size of the hidden layer.
663
+ self_attn (LlamaAttention): Multi-headed self-attention module.
664
+ mlp (LlamaMLP): Multi-layer perceptron module.
665
+ input_layernorm (LlamaRMSNorm): Layer normalization for input.
666
+ post_attention_layernorm (LlamaRMSNorm): Layer normalization after self-attention.
667
+ """
668
+
669
+ def __init__(self, config: LlamaConfig):
670
+ super().__init__()
671
+ self.hidden_size = config.hidden_size
672
+ self.self_attn = LlamaAttention(config=config)
673
+ self.mlp = LlamaMLP(config)
674
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
675
+ self.post_attention_layernorm = LlamaRMSNorm(
676
+ config.hidden_size, eps=config.rms_norm_eps
677
+ )
678
+
679
+ def forward(
680
+ self,
681
+ hidden_states: torch.Tensor,
682
+ attention_mask: Optional[torch.Tensor] = None,
683
+ position_ids: Optional[torch.LongTensor] = None,
684
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
685
+ output_attentions: Optional[bool] = False,
686
+ use_cache: Optional[bool] = False,
687
+ ) -> Tuple[
688
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
689
+ ]:
690
+ """
691
+ Forward pass for the LlamaDecoderLayer.
692
+
693
+ Args:
694
+ hidden_states (torch.FloatTensor): Input tensor of shape `(batch, seq_len, embed_dim)`.
695
+ attention_mask (torch.FloatTensor, optional): Attention mask of size
696
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
697
+ position_ids (torch.LongTensor, optional): Positional IDs tensor.
698
+ past_key_value (Tuple[torch.FloatTensor], optional): Cached past key and value projection states.
699
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers.
700
+ use_cache (bool, optional): If set to `True`, `past_key_values` key-value states are returned and can be
701
+ used to speed up decoding.
702
+
703
+ Returns:
704
+ Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Tuple containing:
705
+ - hidden_states (torch.FloatTensor): Output tensor.
706
+ - self_attn_weights (Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]): Self-attention weights if
707
+ `output_attentions` is `True`.
708
+ - present_key_value (Optional[Tuple[torch.FloatTensor]]): Cached key and value projection states if
709
+ `use_cache` is `True`.
710
+ """
711
+
712
+ residual = hidden_states
713
+
714
+ hidden_states = self.input_layernorm(hidden_states)
715
+
716
+ # Self Attention
717
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
718
+ hidden_states=hidden_states,
719
+ attention_mask=attention_mask,
720
+ position_ids=position_ids,
721
+ past_key_value=past_key_value,
722
+ output_attentions=output_attentions,
723
+ use_cache=use_cache,
724
+ )
725
+ hidden_states = residual + hidden_states
726
+
727
+ # Fully Connected
728
+ residual = hidden_states
729
+ hidden_states = self.post_attention_layernorm(hidden_states)
730
+ hidden_states = self.mlp(hidden_states)
731
+ hidden_states = residual + hidden_states
732
+
733
+ outputs = (hidden_states,)
734
+
735
+ if output_attentions:
736
+ outputs += (self_attn_weights,)
737
+
738
+ if use_cache:
739
+ outputs += (present_key_value,)
740
+
741
+ return outputs
742
+
743
+
744
+ LLAMA_START_DOCSTRING = r"""
745
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
746
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
747
+ etc.)
748
+
749
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
750
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
751
+ and behavior.
752
+
753
+ Parameters:
754
+ config ([`LlamaConfig`]):
755
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
756
+ load the weights associated with the model, only the configuration. Check out the
757
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
758
+ """
759
+
760
+
761
+ @add_start_docstrings(
762
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
763
+ LLAMA_START_DOCSTRING,
764
+ )
765
+ class LlamaPreTrainedModel(PreTrainedModel):
766
+ config_class = LlamaConfig
767
+ base_model_prefix = "model"
768
+ supports_gradient_checkpointing = True
769
+ _no_split_modules = ["LlamaDecoderLayer"]
770
+ _skip_keys_device_placement = "past_key_values"
771
+
772
+ def _init_weights(self, module):
773
+ std = self.config.initializer_range
774
+ if isinstance(module, nn.Linear):
775
+ module.weight.data.normal_(mean=0.0, std=std)
776
+ if module.bias is not None:
777
+ module.bias.data.zero_()
778
+ elif isinstance(module, nn.Embedding):
779
+ module.weight.data.normal_(mean=0.0, std=std)
780
+ if module.padding_idx is not None:
781
+ module.weight.data[module.padding_idx].zero_()
782
+
783
+ def _set_gradient_checkpointing(self, module, value=False):
784
+ if isinstance(module, LlamaModel):
785
+ module.gradient_checkpointing = value
786
+
787
+
788
+ LLAMA_INPUTS_DOCSTRING = r"""
789
+ Args:
790
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
791
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
792
+ it.
793
+
794
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
795
+ [`PreTrainedTokenizer.__call__`] for details.
796
+
797
+ [What are input IDs?](../glossary#input-ids)
798
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
799
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
800
+
801
+ - 1 for tokens that are **not masked**,
802
+ - 0 for tokens that are **masked**.
803
+
804
+ [What are attention masks?](../glossary#attention-mask)
805
+
806
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
807
+ [`PreTrainedTokenizer.__call__`] for details.
808
+
809
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
810
+ `past_key_values`).
811
+
812
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
813
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
814
+ information on the default strategy.
815
+
816
+ - 1 indicates the head is **not masked**,
817
+ - 0 indicates the head is **masked**.
818
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
819
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
820
+ config.n_positions - 1]`.
821
+
822
+ [What are position IDs?](../glossary#position-ids)
823
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
824
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
825
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
826
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
827
+
828
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
829
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
830
+
831
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
832
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
833
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
834
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
835
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
836
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
837
+ model's internal embedding lookup matrix.
838
+ use_cache (`bool`, *optional*):
839
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
840
+ `past_key_values`).
841
+ output_attentions (`bool`, *optional*):
842
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
843
+ tensors for more detail.
844
+ output_hidden_states (`bool`, *optional*):
845
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
846
+ more detail.
847
+ return_dict (`bool`, *optional*):
848
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
849
+ """
850
+
851
+
852
+ @add_start_docstrings(
853
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
854
+ LLAMA_START_DOCSTRING,
855
+ )
856
+ class LlamaModel(LlamaPreTrainedModel):
857
+ """
858
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
859
+
860
+ Args:
861
+ config: LlamaConfig
862
+ """
863
+
864
+ def __init__(self, config: LlamaConfig):
865
+ super().__init__(config)
866
+ self.padding_idx = config.pad_token_id
867
+ self.vocab_size = config.vocab_size
868
+
869
+ self.embed_tokens = nn.Embedding(
870
+ config.vocab_size, config.hidden_size, self.padding_idx
871
+ )
872
+ self.layers = nn.ModuleList(
873
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
874
+ )
875
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
876
+
877
+ self.gradient_checkpointing = False
878
+ # Initialize weights and apply final processing
879
+ self.post_init()
880
+
881
+ def get_input_embeddings(self):
882
+ return self.embed_tokens
883
+
884
+ def set_input_embeddings(self, value):
885
+ self.embed_tokens = value
886
+
887
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
888
+ def _prepare_decoder_attention_mask(
889
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
890
+ ):
891
+ # create causal mask
892
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
893
+ combined_attention_mask = None
894
+ if input_shape[-1] > 1:
895
+ combined_attention_mask = _make_causal_mask(
896
+ input_shape,
897
+ # inputs_embeds.dtype,
898
+ torch.float32, # [MODIFIED] force to cast to float32
899
+ device=inputs_embeds.device,
900
+ past_key_values_length=past_key_values_length,
901
+ )
902
+
903
+ if attention_mask is not None:
904
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
905
+ expanded_attn_mask = _expand_mask(
906
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
907
+ ).to(inputs_embeds.device)
908
+ combined_attention_mask = (
909
+ expanded_attn_mask
910
+ if combined_attention_mask is None
911
+ else expanded_attn_mask + combined_attention_mask
912
+ )
913
+
914
+
915
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
916
+ tree_mask = self.tree_mask
917
+ tree_len = tree_mask.size(-1)
918
+ combined_attention_mask[:, :, -tree_len:, -tree_len:][
919
+ tree_mask == 0
920
+ ] = combined_attention_mask.min()
921
+
922
+ return combined_attention_mask
923
+
924
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
925
+ def forward(
926
+ self,
927
+ input_ids: torch.LongTensor = None,
928
+ attention_mask: Optional[torch.Tensor] = None,
929
+ position_ids: Optional[torch.LongTensor] = None,
930
+ past_key_values=None, # [MODIFIED] past_key_value is KVCache class
931
+ inputs_embeds: Optional[torch.FloatTensor] = None,
932
+ use_cache: Optional[bool] = None,
933
+ output_attentions: Optional[bool] = None,
934
+ output_hidden_states: Optional[bool] = None,
935
+ return_dict: Optional[bool] = None,
936
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
937
+ output_attentions = (
938
+ output_attentions
939
+ if output_attentions is not None
940
+ else self.config.output_attentions
941
+ )
942
+ output_hidden_states = (
943
+ output_hidden_states
944
+ if output_hidden_states is not None
945
+ else self.config.output_hidden_states
946
+ )
947
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
948
+
949
+ return_dict = (
950
+ return_dict if return_dict is not None else self.config.use_return_dict
951
+ )
952
+
953
+ # retrieve input_ids and inputs_embeds
954
+ if input_ids is not None and inputs_embeds is not None:
955
+ raise ValueError(
956
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
957
+ )
958
+ elif input_ids is not None:
959
+ batch_size, seq_length = input_ids.shape
960
+ elif inputs_embeds is not None:
961
+ batch_size, seq_length, _ = inputs_embeds.shape
962
+ else:
963
+ raise ValueError(
964
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
965
+ )
966
+
967
+ seq_length_with_past = seq_length
968
+ past_key_values_length = 0
969
+
970
+ if past_key_values is not None:
971
+ past_key_values_length = past_key_values[0][0].shape[2]
972
+ seq_length_with_past = seq_length_with_past + past_key_values_length
973
+
974
+ if position_ids is None:
975
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
976
+ position_ids = torch.arange(
977
+ past_key_values_length,
978
+ seq_length + past_key_values_length,
979
+ dtype=torch.long,
980
+ device=device,
981
+ )
982
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
983
+ else:
984
+ position_ids = position_ids.view(-1, seq_length).long()
985
+
986
+ if inputs_embeds is None:
987
+ inputs_embeds = self.embed_tokens(input_ids)
988
+ # embed positions
989
+ if attention_mask is None:
990
+ attention_mask = torch.ones(
991
+ (batch_size, seq_length_with_past),
992
+ dtype=torch.bool,
993
+ device=inputs_embeds.device,
994
+ )
995
+ attention_mask = self._prepare_decoder_attention_mask(
996
+ attention_mask,
997
+ (batch_size, seq_length),
998
+ inputs_embeds,
999
+ past_key_values_length,
1000
+ )
1001
+
1002
+ hidden_states = inputs_embeds
1003
+
1004
+ if self.gradient_checkpointing and self.training:
1005
+ if use_cache:
1006
+ logger.warning_once(
1007
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1008
+ )
1009
+ use_cache = False
1010
+
1011
+ # decoder layers
1012
+ all_hidden_states = () if output_hidden_states else None
1013
+ all_self_attns = () if output_attentions else None
1014
+ next_decoder_cache = () if use_cache else None
1015
+
1016
+ for idx, decoder_layer in enumerate(self.layers):
1017
+ # if idx==16:
1018
+ # print(idx)
1019
+ if output_hidden_states:
1020
+ all_hidden_states += (hidden_states,)
1021
+
1022
+ past_key_value = (
1023
+ past_key_values[idx] if past_key_values is not None else None
1024
+ )
1025
+
1026
+ if self.gradient_checkpointing and self.training:
1027
+
1028
+ def create_custom_forward(module):
1029
+ def custom_forward(*inputs):
1030
+ # None for past_key_value
1031
+ return module(*inputs, output_attentions, None)
1032
+
1033
+ return custom_forward
1034
+
1035
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1036
+ create_custom_forward(decoder_layer),
1037
+ hidden_states,
1038
+ attention_mask,
1039
+ position_ids,
1040
+ None,
1041
+ )
1042
+ else:
1043
+ layer_outputs = decoder_layer(
1044
+ hidden_states,
1045
+ attention_mask=attention_mask,
1046
+ position_ids=position_ids,
1047
+ past_key_value=past_key_value,
1048
+ output_attentions=output_attentions,
1049
+ use_cache=use_cache,
1050
+ )
1051
+
1052
+ hidden_states = layer_outputs[0]
1053
+
1054
+ if use_cache:
1055
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1056
+
1057
+ if output_attentions:
1058
+ all_self_attns += (layer_outputs[1],)
1059
+
1060
+ hidden_states = self.norm(hidden_states)
1061
+
1062
+ # add hidden states from the last decoder layer
1063
+ if output_hidden_states:
1064
+ all_hidden_states += (hidden_states,)
1065
+
1066
+ next_cache = next_decoder_cache if use_cache else None
1067
+ if not return_dict:
1068
+ return tuple(
1069
+ v
1070
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1071
+ if v is not None
1072
+ )
1073
+ return BaseModelOutputWithPast(
1074
+ last_hidden_state=hidden_states,
1075
+ past_key_values=next_cache,
1076
+ hidden_states=all_hidden_states,
1077
+ attentions=all_self_attns,
1078
+ )
1079
+
1080
+
1081
+ class LlamaForCausalLM(LlamaPreTrainedModel):
1082
+ _tied_weights_keys = ["lm_head.weight"]
1083
+
1084
+ def __init__(self, config):
1085
+ super().__init__(config)
1086
+ self.model = LlamaModel(config)
1087
+ self.pretraining_tp = config.pretraining_tp
1088
+ self.vocab_size = config.vocab_size
1089
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1090
+
1091
+ # Initialize weights and apply final processing
1092
+ self.post_init()
1093
+
1094
+ def get_input_embeddings(self):
1095
+ return self.model.embed_tokens
1096
+
1097
+ def set_input_embeddings(self, value):
1098
+ self.model.embed_tokens = value
1099
+
1100
+ def get_output_embeddings(self):
1101
+ return self.lm_head
1102
+
1103
+ def set_output_embeddings(self, new_embeddings):
1104
+ self.lm_head = new_embeddings
1105
+
1106
+ def set_decoder(self, decoder):
1107
+ self.model = decoder
1108
+
1109
+ def get_decoder(self):
1110
+ return self.model
1111
+
1112
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1113
+ @replace_return_docstrings(
1114
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1115
+ )
1116
+ def forward(
1117
+ self,
1118
+ input_ids: torch.LongTensor = None,
1119
+ attention_mask: Optional[torch.Tensor] = None,
1120
+ position_ids: Optional[torch.LongTensor] = None,
1121
+ past_key_values=None, # [MODIFIED] past_key_value is KVCache class
1122
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1123
+ labels: Optional[torch.LongTensor] = None,
1124
+ use_cache: Optional[bool] = None,
1125
+ output_attentions: Optional[bool] = None,
1126
+ output_hidden_states: Optional[bool] = None,
1127
+ return_dict: Optional[bool] = None,
1128
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1129
+ r"""
1130
+ Args:
1131
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1132
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1133
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1134
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1135
+
1136
+ Returns:
1137
+
1138
+ Example:
1139
+
1140
+ ```python
1141
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1142
+
1143
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1144
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1145
+
1146
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1147
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1148
+
1149
+ >>> # Generate
1150
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1151
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1152
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1153
+ ```"""
1154
+
1155
+ output_attentions = (
1156
+ output_attentions
1157
+ if output_attentions is not None
1158
+ else self.config.output_attentions
1159
+ )
1160
+ output_hidden_states = (
1161
+ output_hidden_states
1162
+ if output_hidden_states is not None
1163
+ else self.config.output_hidden_states
1164
+ )
1165
+ return_dict = (
1166
+ return_dict if return_dict is not None else self.config.use_return_dict
1167
+ )
1168
+
1169
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1170
+ outputs = self.model(
1171
+ input_ids=input_ids,
1172
+ attention_mask=attention_mask,
1173
+ position_ids=position_ids,
1174
+ past_key_values=past_key_values,
1175
+ inputs_embeds=inputs_embeds,
1176
+ use_cache=use_cache,
1177
+ output_attentions=output_attentions,
1178
+ output_hidden_states=output_hidden_states,
1179
+ return_dict=return_dict,
1180
+ )
1181
+
1182
+ hidden_states = outputs[0]
1183
+ if self.pretraining_tp > 1:
1184
+ lm_head_slices = self.lm_head.weight.split(
1185
+ self.vocab_size // self.pretraining_tp, dim=0
1186
+ )
1187
+ logits = [
1188
+ F.linear(hidden_states, lm_head_slices[i])
1189
+ for i in range(self.pretraining_tp)
1190
+ ]
1191
+ logits = torch.cat(logits, dim=-1)
1192
+ else:
1193
+ logits = self.lm_head(hidden_states)
1194
+ logits = logits.float()
1195
+
1196
+ loss = None
1197
+ if labels is not None:
1198
+ # Shift so that tokens < n predict n
1199
+ shift_logits = logits[..., :-1, :].contiguous()
1200
+ shift_labels = labels[..., 1:].contiguous()
1201
+ # Flatten the tokens
1202
+ loss_fct = CrossEntropyLoss()
1203
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1204
+ shift_labels = shift_labels.view(-1)
1205
+ # Enable model parallelism
1206
+ shift_labels = shift_labels.to(shift_logits.device)
1207
+ loss = loss_fct(shift_logits, shift_labels)
1208
+
1209
+ if not return_dict:
1210
+ output = (logits,) + outputs[1:]
1211
+ return (loss,) + output if loss is not None else output
1212
+
1213
+ return CausalLMOutputWithPast(
1214
+ loss=loss,
1215
+ logits=logits,
1216
+ past_key_values=outputs.past_key_values,
1217
+ hidden_states=outputs.hidden_states,
1218
+ attentions=outputs.attentions,
1219
+ )
1220
+
1221
+ def prepare_inputs_for_generation(
1222
+ self,
1223
+ input_ids,
1224
+ past_key_values=None,
1225
+ attention_mask=None,
1226
+ inputs_embeds=None,
1227
+ **kwargs,
1228
+ ):
1229
+ if past_key_values:
1230
+ input_ids = input_ids[:, -1:]
1231
+
1232
+ position_ids = kwargs.get("position_ids", None)
1233
+ if attention_mask is not None and position_ids is None:
1234
+ # create position_ids on the fly for batch generation
1235
+ position_ids = attention_mask.long().cumsum(-1) - 1
1236
+ position_ids.masked_fill_(attention_mask == 0, 1)
1237
+ if past_key_values:
1238
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1239
+
1240
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1241
+ if inputs_embeds is not None and past_key_values is None:
1242
+ model_inputs = {"inputs_embeds": inputs_embeds}
1243
+ else:
1244
+ model_inputs = {"input_ids": input_ids}
1245
+
1246
+ model_inputs.update(
1247
+ {
1248
+ "position_ids": position_ids,
1249
+ "past_key_values": past_key_values,
1250
+ "use_cache": kwargs.get("use_cache"),
1251
+ "attention_mask": attention_mask,
1252
+ }
1253
+ )
1254
+ return model_inputs
1255
+
1256
+ @staticmethod
1257
+ def _reorder_cache(past_key_values, beam_idx):
1258
+ reordered_past = ()
1259
+ for layer_past in past_key_values:
1260
+ reordered_past += (
1261
+ tuple(
1262
+ past_state.index_select(0, beam_idx.to(past_state.device))
1263
+ for past_state in layer_past
1264
+ ),
1265
+ )
1266
+ return reordered_past
1267
+
1268
+
1269
+ @add_start_docstrings(
1270
+ """
1271
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1272
+
1273
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1274
+ (e.g. GPT-2) do.
1275
+
1276
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1277
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1278
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1279
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1280
+ each row of the batch).
1281
+ """,
1282
+ LLAMA_START_DOCSTRING,
1283
+ )
1284
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1285
+ def __init__(self, config):
1286
+ super().__init__(config)
1287
+ self.num_labels = config.num_labels
1288
+ self.model = LlamaModel(config)
1289
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1290
+
1291
+ # Initialize weights and apply final processing
1292
+ self.post_init()
1293
+
1294
+ def get_input_embeddings(self):
1295
+ return self.model.embed_tokens
1296
+
1297
+ def set_input_embeddings(self, value):
1298
+ self.model.embed_tokens = value
1299
+
1300
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1301
+ def forward(
1302
+ self,
1303
+ input_ids: torch.LongTensor = None,
1304
+ attention_mask: Optional[torch.Tensor] = None,
1305
+ position_ids: Optional[torch.LongTensor] = None,
1306
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1307
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1308
+ labels: Optional[torch.LongTensor] = None,
1309
+ use_cache: Optional[bool] = None,
1310
+ output_attentions: Optional[bool] = None,
1311
+ output_hidden_states: Optional[bool] = None,
1312
+ return_dict: Optional[bool] = None,
1313
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1314
+ r"""
1315
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1316
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1317
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1318
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1319
+ """
1320
+ return_dict = (
1321
+ return_dict if return_dict is not None else self.config.use_return_dict
1322
+ )
1323
+
1324
+ transformer_outputs = self.model(
1325
+ input_ids,
1326
+ attention_mask=attention_mask,
1327
+ position_ids=position_ids,
1328
+ past_key_values=past_key_values,
1329
+ inputs_embeds=inputs_embeds,
1330
+ use_cache=use_cache,
1331
+ output_attentions=output_attentions,
1332
+ output_hidden_states=output_hidden_states,
1333
+ return_dict=return_dict,
1334
+ )
1335
+ hidden_states = transformer_outputs[0]
1336
+ logits = self.score(hidden_states)
1337
+
1338
+ if input_ids is not None:
1339
+ batch_size = input_ids.shape[0]
1340
+ else:
1341
+ batch_size = inputs_embeds.shape[0]
1342
+
1343
+ if self.config.pad_token_id is None and batch_size != 1:
1344
+ raise ValueError(
1345
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1346
+ )
1347
+ if self.config.pad_token_id is None:
1348
+ sequence_lengths = -1
1349
+ else:
1350
+ if input_ids is not None:
1351
+ sequence_lengths = (
1352
+ torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1353
+ ).to(logits.device)
1354
+ else:
1355
+ sequence_lengths = -1
1356
+
1357
+ pooled_logits = logits[
1358
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1359
+ ]
1360
+
1361
+ loss = None
1362
+ if labels is not None:
1363
+ labels = labels.to(logits.device)
1364
+ if self.config.problem_type is None:
1365
+ if self.num_labels == 1:
1366
+ self.config.problem_type = "regression"
1367
+ elif self.num_labels > 1 and (
1368
+ labels.dtype == torch.long or labels.dtype == torch.int
1369
+ ):
1370
+ self.config.problem_type = "single_label_classification"
1371
+ else:
1372
+ self.config.problem_type = "multi_label_classification"
1373
+
1374
+ if self.config.problem_type == "regression":
1375
+ loss_fct = MSELoss()
1376
+ if self.num_labels == 1:
1377
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1378
+ else:
1379
+ loss = loss_fct(pooled_logits, labels)
1380
+ elif self.config.problem_type == "single_label_classification":
1381
+ loss_fct = CrossEntropyLoss()
1382
+ loss = loss_fct(
1383
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1384
+ )
1385
+ elif self.config.problem_type == "multi_label_classification":
1386
+ loss_fct = BCEWithLogitsLoss()
1387
+ loss = loss_fct(pooled_logits, labels)
1388
+ if not return_dict:
1389
+ output = (pooled_logits,) + transformer_outputs[1:]
1390
+ return ((loss,) + output) if loss is not None else output
1391
+
1392
+ return SequenceClassifierOutputWithPast(
1393
+ loss=loss,
1394
+ logits=pooled_logits,
1395
+ past_key_values=transformer_outputs.past_key_values,
1396
+ hidden_states=transformer_outputs.hidden_states,
1397
+ attentions=transformer_outputs.attentions,
1398
+ )
model/modeling_mixtral_kv.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Mixtral model."""
21
+ import inspect
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+ from .kv_cache import KVCache
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
+
33
+ # [MODIFIED] Import from transformer library
34
+ from transformers.activations import ACT2FN
35
+
36
+ from transformers.modeling_outputs import (
37
+ MoeCausalLMOutputWithPast,
38
+ MoeModelOutputWithPast,
39
+ SequenceClassifierOutputWithPast,
40
+ )
41
+ from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.utils import (
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from transformers import MixtralConfig
49
+
50
+
51
+
52
+
53
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
54
+ # It means that the function will not be traced through and simply appear as a node in the graph.
55
+
56
+
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+ _CONFIG_FOR_DOC = "MixtralConfig"
61
+
62
+
63
+ def _make_causal_mask(
64
+ input_ids_shape: torch.Size,
65
+ dtype: torch.dtype,
66
+ device: torch.device,
67
+ past_key_values_length: int = 0,
68
+ ):
69
+ """
70
+ Create a causal mask for bi-directional self-attention.
71
+
72
+ Args:
73
+ input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len).
74
+ dtype (torch.dtype): The data type of the mask.
75
+ device (torch.device): The device on which the mask will be placed.
76
+ past_key_values_length (int, optional): The length of past key values. Default is 0.
77
+
78
+ Returns:
79
+ torch.Tensor: The causal mask tensor.
80
+ """
81
+ bsz, tgt_len = input_ids_shape
82
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
83
+ mask_cond = torch.arange(mask.size(-1), device=device)
84
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
85
+ mask = mask.to(dtype)
86
+
87
+ if past_key_values_length > 0:
88
+ mask = torch.cat(
89
+ [
90
+ torch.zeros(
91
+ tgt_len, past_key_values_length, dtype=dtype, device=device
92
+ ),
93
+ mask,
94
+ ],
95
+ dim=-1,
96
+ )
97
+ return mask[None, None, :, :].expand(
98
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
99
+ )
100
+
101
+
102
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
103
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
104
+ """
105
+ Expand attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
106
+
107
+ Args:
108
+ mask (torch.Tensor): The attention mask tensor of shape `[bsz, seq_len]`.
109
+ dtype (torch.dtype): The data type of the mask.
110
+ tgt_len (Optional[int], optional): The target sequence length. If None, it defaults to the source sequence length.
111
+
112
+ Returns:
113
+ torch.Tensor: The expanded mask tensor.
114
+ """
115
+ bsz, src_len = mask.size()
116
+ tgt_len = tgt_len if tgt_len is not None else src_len
117
+
118
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
119
+
120
+ inverted_mask = 1.0 - expanded_mask
121
+
122
+ return inverted_mask.masked_fill(
123
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
124
+ )
125
+
126
+
127
+ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
128
+ r"""
129
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
130
+
131
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
132
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
133
+ experts is too unbalanced.
134
+
135
+ Args:
136
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
137
+ Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts].
138
+ num_experts (`int`, *optional*):
139
+ Number of experts
140
+
141
+ Returns:
142
+ The auxiliary loss.
143
+ """
144
+ if gate_logits is None:
145
+ return 0
146
+
147
+ if isinstance(gate_logits, tuple):
148
+ # cat along the layers?
149
+ compute_device = gate_logits[0].device
150
+ gate_logits = torch.cat([gate.to(compute_device) for gate in gate_logits], dim=0)
151
+
152
+ routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
153
+ routing_weights = routing_weights.softmax(dim=-1)
154
+
155
+ # cast the expert indices to int64, otherwise one-hot encoding will fail
156
+ if selected_experts.dtype != torch.int64:
157
+ selected_experts = selected_experts.to(torch.int64)
158
+
159
+ if len(selected_experts.shape) == 2:
160
+ selected_experts = selected_experts.unsqueeze(2)
161
+
162
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
163
+
164
+ # For a given token, determine if it was routed to a given expert.
165
+ expert_mask = torch.max(expert_mask, axis=-2).values
166
+
167
+ # cast to float32 otherwise mean will fail
168
+ expert_mask = expert_mask.to(torch.float32)
169
+ tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
170
+
171
+ router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1)
172
+ return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)) * (num_experts**2)
173
+
174
+
175
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
176
+ def _get_unpad_data(attention_mask):
177
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
178
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
179
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
180
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
181
+ return (
182
+ indices,
183
+ cu_seqlens,
184
+ max_seqlen_in_batch,
185
+ )
186
+
187
+
188
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
189
+ class MixtralRMSNorm(nn.Module):
190
+ def __init__(self, hidden_size, eps=1e-6):
191
+ """
192
+ MixtralRMSNorm is equivalent to T5LayerNorm
193
+ """
194
+ super().__init__()
195
+ self.weight = nn.Parameter(torch.ones(hidden_size))
196
+ self.variance_epsilon = eps
197
+
198
+ def forward(self, hidden_states):
199
+ input_dtype = hidden_states.dtype
200
+ hidden_states = hidden_states.to(torch.float32)
201
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
202
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
203
+ return self.weight * hidden_states.to(input_dtype)
204
+
205
+
206
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral
207
+ class MixtralRotaryEmbedding(nn.Module):
208
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
209
+ super().__init__()
210
+
211
+ self.dim = dim
212
+ self.max_position_embeddings = max_position_embeddings
213
+ self.base = base
214
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
215
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
216
+
217
+ # Build here to make `torch.jit.trace` work.
218
+ self._set_cos_sin_cache(
219
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
220
+ )
221
+
222
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
223
+ self.max_seq_len_cached = seq_len
224
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
225
+
226
+ freqs = torch.outer(t, self.inv_freq)
227
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
228
+ emb = torch.cat((freqs, freqs), dim=-1)
229
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
230
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
231
+
232
+ def forward(self, x, seq_len=None):
233
+ # x: [bs, num_attention_heads, seq_len, head_size]
234
+ if seq_len > self.max_seq_len_cached:
235
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
236
+
237
+ return (
238
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
239
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
240
+ )
241
+
242
+
243
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
244
+ def rotate_half(x):
245
+ """Rotates half the hidden dims of the input."""
246
+ x1 = x[..., : x.shape[-1] // 2]
247
+ x2 = x[..., x.shape[-1] // 2 :]
248
+ return torch.cat((-x2, x1), dim=-1)
249
+
250
+
251
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
252
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
253
+ """Applies Rotary Position Embedding to the query and key tensors.
254
+
255
+ Args:
256
+ q (`torch.Tensor`): The query tensor.
257
+ k (`torch.Tensor`): The key tensor.
258
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
259
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
260
+ position_ids (`torch.Tensor`):
261
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
262
+ used to pass offsetted position ids when working with a KV-cache.
263
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
264
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
265
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
266
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
267
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
268
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
269
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
270
+ Returns:
271
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
272
+ """
273
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
274
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
275
+ q_embed = (q * cos) + (rotate_half(q) * sin)
276
+ k_embed = (k * cos) + (rotate_half(k) * sin)
277
+ return q_embed, k_embed
278
+
279
+
280
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
281
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
282
+ """
283
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
284
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
285
+ """
286
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
287
+ if n_rep == 1:
288
+ return hidden_states
289
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
290
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
291
+
292
+
293
+ # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
294
+ class MixtralAttention(nn.Module):
295
+ """
296
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
297
+ and "Generating Long Sequences with Sparse Transformers".
298
+ """
299
+
300
+ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
301
+ super().__init__()
302
+ self.config = config
303
+ self.layer_idx = layer_idx
304
+ if layer_idx is None:
305
+ logger.warning_once(
306
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
307
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
308
+ "when creating this class."
309
+ )
310
+
311
+ self.hidden_size = config.hidden_size
312
+ self.num_heads = config.num_attention_heads
313
+ self.head_dim = self.hidden_size // self.num_heads
314
+ self.num_key_value_heads = config.num_key_value_heads
315
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
316
+ self.max_position_embeddings = config.max_position_embeddings
317
+ self.rope_theta = config.rope_theta
318
+ self.is_causal = True
319
+ self.attention_dropout = config.attention_dropout
320
+
321
+ if (self.head_dim * self.num_heads) != self.hidden_size:
322
+ raise ValueError(
323
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
324
+ f" and `num_heads`: {self.num_heads})."
325
+ )
326
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
327
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
328
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
329
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
330
+
331
+ self.rotary_emb = MixtralRotaryEmbedding(
332
+ self.head_dim,
333
+ max_position_embeddings=self.max_position_embeddings,
334
+ base=self.rope_theta,
335
+ )
336
+
337
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
338
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
339
+
340
+ def forward(
341
+ self,
342
+ hidden_states: torch.Tensor,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ position_ids: Optional[torch.LongTensor] = None,
345
+ past_key_value: Optional[Tuple[KVCache]] = None,
346
+ output_attentions: bool = False,
347
+ use_cache: bool = False,
348
+ **kwargs,
349
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
350
+ if "padding_mask" in kwargs:
351
+ warnings.warn(
352
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
353
+ )
354
+ bsz, q_len, _ = hidden_states.size()
355
+
356
+ query_states = self.q_proj(hidden_states)
357
+ key_states = self.k_proj(hidden_states)
358
+ value_states = self.v_proj(hidden_states)
359
+
360
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
361
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
362
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
363
+
364
+ kv_seq_len = key_states.shape[-2]
365
+ if past_key_value is not None:
366
+ if self.layer_idx is None:
367
+ raise ValueError(
368
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
369
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
370
+ "with a layer index."
371
+ )
372
+ kv_seq_len += past_key_value[0].shape[-2]
373
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
374
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
375
+
376
+ if past_key_value is not None:
377
+ key_states = past_key_value[0].cat(key_states, dim=2)
378
+ value_states = past_key_value[1].cat(value_states, dim=2)
379
+
380
+ # repeat k/v heads if n_kv_heads < n_heads
381
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
382
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
383
+
384
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
385
+
386
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
387
+ raise ValueError(
388
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
389
+ f" {attn_weights.size()}"
390
+ )
391
+
392
+ if attention_mask is not None:
393
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
394
+ raise ValueError(
395
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
396
+ )
397
+
398
+ attn_weights = attn_weights + attention_mask
399
+
400
+ # upcast attention to fp32
401
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
402
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
403
+ attn_output = torch.matmul(attn_weights, value_states)
404
+
405
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
406
+ raise ValueError(
407
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
408
+ f" {attn_output.size()}"
409
+ )
410
+
411
+ attn_output = attn_output.transpose(1, 2).contiguous()
412
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
413
+
414
+ attn_output = self.o_proj(attn_output)
415
+
416
+ if not output_attentions:
417
+ attn_weights = None
418
+
419
+ return attn_output, attn_weights, past_key_value
420
+
421
+
422
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
423
+
424
+
425
+
426
+ class MixtralBLockSparseTop2MLP(nn.Module):
427
+ def __init__(self, config: MixtralConfig):
428
+ super().__init__()
429
+ self.ffn_dim = config.intermediate_size
430
+ self.hidden_dim = config.hidden_size
431
+
432
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
433
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
434
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
435
+
436
+ self.act_fn = ACT2FN[config.hidden_act]
437
+
438
+ def forward(self, hidden_states, routing_weights):
439
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
440
+ current_hidden_states = self.w2(current_hidden_states)
441
+ return routing_weights * current_hidden_states
442
+
443
+
444
+ MISTRAL_ATTENTION_CLASSES = {
445
+ "eager": MixtralAttention,
446
+ }
447
+
448
+
449
+ class MixtralSparseMoeBlock(nn.Module):
450
+ """
451
+ This implementation is
452
+ strictly equivalent to standard MoE with full capacity (no
453
+ dropped tokens). It's faster since it formulates MoE operations
454
+ in terms of block-sparse operations to accomodate imbalanced
455
+ assignments of tokens to experts, whereas standard MoE either
456
+ (1) drop tokens at the cost of reduced performance or (2) set
457
+ capacity factor to number of experts and thus waste computation
458
+ and memory on padding.
459
+ """
460
+
461
+ def __init__(self, config):
462
+ super().__init__()
463
+ self.hidden_dim = config.hidden_size
464
+ self.ffn_dim = config.intermediate_size
465
+ self.num_experts = config.num_local_experts
466
+ self.top_k = config.num_experts_per_tok
467
+
468
+ # gating
469
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
470
+
471
+ self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
472
+
473
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
474
+ """ """
475
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
476
+ hidden_states = hidden_states.view(-1, hidden_dim)
477
+ # router_logits: (batch * sequence_length, n_experts)
478
+ router_logits = self.gate(hidden_states)
479
+
480
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
481
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
482
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
483
+ # we cast back to the input dtype
484
+ routing_weights = routing_weights.to(hidden_states.dtype)
485
+
486
+ final_hidden_states = torch.zeros(
487
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
488
+ )
489
+
490
+ # One hot encode the selected experts to create an expert mask
491
+ # this will be used to easily index which expert is going to be sollicitated
492
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
493
+
494
+ # Loop over all available experts in the model and perform the computation on each expert
495
+ for expert_idx in range(self.num_experts):
496
+ expert_layer = self.experts[expert_idx]
497
+ idx, top_x = torch.where(expert_mask[expert_idx])
498
+
499
+ if top_x.shape[0] == 0:
500
+ continue
501
+
502
+ # in torch it is faster to index using lists than torch tensors
503
+ top_x_list = top_x.tolist()
504
+ idx_list = idx.tolist()
505
+
506
+ # Index the correct hidden states and compute the expert hidden state for
507
+ # the current expert. We need to make sure to multiply the output hidden
508
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
509
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
510
+ current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
511
+
512
+ # However `index_add_` only support torch tensors for indexing so we'll use
513
+ # the `top_x` tensor here.
514
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
515
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
516
+ return final_hidden_states, router_logits
517
+
518
+
519
+ class MixtralDecoderLayer(nn.Module):
520
+ def __init__(self, config: MixtralConfig, layer_idx: int):
521
+ super().__init__()
522
+ self.hidden_size = config.hidden_size
523
+
524
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
525
+
526
+ self.block_sparse_moe = MixtralSparseMoeBlock(config)
527
+ self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
528
+ self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
529
+
530
+ def forward(
531
+ self,
532
+ hidden_states: torch.Tensor,
533
+ attention_mask: Optional[torch.Tensor] = None,
534
+ position_ids: Optional[torch.LongTensor] = None,
535
+ past_key_value: Optional[Tuple[KVCache]] = None,
536
+ output_attentions: Optional[bool] = False,
537
+ output_router_logits: Optional[bool] = False,
538
+ use_cache: Optional[bool] = False,
539
+ **kwargs,
540
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
541
+ if "padding_mask" in kwargs:
542
+ warnings.warn(
543
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
544
+ )
545
+ """
546
+ Args:
547
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
548
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
549
+ `(batch, sequence_length)` where padding elements are indicated by 0.
550
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
551
+ output_attentions (`bool`, *optional*):
552
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
553
+ returned tensors for more detail.
554
+ output_router_logits (`bool`, *optional*):
555
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
556
+ should not be returned during inference.
557
+ use_cache (`bool`, *optional*):
558
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
559
+ (see `past_key_values`).
560
+ """
561
+
562
+ residual = hidden_states
563
+
564
+ hidden_states = self.input_layernorm(hidden_states)
565
+
566
+ # Self Attention
567
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
568
+ hidden_states=hidden_states,
569
+ attention_mask=attention_mask,
570
+ position_ids=position_ids,
571
+ past_key_value=past_key_value,
572
+ output_attentions=output_attentions,
573
+ use_cache=use_cache,
574
+ )
575
+ hidden_states = residual + hidden_states
576
+
577
+ # Fully Connected
578
+ residual = hidden_states
579
+ hidden_states = self.post_attention_layernorm(hidden_states)
580
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
581
+ hidden_states = residual + hidden_states
582
+
583
+ outputs = (hidden_states,)
584
+
585
+ if output_attentions:
586
+ outputs += (self_attn_weights,)
587
+
588
+ if use_cache:
589
+ outputs += (present_key_value,)
590
+
591
+ if output_router_logits:
592
+ outputs += (router_logits,)
593
+
594
+ return outputs
595
+
596
+
597
+ MIXTRAL_START_DOCSTRING = r"""
598
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
599
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
600
+ etc.)
601
+
602
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
603
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
604
+ and behavior.
605
+
606
+ Parameters:
607
+ config ([`MixtralConfig`]):
608
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
609
+ load the weights associated with the model, only the configuration. Check out the
610
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
611
+ """
612
+
613
+
614
+ @add_start_docstrings(
615
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
616
+ MIXTRAL_START_DOCSTRING,
617
+ )
618
+ # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
619
+ class MixtralPreTrainedModel(PreTrainedModel):
620
+ config_class = MixtralConfig
621
+ base_model_prefix = "model"
622
+ supports_gradient_checkpointing = True
623
+ _no_split_modules = ["MixtralDecoderLayer"]
624
+ _skip_keys_device_placement = "past_key_values"
625
+ _supports_flash_attn_2 = True
626
+ _supports_cache_class = True
627
+
628
+ def _init_weights(self, module):
629
+ std = self.config.initializer_range
630
+ if isinstance(module, nn.Linear):
631
+ module.weight.data.normal_(mean=0.0, std=std)
632
+ if module.bias is not None:
633
+ module.bias.data.zero_()
634
+ elif isinstance(module, nn.Embedding):
635
+ module.weight.data.normal_(mean=0.0, std=std)
636
+ if module.padding_idx is not None:
637
+ module.weight.data[module.padding_idx].zero_()
638
+
639
+
640
+ MIXTRAL_INPUTS_DOCSTRING = r"""
641
+ Args:
642
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
643
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
644
+ it.
645
+
646
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
647
+ [`PreTrainedTokenizer.__call__`] for details.
648
+
649
+ [What are input IDs?](../glossary#input-ids)
650
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
651
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
652
+
653
+ - 1 for tokens that are **not masked**,
654
+ - 0 for tokens that are **masked**.
655
+
656
+ [What are attention masks?](../glossary#attention-mask)
657
+
658
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
659
+ [`PreTrainedTokenizer.__call__`] for details.
660
+
661
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
662
+ `past_key_values`).
663
+
664
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
665
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
666
+ information on the default strategy.
667
+
668
+ - 1 indicates the head is **not masked**,
669
+ - 0 indicates the head is **masked**.
670
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
671
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
672
+ config.n_positions - 1]`.
673
+
674
+ [What are position IDs?](../glossary#position-ids)
675
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
676
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
677
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
678
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
679
+
680
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
681
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
682
+
683
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
684
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
685
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
686
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
687
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
688
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
689
+ model's internal embedding lookup matrix.
690
+ use_cache (`bool`, *optional*):
691
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
692
+ `past_key_values`).
693
+ output_attentions (`bool`, *optional*):
694
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
695
+ tensors for more detail.
696
+ output_hidden_states (`bool`, *optional*):
697
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
698
+ more detail.
699
+ output_router_logits (`bool`, *optional*):
700
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
701
+ should not be returned during inference.
702
+ return_dict (`bool`, *optional*):
703
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
704
+ """
705
+
706
+
707
+ @add_start_docstrings(
708
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
709
+ MIXTRAL_START_DOCSTRING,
710
+ )
711
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
712
+ class MixtralModel(MixtralPreTrainedModel):
713
+ """
714
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
715
+
716
+ Args:
717
+ config: MixtralConfig
718
+ """
719
+
720
+ def __init__(self, config: MixtralConfig):
721
+ super().__init__(config)
722
+ self.padding_idx = config.pad_token_id
723
+ self.vocab_size = config.vocab_size
724
+
725
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
726
+ self.layers = nn.ModuleList(
727
+ [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
728
+ )
729
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
730
+ self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
731
+
732
+ self.gradient_checkpointing = False
733
+ # Initialize weights and apply final processing
734
+ self.post_init()
735
+
736
+ def get_input_embeddings(self):
737
+ return self.embed_tokens
738
+
739
+ def set_input_embeddings(self, value):
740
+ self.embed_tokens = value
741
+
742
+ def _prepare_decoder_attention_mask(
743
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
744
+ ):
745
+ # create causal mask
746
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
747
+ combined_attention_mask = None
748
+ if input_shape[-1] > 1:
749
+ combined_attention_mask = _make_causal_mask(
750
+ input_shape,
751
+ # inputs_embeds.dtype,
752
+ torch.float32, # [MODIFIED] force to cast to float32
753
+ device=inputs_embeds.device,
754
+ past_key_values_length=past_key_values_length,
755
+ )
756
+
757
+ if attention_mask is not None:
758
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
759
+ expanded_attn_mask = _expand_mask(
760
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
761
+ ).to(inputs_embeds.device)
762
+ combined_attention_mask = (
763
+ expanded_attn_mask
764
+ if combined_attention_mask is None
765
+ else expanded_attn_mask + combined_attention_mask
766
+ )
767
+
768
+
769
+ if hasattr(self, "tree_mask") and self.tree_mask is not None:
770
+ tree_mask = self.tree_mask
771
+ tree_len = tree_mask.size(-1)
772
+ combined_attention_mask[:, :, -tree_len:, -tree_len:][
773
+ tree_mask == 0
774
+ ] = combined_attention_mask.min()
775
+
776
+ return combined_attention_mask
777
+
778
+ # Ignore copy
779
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
780
+ def forward(
781
+ self,
782
+ input_ids: torch.LongTensor = None,
783
+ attention_mask: Optional[torch.Tensor] = None,
784
+ position_ids: Optional[torch.LongTensor] = None,
785
+ past_key_values: Optional[List[Tuple[KVCache]]] = None,
786
+ inputs_embeds: Optional[torch.FloatTensor] = None,
787
+ use_cache: Optional[bool] = None,
788
+ output_attentions: Optional[bool] = None,
789
+ output_hidden_states: Optional[bool] = None,
790
+ output_router_logits: Optional[bool] = None,
791
+ return_dict: Optional[bool] = None,
792
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
793
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
794
+ output_router_logits = (
795
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
796
+ )
797
+ output_hidden_states = (
798
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
799
+ )
800
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
801
+
802
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
803
+
804
+ # retrieve input_ids and inputs_embeds
805
+ if input_ids is not None and inputs_embeds is not None:
806
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
807
+ elif input_ids is not None:
808
+ batch_size, seq_length = input_ids.shape
809
+ elif inputs_embeds is not None:
810
+ batch_size, seq_length, _ = inputs_embeds.shape
811
+ else:
812
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
813
+
814
+ past_key_values_length = 0
815
+
816
+ if self.gradient_checkpointing and self.training:
817
+ if use_cache:
818
+ logger.warning_once(
819
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
820
+ )
821
+ use_cache = False
822
+
823
+ if past_key_values is not None:
824
+ past_key_values_length = past_key_values[0][0].shape[2]
825
+
826
+ if position_ids is None:
827
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
828
+ position_ids = torch.arange(
829
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
830
+ )
831
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
832
+ else:
833
+ position_ids = position_ids.view(-1, seq_length).long()
834
+
835
+ if inputs_embeds is None:
836
+ inputs_embeds = self.embed_tokens(input_ids)
837
+
838
+ if attention_mask is not None and self._use_flash_attention_2 and use_cache:
839
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
840
+ if is_padding_right:
841
+ raise ValueError(
842
+ "You are attempting to perform batched generation with padding_side='right'"
843
+ " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
844
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
845
+ )
846
+
847
+ # if self._use_flash_attention_2:
848
+ # # 2d mask is passed through the layers
849
+ # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
850
+ # else:
851
+ # 4d mask is passed through the layers
852
+ attention_mask = self._prepare_decoder_attention_mask(
853
+ attention_mask,
854
+ (batch_size, seq_length),
855
+ inputs_embeds,
856
+ past_key_values_length,
857
+ )
858
+
859
+ hidden_states = inputs_embeds
860
+
861
+ # decoder layers
862
+ all_hidden_states = () if output_hidden_states else None
863
+ all_self_attns = () if output_attentions else None
864
+ all_router_logits = () if output_router_logits else None
865
+ next_decoder_cache = None
866
+
867
+ for idx, decoder_layer in enumerate(self.layers):
868
+ if output_hidden_states:
869
+ all_hidden_states += (hidden_states,)
870
+
871
+ past_key_value = (
872
+ past_key_values[idx] if past_key_values is not None else None
873
+ )
874
+
875
+ if self.gradient_checkpointing and self.training:
876
+ layer_outputs = self._gradient_checkpointing_func(
877
+ decoder_layer.__call__,
878
+ hidden_states,
879
+ attention_mask,
880
+ position_ids,
881
+ past_key_value,
882
+ output_attentions,
883
+ output_router_logits,
884
+ use_cache,
885
+ )
886
+ else:
887
+ layer_outputs = decoder_layer(
888
+ hidden_states,
889
+ attention_mask=attention_mask,
890
+ position_ids=position_ids,
891
+ past_key_value=past_key_value,
892
+ output_attentions=output_attentions,
893
+ output_router_logits=output_router_logits,
894
+ use_cache=use_cache,
895
+ )
896
+
897
+ hidden_states = layer_outputs[0]
898
+
899
+ if use_cache:
900
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
901
+
902
+ if output_attentions:
903
+ all_self_attns += (layer_outputs[1],)
904
+
905
+ if output_router_logits:
906
+ all_router_logits += (layer_outputs[-1],)
907
+
908
+ hidden_states = self.norm(hidden_states)
909
+
910
+ # add hidden states from the last decoder layer
911
+ if output_hidden_states:
912
+ all_hidden_states += (hidden_states,)
913
+
914
+
915
+ next_cache = next_decoder_cache if use_cache else None
916
+ # if use_cache:
917
+ # next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
918
+
919
+ if not return_dict:
920
+ return tuple(
921
+ v
922
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
923
+ if v is not None
924
+ )
925
+ return MoeModelOutputWithPast(
926
+ last_hidden_state=hidden_states,
927
+ past_key_values=next_cache,
928
+ hidden_states=all_hidden_states,
929
+ attentions=all_self_attns,
930
+ router_logits=all_router_logits,
931
+ )
932
+
933
+
934
+ class MixtralForCausalLM(MixtralPreTrainedModel):
935
+ _tied_weights_keys = ["lm_head.weight"]
936
+
937
+ def __init__(self, config):
938
+ super().__init__(config)
939
+ self.model = MixtralModel(config)
940
+ self.vocab_size = config.vocab_size
941
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
942
+ self.router_aux_loss_coef = config.router_aux_loss_coef
943
+ self.num_experts = config.num_local_experts
944
+ self.num_experts_per_tok = config.num_experts_per_tok
945
+ # Initialize weights and apply final processing
946
+ self.post_init()
947
+
948
+ def get_input_embeddings(self):
949
+ return self.model.embed_tokens
950
+
951
+ def set_input_embeddings(self, value):
952
+ self.model.embed_tokens = value
953
+
954
+ def get_output_embeddings(self):
955
+ return self.lm_head
956
+
957
+ def set_output_embeddings(self, new_embeddings):
958
+ self.lm_head = new_embeddings
959
+
960
+ def set_decoder(self, decoder):
961
+ self.model = decoder
962
+
963
+ def get_decoder(self):
964
+ return self.model
965
+
966
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
967
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
968
+ # Ignore copy
969
+ def forward(
970
+ self,
971
+ input_ids: torch.LongTensor = None,
972
+ attention_mask: Optional[torch.Tensor] = None,
973
+ position_ids: Optional[torch.LongTensor] = None,
974
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
975
+ inputs_embeds: Optional[torch.FloatTensor] = None,
976
+ labels: Optional[torch.LongTensor] = None,
977
+ use_cache: Optional[bool] = None,
978
+ output_attentions: Optional[bool] = None,
979
+ output_hidden_states: Optional[bool] = None,
980
+ output_router_logits: Optional[bool] = None,
981
+ return_dict: Optional[bool] = None,
982
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
983
+ r"""
984
+ Args:
985
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
986
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
987
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
988
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
989
+
990
+ Returns:
991
+
992
+ Example:
993
+
994
+ ```python
995
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
996
+
997
+ >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
998
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
999
+
1000
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1001
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1002
+
1003
+ >>> # Generate
1004
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1005
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1006
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1007
+ ```"""
1008
+
1009
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1010
+ output_router_logits = (
1011
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1012
+ )
1013
+
1014
+ output_hidden_states = (
1015
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1016
+ )
1017
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1018
+
1019
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1020
+ outputs = self.model(
1021
+ input_ids=input_ids,
1022
+ attention_mask=attention_mask,
1023
+ position_ids=position_ids,
1024
+ past_key_values=past_key_values,
1025
+ inputs_embeds=inputs_embeds,
1026
+ use_cache=use_cache,
1027
+ output_attentions=output_attentions,
1028
+ output_hidden_states=output_hidden_states,
1029
+ output_router_logits=output_router_logits,
1030
+ return_dict=return_dict,
1031
+ )
1032
+
1033
+ hidden_states = outputs[0]
1034
+ logits = self.lm_head(hidden_states)
1035
+ logits = logits.float()
1036
+
1037
+ loss = None
1038
+ if labels is not None:
1039
+ # Shift so that tokens < n predict n
1040
+ shift_logits = logits[..., :-1, :].contiguous()
1041
+ shift_labels = labels[..., 1:].contiguous()
1042
+ # Flatten the tokens
1043
+ loss_fct = CrossEntropyLoss()
1044
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1045
+ shift_labels = shift_labels.view(-1)
1046
+ # Enable model parallelism
1047
+ shift_labels = shift_labels.to(shift_logits.device)
1048
+ loss = loss_fct(shift_logits, shift_labels)
1049
+
1050
+ aux_loss = None
1051
+ if output_router_logits:
1052
+ aux_loss = load_balancing_loss_func(
1053
+ outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok
1054
+ )
1055
+ if labels is not None:
1056
+ loss += self.router_aux_loss_coef * aux_loss
1057
+
1058
+ if not return_dict:
1059
+ output = (logits,) + outputs[1:]
1060
+ if output_router_logits:
1061
+ output = (aux_loss,) + output
1062
+ return (loss,) + output if loss is not None else output
1063
+
1064
+ return MoeCausalLMOutputWithPast(
1065
+ loss=loss,
1066
+ aux_loss=aux_loss,
1067
+ logits=logits,
1068
+ past_key_values=outputs.past_key_values,
1069
+ hidden_states=outputs.hidden_states,
1070
+ attentions=outputs.attentions,
1071
+ router_logits=outputs.router_logits,
1072
+ )
1073
+
1074
+
1075
+
1076
+
1077
+
1078
+
1079
+ @add_start_docstrings(
1080
+ """
1081
+ The Mixtral Model transformer with a sequence classification head on top (linear layer).
1082
+
1083
+ [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1084
+ (e.g. GPT-2) do.
1085
+
1086
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1087
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1088
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1089
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1090
+ each row of the batch).
1091
+ """,
1092
+ MIXTRAL_START_DOCSTRING,
1093
+ )
1094
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
1095
+ class MixtralForSequenceClassification(MixtralPreTrainedModel):
1096
+ def __init__(self, config):
1097
+ super().__init__(config)
1098
+ self.num_labels = config.num_labels
1099
+ self.model = MixtralModel(config)
1100
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1101
+
1102
+ # Initialize weights and apply final processing
1103
+ self.post_init()
1104
+
1105
+ def get_input_embeddings(self):
1106
+ return self.model.embed_tokens
1107
+
1108
+ def set_input_embeddings(self, value):
1109
+ self.model.embed_tokens = value
1110
+
1111
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1112
+ def forward(
1113
+ self,
1114
+ input_ids: torch.LongTensor = None,
1115
+ attention_mask: Optional[torch.Tensor] = None,
1116
+ position_ids: Optional[torch.LongTensor] = None,
1117
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1118
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1119
+ labels: Optional[torch.LongTensor] = None,
1120
+ use_cache: Optional[bool] = None,
1121
+ output_attentions: Optional[bool] = None,
1122
+ output_hidden_states: Optional[bool] = None,
1123
+ return_dict: Optional[bool] = None,
1124
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1125
+ r"""
1126
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1127
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1128
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1129
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1130
+ """
1131
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1132
+
1133
+ transformer_outputs = self.model(
1134
+ input_ids,
1135
+ attention_mask=attention_mask,
1136
+ position_ids=position_ids,
1137
+ past_key_values=past_key_values,
1138
+ inputs_embeds=inputs_embeds,
1139
+ use_cache=use_cache,
1140
+ output_attentions=output_attentions,
1141
+ output_hidden_states=output_hidden_states,
1142
+ return_dict=return_dict,
1143
+ )
1144
+ hidden_states = transformer_outputs[0]
1145
+ logits = self.score(hidden_states)
1146
+
1147
+ if input_ids is not None:
1148
+ batch_size = input_ids.shape[0]
1149
+ else:
1150
+ batch_size = inputs_embeds.shape[0]
1151
+
1152
+ if self.config.pad_token_id is None and batch_size != 1:
1153
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1154
+ if self.config.pad_token_id is None:
1155
+ sequence_lengths = -1
1156
+ else:
1157
+ if input_ids is not None:
1158
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1159
+ logits.device
1160
+ )
1161
+ else:
1162
+ sequence_lengths = -1
1163
+
1164
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1165
+
1166
+ loss = None
1167
+ if labels is not None:
1168
+ labels = labels.to(logits.device)
1169
+ if self.config.problem_type is None:
1170
+ if self.num_labels == 1:
1171
+ self.config.problem_type = "regression"
1172
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1173
+ self.config.problem_type = "single_label_classification"
1174
+ else:
1175
+ self.config.problem_type = "multi_label_classification"
1176
+
1177
+ if self.config.problem_type == "regression":
1178
+ loss_fct = MSELoss()
1179
+ if self.num_labels == 1:
1180
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1181
+ else:
1182
+ loss = loss_fct(pooled_logits, labels)
1183
+ elif self.config.problem_type == "single_label_classification":
1184
+ loss_fct = CrossEntropyLoss()
1185
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1186
+ elif self.config.problem_type == "multi_label_classification":
1187
+ loss_fct = BCEWithLogitsLoss()
1188
+ loss = loss_fct(pooled_logits, labels)
1189
+ if not return_dict:
1190
+ output = (pooled_logits,) + transformer_outputs[1:]
1191
+ return ((loss,) + output) if loss is not None else output
1192
+
1193
+ return SequenceClassifierOutputWithPast(
1194
+ loss=loss,
1195
+ logits=pooled_logits,
1196
+ past_key_values=transformer_outputs.past_key_values,
1197
+ hidden_states=transformer_outputs.hidden_states,
1198
+ attentions=transformer_outputs.attentions,
1199
+ )
model/utils.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+
4
+ # typing
5
+ from typing import List, Tuple
6
+ import time
7
+ import torch
8
+
9
+ # TODO
10
+ # from transformers import LlamaTokenizer
11
+ # tokenizer=LlamaTokenizer.from_pretrained("/home/lyh/weights/hf/vicuna_v13/7B/")
12
+
13
+ TOPK = 10 # topk for sparse tree
14
+
15
+ from transformers.generation.logits_process import (
16
+ LogitsProcessorList,
17
+ RepetitionPenaltyLogitsProcessor,
18
+ TemperatureLogitsWarper,
19
+ TopKLogitsWarper,
20
+ TopPLogitsWarper,
21
+ )
22
+
23
+
24
+ class Timer:
25
+ def __init__(self,name):
26
+ self.name = name
27
+ def __enter__(self):
28
+ torch.cuda.synchronize()
29
+ self.start = time.perf_counter()
30
+
31
+
32
+ def __exit__(self, exc_type, exc_value, traceback):
33
+ torch.cuda.synchronize()
34
+ elapsed = time.perf_counter() - self.start
35
+ print(f'{self.name} took {elapsed} seconds')
36
+
37
+
38
+ def prepare_logits_processor(
39
+ temperature: float = 0.0,
40
+ repetition_penalty: float = 0.0,
41
+ top_p: float = 0.0,
42
+ top_k: int = 0
43
+ ) -> LogitsProcessorList:
44
+ processor_list = LogitsProcessorList()
45
+ if temperature > 1e-5:
46
+ if temperature >= 1e-5 and temperature != 1.0:
47
+ processor_list.append(TemperatureLogitsWarper(temperature))
48
+ if repetition_penalty > 1.0:
49
+ processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
50
+ if 1e-8 <= top_p < 1.0:
51
+ processor_list.append(TopPLogitsWarper(top_p))
52
+ if top_k > 0:
53
+ processor_list.append(TopKLogitsWarper(top_k))
54
+ return processor_list
55
+
56
+
57
+ # test_processor = prepare_logits_processor(
58
+ # 0.0, 0.0, -1, 1
59
+ # )
60
+
61
+
62
+ def pad_path(path: List[int], length: int, pad_value: int = -2) -> List[int]:
63
+ """
64
+ Pad the given path list with a specific value up to a specified length.
65
+
66
+ Parameters:
67
+ - path (list): The original list that needs padding.
68
+ - length (int): The desired length of the padded list.
69
+ - pad_value (optional, default=-2): The value to use for padding.
70
+
71
+ Returns:
72
+ - list: A new list based on the original path but padded to the desired length.
73
+
74
+ Example:
75
+ >>> pad_path([1,2,3], 5)
76
+ [1, 2, 3, -2, -2]
77
+
78
+ Note:
79
+ If the given path is already longer than the specified length,
80
+ then no padding occurs, and the original path is returned.
81
+ """
82
+
83
+ # Calculate the number of padding values needed by subtracting the length
84
+ # of the path from the desired length.
85
+ # Append the padding values to the original path and return the new list.
86
+ return path + [pad_value] * (length - len(path))
87
+
88
+
89
+ def generate_tree_buffers(tree_choices, device="cuda"):
90
+ def custom_sort(lst):
91
+ # sort_keys=[len(list)]
92
+ sort_keys = []
93
+ for i in range(len(lst)):
94
+ sort_keys.append(lst[i] if lst[i] >= 0 else maxitem)
95
+ return sort_keys
96
+ with Timer("sort"):
97
+
98
+ sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x))
99
+ tree_len = len(sorted_tree_choices) + 1
100
+
101
+ # Initialize depth_counts to keep track of how many choices have a particular depth
102
+ depth_counts = []
103
+ prev_depth = 0
104
+ for path in sorted_tree_choices:
105
+ depth = len(path)
106
+ if depth != prev_depth:
107
+ depth_counts.append(0)
108
+ depth_counts[depth - 1] += 1
109
+ prev_depth = depth
110
+
111
+ tree_attn_mask = torch.eye(tree_len, tree_len)
112
+ tree_attn_mask[:, 0] = 1
113
+ start = 0
114
+ for i in range(len(depth_counts)):
115
+ for j in range(depth_counts[i]):
116
+ cur_tree_choice = sorted_tree_choices[start + j]
117
+ # retrieve ancestor position
118
+ if len(cur_tree_choice) == 1:
119
+ continue
120
+ ancestor_idx = []
121
+ for c in range(len(cur_tree_choice) - 1):
122
+ ancestor_idx.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
123
+ tree_attn_mask[j + start + 1, ancestor_idx] = 1
124
+ start += depth_counts[i]
125
+
126
+ tree_indices = torch.zeros(tree_len, dtype=torch.long)
127
+ p_indices = [0 for _ in range(tree_len - 1)]
128
+ b_indices = [[] for _ in range(tree_len - 1)]
129
+ tree_indices[0] = 0
130
+ start = 0
131
+ bias = 0
132
+ for i in range(len(depth_counts)):
133
+ inlayer_bias = 0
134
+ b = []
135
+ for j in range(depth_counts[i]):
136
+ cur_tree_choice = sorted_tree_choices[start + j]
137
+ cur_parent = cur_tree_choice[:-1]
138
+ if j != 0:
139
+ if cur_parent != parent:
140
+ bias += 1
141
+ inlayer_bias += 1
142
+ parent = cur_parent
143
+ b = []
144
+ else:
145
+ parent = cur_parent
146
+ tree_indices[start + j + 1] = cur_tree_choice[-1] + TOPK * (i + bias) + 1
147
+ p_indices[start + j] = inlayer_bias
148
+ if len(b) > 0:
149
+ b_indices[start + j] = copy.deepcopy(b)
150
+ else:
151
+ b_indices[start + j] = []
152
+ b.append(cur_tree_choice[-1] + TOPK * (i + bias) + 1)
153
+ start += depth_counts[i]
154
+
155
+ p_indices = [-1] + p_indices
156
+ tree_position_ids = torch.zeros(tree_len, dtype=torch.long)
157
+ start = 0
158
+ for i in range(len(depth_counts)):
159
+ tree_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
160
+ start += depth_counts[i]
161
+
162
+ retrieve_indices_nest = []
163
+ retrieve_paths = []
164
+ for i in range(len(sorted_tree_choices)):
165
+ cur_tree_choice = sorted_tree_choices[-i - 1]
166
+ retrieve_indice = []
167
+ if cur_tree_choice in retrieve_paths:
168
+ continue
169
+ else:
170
+ for c in range(len(cur_tree_choice)):
171
+ retrieve_indice.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]))
172
+ retrieve_paths.append(cur_tree_choice[:c + 1])
173
+ retrieve_indices_nest.append(retrieve_indice)
174
+ max_length = max([len(x) for x in retrieve_indices_nest])
175
+ retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
176
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
177
+ retrieve_indices = retrieve_indices + 1
178
+ retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices],
179
+ dim=1)
180
+
181
+ maxitem = retrieve_indices.max().item() + 5
182
+
183
+
184
+
185
+ retrieve_indices = retrieve_indices.tolist()
186
+ retrieve_indices = sorted(retrieve_indices, key=custom_sort)
187
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
188
+
189
+
190
+
191
+ # Aggregate the generated buffers into a dictionary
192
+ tree_buffers = {
193
+ "tree_attn_mask": tree_attn_mask.unsqueeze(0).unsqueeze(0),
194
+ "tree_indices": tree_indices,
195
+ "tree_position_ids": tree_position_ids,
196
+ "retrieve_indices": retrieve_indices,
197
+ }
198
+
199
+ # Move the tensors in the dictionary to the specified device
200
+ tree_buffers = {
201
+ k: v.clone().to(device)
202
+ if isinstance(v, torch.Tensor)
203
+ else torch.tensor(v, device=device)
204
+ for k, v in tree_buffers.items()
205
+ }
206
+
207
+ return tree_buffers
208
+
209
+
210
+ def initialize_tree0(input_ids, model, past_key_values, logits_processor):
211
+ draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, logits, hidden_state, sample_token = model(
212
+ input_ids, past_key_values=past_key_values, output_orig=True, logits_processor=logits_processor
213
+ )
214
+
215
+ # if logits_processor is not None:
216
+ # logits = orig[:, -1]
217
+ # logits = logits_processor(None, logits)
218
+ # probabilities = torch.nn.functional.softmax(logits, dim=1)
219
+ # token = torch.multinomial(probabilities, 1)
220
+ # else:
221
+ # token = torch.argmax(orig[:, -1])
222
+ # token = token[None, None]
223
+ # input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
224
+ # # Clone the output hidden states
225
+ #
226
+ # draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head)
227
+ # if output_orig:
228
+ # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token
229
+ # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token
230
+ return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token
231
+
232
+ def initialize_tree(input_ids, model, past_key_values, logits_processor):
233
+ outputs, orig, hidden_states = model(
234
+ input_ids, past_key_values=past_key_values, output_orig=True
235
+ )
236
+
237
+ if logits_processor is not None:
238
+ logits = orig[:, -1]
239
+ logits = logits_processor(None, logits)
240
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
241
+ token = torch.multinomial(probabilities, 1)
242
+ else:
243
+ token = torch.argmax(orig[:, -1])
244
+ token = token[None, None]
245
+ input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
246
+ # Clone the output hidden states
247
+
248
+ draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor)
249
+ return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token
250
+
251
+
252
+ def reset_tree_mode(
253
+ model,
254
+ ):
255
+ model.base_model.model.tree_mask = None
256
+ model.base_model.model.tree_mode = None
257
+
258
+
259
+ def reset_past_key_values(passed_key_values: List[torch.Tensor]) -> List[torch.Tensor]:
260
+ """
261
+ Resets the current lengths in the passed key-values to zero.
262
+
263
+ This function is designed to be used during the evaluation of a baseline model.
264
+ It iterates through each layer's key-values and sets their current lengths to zero,
265
+ effectively resetting their state.
266
+
267
+ Args:
268
+ - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
269
+
270
+ Returns:
271
+ - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
272
+ """
273
+ for i in range(len(passed_key_values)):
274
+ for j in range(2):
275
+ passed_key_values[i][j].current_length.fill_(0)
276
+ return passed_key_values
277
+
278
+
279
+ def generate_candidates(tree_logits, tree_indices, retrieve_indices, sample_token, logits_processor):
280
+ sample_token = sample_token.to(tree_indices.device)
281
+
282
+ candidates_logit = sample_token[0]
283
+
284
+ candidates_tree_logits = tree_logits
285
+
286
+ candidates = torch.cat([candidates_logit, candidates_tree_logits.view(-1)], dim=-1)
287
+
288
+ tree_candidates = candidates[tree_indices]
289
+
290
+ tree_candidates_ext = torch.cat(
291
+ [tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device) - 1], dim=0)
292
+
293
+ cart_candidates = tree_candidates_ext[retrieve_indices]
294
+
295
+
296
+ # Unsqueeze the tree candidates for dimension consistency.
297
+ tree_candidates = tree_candidates.unsqueeze(0)
298
+ return cart_candidates, tree_candidates
299
+
300
+
301
+ def tree_decoding(
302
+ model,
303
+ tree_candidates,
304
+ past_key_values,
305
+ tree_position_ids,
306
+ input_ids,
307
+ retrieve_indices,
308
+ ):
309
+ position_ids = tree_position_ids + input_ids.shape[1]
310
+
311
+ outputs, tree_logits, hidden_state = model(
312
+ tree_candidates,
313
+ output_orig=True,
314
+ past_key_values=past_key_values,
315
+ position_ids=position_ids,
316
+ )
317
+
318
+
319
+ logits = tree_logits[0, retrieve_indices]
320
+ return logits, hidden_state, outputs
321
+
322
+
323
+
324
+
325
+
326
+ def evaluate_posterior(
327
+ logits: torch.Tensor,
328
+ candidates: torch.Tensor,
329
+ logits_processor,
330
+ ):
331
+ """
332
+ Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.
333
+
334
+ Depending on the temperature value, the function either uses greedy decoding or evaluates posterior
335
+ probabilities to select the best candidate.
336
+
337
+ Args:
338
+ - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
339
+ - candidates (torch.Tensor): Candidate token sequences.
340
+ - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
341
+ - posterior_threshold (float): Threshold for posterior probability.
342
+ - posterior_alpha (float): Scaling factor for the threshold.
343
+
344
+ Returns:
345
+ - best_candidate (torch.Tensor): Index of the chosen best candidate.
346
+ - accept_length (int): Length of the accepted candidate sequence.
347
+ """
348
+ # Greedy decoding based on temperature value
349
+ if logits_processor is None:
350
+ # Find the tokens that match the maximum logits for each position in the sequence
351
+ posterior_mask = (
352
+ candidates[:, 1:].to(logits.device) == torch.argmax(logits[:, :-1], dim=-1)
353
+ ).int()
354
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
355
+ accept_length = candidates_accept_length.max()
356
+ # Choose the best candidate
357
+ if accept_length == 0:
358
+ # Default to the first candidate if none are accepted
359
+ best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
360
+ else:
361
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
362
+ return best_candidate, accept_length, logits[best_candidate, accept_length]
363
+
364
+ else:
365
+ accept_length = 1
366
+ accept_cand = candidates[0][:1]
367
+ best_candidate = 0
368
+ for i in range(1, candidates.shape[1]):
369
+ if i != accept_length:
370
+ break
371
+ adjustflag = False
372
+ is_eq = (candidates[:, :accept_length] == accept_cand).all(dim=1)
373
+ fi = torch.nonzero(is_eq, as_tuple=True)[0][0]
374
+ gt_logits = logits[fi, i - 1][None]
375
+ gt_logits = logits_processor(None, gt_logits)[0]
376
+ gtp = torch.softmax(gt_logits, dim=0)
377
+ candidates_set = []
378
+ for j in range(candidates.shape[0]):
379
+ if is_eq[j]:
380
+ x = candidates[j, i]
381
+ xi = x.item()
382
+ if xi in candidates_set or xi == -1:
383
+ continue
384
+ candidates_set.append(xi)
385
+ r = random.random()
386
+ px = gtp[xi]
387
+ qx = 1.0
388
+ acp = px / qx
389
+ if r <= acp:
390
+ accept_cand = torch.cat((accept_cand, x[None]), dim=0)
391
+ accept_length += 1
392
+ best_candidate = j
393
+ break
394
+ else:
395
+ gtp[xi] = 0
396
+ gtp = gtp / gtp.sum()
397
+ adjustflag = True
398
+ if adjustflag and accept_length != candidates.shape[1]:
399
+ sample_p = gtp
400
+ else:
401
+ gt_logits = logits[best_candidate, accept_length - 1]
402
+ sample_p = torch.softmax(gt_logits, dim=0)
403
+ return torch.tensor(best_candidate), accept_length - 1, sample_p
404
+
405
+
406
+ @torch.no_grad()
407
+ def update_inference_inputs(
408
+ input_ids,
409
+ candidates,
410
+ best_candidate,
411
+ accept_length,
412
+ retrieve_indices,
413
+ logits_processor,
414
+ new_token,
415
+ past_key_values_data_list,
416
+ current_length_data,
417
+ model,
418
+ hidden_state_new,
419
+ sample_p
420
+ ):
421
+ prev_input_len = input_ids.shape[1]
422
+ # Map the best candidate indices to the original indices in the sequence
423
+ select_indices = (
424
+ retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len
425
+ )
426
+ # Append the tokens from the best candidate to the input sequence
427
+ input_ids = torch.cat(
428
+ [input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1
429
+ )
430
+ # Update the past key values based on the selected tokens
431
+ # Source tensor that contains relevant past information based on the selected candidate
432
+ for past_key_values_data in past_key_values_data_list:
433
+ tgt = past_key_values_data[..., select_indices.to(past_key_values_data.device), :]
434
+ # Destination tensor where the relevant past information will be stored
435
+ dst = past_key_values_data[..., prev_input_len: prev_input_len + tgt.shape[-2], :]
436
+ # Copy relevant past information from the source to the destination
437
+ dst.copy_(tgt, non_blocking=True)
438
+
439
+ # Update the current length tensor (currently only support batch size is 1)
440
+ current_length_data.fill_(prev_input_len + tgt.shape[-2])
441
+
442
+ retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices]
443
+ accept_hidden_state_new = retrieve_hidden_state_new[:, best_candidate, : accept_length + 1]
444
+ # token=model.base_model.lm_head(accept_hidden_state_new[:,-1]).argmax()
445
+ # token=token[None,None]
446
+ prob = sample_p
447
+ if logits_processor is not None:
448
+ token = torch.multinomial(prob, 1)
449
+ token = token[None]
450
+ else:
451
+ token = torch.argmax(prob)
452
+ token = token[None, None]
453
+ # hidden_state = torch.cat((hidden_state, accept_hidden_state_new), dim=1)
454
+ draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new,
455
+ input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1),
456
+ head=model.base_model.lm_head,logits_processor=logits_processor)
457
+
458
+
459
+ new_token += accept_length + 1
460
+
461
+ return input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, None, token
462
+
463
+
464
+ if __name__ == "__main__":
465
+ logits = torch.randn(1, 5)
466
+ tp = prepare_logits_processor(0.9, 0, 0.9, 0)
467
+ l = tp(None, logits)
468
+ if tp is None:
469
+ print(tp)