Isaachh commited on
Commit
114dd13
1 Parent(s): 162e139
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +74 -0
  2. bunny/constants.py +7 -0
  3. bunny/conversation.py +239 -0
  4. bunny/eval/m4c_evaluator.py +334 -0
  5. bunny/eval/model_vqa.py +111 -0
  6. bunny/eval/model_vqa_cmmmu.py +234 -0
  7. bunny/eval/model_vqa_loader.py +143 -0
  8. bunny/eval/model_vqa_mmbench.py +167 -0
  9. bunny/eval/model_vqa_mmmu.py +326 -0
  10. bunny/eval/model_vqa_science.py +119 -0
  11. bunny/model/__init__.py +6 -0
  12. bunny/model/builder.py +197 -0
  13. bunny/model/bunny_arch.py +230 -0
  14. bunny/model/language_model/bunny_llama.py +102 -0
  15. bunny/model/language_model/bunny_minicpm.py +103 -0
  16. bunny/model/language_model/bunny_phi.py +100 -0
  17. bunny/model/language_model/bunny_phi3.py +100 -0
  18. bunny/model/language_model/bunny_qwen.py +100 -0
  19. bunny/model/language_model/bunny_stablelm.py +100 -0
  20. bunny/model/language_model/llama/__init__.py +114 -0
  21. bunny/model/language_model/llama/configuration_llama.py +191 -0
  22. bunny/model/language_model/llama/modeling_llama.py +1844 -0
  23. bunny/model/language_model/llama/tokenization_llama.py +471 -0
  24. bunny/model/language_model/llama/tokenization_llama_fast.py +281 -0
  25. bunny/model/language_model/minicpm/configuration_minicpm.py +202 -0
  26. bunny/model/language_model/minicpm/modeling_minicpm.py +1456 -0
  27. bunny/model/language_model/phi/__init__.py +69 -0
  28. bunny/model/language_model/phi/configuration_phi.py +195 -0
  29. bunny/model/language_model/phi/modeling_phi.py +1374 -0
  30. bunny/model/language_model/phi3/__init__.py +69 -0
  31. bunny/model/language_model/phi3/configuration_phi3.py +213 -0
  32. bunny/model/language_model/phi3/modeling_phi3.py +1597 -0
  33. bunny/model/language_model/qwen2/__init__.py +80 -0
  34. bunny/model/language_model/qwen2/configuration_qwen2.py +144 -0
  35. bunny/model/language_model/qwen2/modeling_qwen2.py +1403 -0
  36. bunny/model/language_model/qwen2/tokenization_qwen2.py +345 -0
  37. bunny/model/language_model/qwen2/tokenization_qwen2_fast.py +143 -0
  38. bunny/model/language_model/stable_lm/configuration_stablelm_epoch.py +113 -0
  39. bunny/model/language_model/stable_lm/modeling_stablelm_epoch.py +917 -0
  40. bunny/model/multimodal_encoder/builder.py +29 -0
  41. bunny/model/multimodal_encoder/clip/clip_encoder.py +76 -0
  42. bunny/model/multimodal_encoder/eva_clip/eva_clip_encoder.py +63 -0
  43. bunny/model/multimodal_encoder/eva_clip/eva_clip_processors.py +68 -0
  44. bunny/model/multimodal_encoder/eva_clip/eva_vit.py +851 -0
  45. bunny/model/multimodal_encoder/siglip/siglip_encoder.py +129 -0
  46. bunny/model/multimodal_projector/builder.py +183 -0
  47. bunny/serve/cli.py +118 -0
  48. bunny/serve/controller.py +277 -0
  49. bunny/serve/examples/example_1.png +0 -0
  50. bunny/serve/examples/example_2.png +0 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time
4
+ import argparse
5
+ import subprocess
6
+
7
+ import bunny.serve.gradio_web_server as gws
8
+
9
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-e', '.'])
10
+
11
+
12
+ def start_controller():
13
+ controller_command = [
14
+ sys.executable, '-m', 'bunny.serve.controller',
15
+ '--host', '0.0.0.0',
16
+ '--port', '10000'
17
+ ]
18
+ return subprocess.Popen(controller_command)
19
+
20
+
21
+ def start_worker(port: int, model_path: str, model_type: str):
22
+ worker_command = [
23
+ sys.executable, '-m', 'bunny.serve.model_worker',
24
+ '--host', '0.0.0.0',
25
+ '--controller', 'http://localhost:10000',
26
+ '--port', f'{port}',
27
+ '--worker', f'http://localhost:{port}',
28
+ '--model-path', model_path,
29
+ '--model-type', model_type
30
+ ]
31
+ return subprocess.Popen(worker_command)
32
+
33
+
34
+ if __name__ == '__main__':
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--host", type=str, default="0.0.0.0")
37
+ parser.add_argument("--port", type=int)
38
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
39
+ parser.add_argument("--concurrency-count", type=int, default=5)
40
+ parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
41
+ parser.add_argument("--share", action="store_true")
42
+ parser.add_argument("--moderate", action="store_true")
43
+ parser.add_argument("--embed", action="store_true")
44
+ gws.args = parser.parse_args()
45
+ gws.models = []
46
+
47
+ controller_proc = start_controller()
48
+
49
+ worker_procs = []
50
+
51
+ worker_procs.append(start_worker(port=40000, model_path='BAAI/Bunny-v1_1-Llama-3-8B-V', model_type='llama3-8b'))
52
+ worker_procs.append(start_worker(port=40001, model_path='BAAI/Bunny-v1_1-4B', model_type='phi-3'))
53
+ worker_procs.append(start_worker(port=40002, model_path='BAAI/Bunny-v1_0-3B', model_type='phi-2'))
54
+
55
+ time.sleep(60)
56
+
57
+ exit_status = 0
58
+ try:
59
+ demo = gws.build_demo(embed_mode=gws.args.embed)
60
+ demo.launch(
61
+ server_name=gws.args.host,
62
+ server_port=gws.args.port,
63
+ share=gws.args.share,
64
+ debug=True,
65
+ max_threads=10
66
+ )
67
+ except Exception as e:
68
+ print(e)
69
+ exit_status = 1
70
+ finally:
71
+ for worker_proc in worker_procs:
72
+ worker_proc.kill()
73
+ controller_proc.kill()
74
+ sys.exit(exit_status)
bunny/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Model Constants
2
+ IGNORE_INDEX = -100
3
+ IMAGE_TOKEN_INDEX = -200
4
+ DEFAULT_IMAGE_TOKEN = "<image>"
5
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
6
+ LOGDIR = "gradio-logs"
7
+ WORKER_HEART_BEAT_INTERVAL = 15
bunny/conversation.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ TWO = auto()
9
+ PLAIN = auto()
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Conversation:
14
+ """A class that keeps all conversation history."""
15
+ system: str
16
+ roles: List[str]
17
+ messages: List[List[str]]
18
+ offset: int
19
+ sep_style: SeparatorStyle
20
+ sep: str = "###"
21
+ sep2: str = None
22
+ version: str = "Unknown"
23
+
24
+ skip_next: bool = False
25
+
26
+ def get_prompt(self):
27
+ messages = self.messages
28
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
29
+ messages = self.messages.copy()
30
+ init_role, init_msg = messages[0].copy()
31
+ init_msg = init_msg[0].replace("<image>", "").strip()
32
+ if 'mmtag' in self.version:
33
+ messages[0] = (init_role, init_msg)
34
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
35
+ messages.insert(1, (self.roles[1], "Received."))
36
+ else:
37
+ messages[0] = (init_role, "<image>\n" + init_msg)
38
+
39
+ if self.sep_style == SeparatorStyle.TWO:
40
+ seps = [self.sep, self.sep2]
41
+ ret = self.system + seps[0]
42
+ for i, (role, message) in enumerate(messages):
43
+ if message:
44
+ if type(message) is tuple:
45
+ message, _, _ = message
46
+ ret += role + ": " + message + seps[i % 2]
47
+ else:
48
+ ret += role + ":"
49
+
50
+ elif self.sep_style == SeparatorStyle.PLAIN:
51
+ seps = [self.sep, self.sep2]
52
+ ret = self.system
53
+ for i, (role, message) in enumerate(messages):
54
+ if message:
55
+ if type(message) is tuple:
56
+ message, _, _ = message
57
+ ret += message + seps[i % 2]
58
+ else:
59
+ ret += ""
60
+ else:
61
+ raise ValueError(f"Invalid style: {self.sep_style}")
62
+
63
+ return ret
64
+
65
+ def append_message(self, role, message):
66
+ self.messages.append([role, message])
67
+
68
+ def get_images(self, return_pil=False):
69
+ images = []
70
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
71
+ if i % 2 == 0:
72
+ if type(msg) is tuple:
73
+ import base64
74
+ from io import BytesIO
75
+ from PIL import Image
76
+ msg, image, image_process_mode = msg
77
+ if image_process_mode == "Pad":
78
+ def expand2square(pil_img, background_color=(122, 116, 104)):
79
+ width, height = pil_img.size
80
+ if width == height:
81
+ return pil_img
82
+ elif width > height:
83
+ result = Image.new(pil_img.mode, (width, width), background_color)
84
+ result.paste(pil_img, (0, (width - height) // 2))
85
+ return result
86
+ else:
87
+ result = Image.new(pil_img.mode, (height, height), background_color)
88
+ result.paste(pil_img, ((height - width) // 2, 0))
89
+ return result
90
+
91
+ image = expand2square(image)
92
+ elif image_process_mode in ["Default", "Crop"]:
93
+ pass
94
+ elif image_process_mode == "Resize":
95
+ image = image.resize((336, 336))
96
+ else:
97
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
98
+
99
+ if return_pil:
100
+ images.append(image)
101
+ else:
102
+ buffered = BytesIO()
103
+ image.save(buffered, format="PNG")
104
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
105
+ images.append(img_b64_str)
106
+ return images
107
+
108
+ def to_gradio_chatbot(self):
109
+ ret = []
110
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
111
+ if i % 2 == 0:
112
+ if type(msg) is tuple:
113
+ import base64
114
+ from io import BytesIO
115
+ msg, image, image_process_mode = msg
116
+ max_hw, min_hw = max(image.size), min(image.size)
117
+ aspect_ratio = max_hw / min_hw
118
+ max_len, min_len = 800, 400
119
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
120
+ longest_edge = int(shortest_edge * aspect_ratio)
121
+ W, H = image.size
122
+ if H > W:
123
+ H, W = longest_edge, shortest_edge
124
+ else:
125
+ H, W = shortest_edge, longest_edge
126
+ image = image.resize((W, H))
127
+ buffered = BytesIO()
128
+ image.save(buffered, format="JPEG")
129
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
130
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
131
+ msg = img_str + msg.replace('<image>', '').strip()
132
+ ret.append([msg, None])
133
+ else:
134
+ ret.append([msg, None])
135
+ else:
136
+ ret[-1][-1] = msg
137
+ return ret
138
+
139
+ def copy(self):
140
+ return Conversation(
141
+ system=self.system,
142
+ roles=self.roles,
143
+ messages=[[x, y] for x, y in self.messages],
144
+ offset=self.offset,
145
+ sep_style=self.sep_style,
146
+ sep=self.sep,
147
+ sep2=self.sep2,
148
+ version=self.version)
149
+
150
+ def dict(self):
151
+ if len(self.get_images()) > 0:
152
+ return {
153
+ "system": self.system,
154
+ "roles": self.roles,
155
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
156
+ "offset": self.offset,
157
+ "sep": self.sep,
158
+ "sep2": self.sep2,
159
+ }
160
+ return {
161
+ "system": self.system,
162
+ "roles": self.roles,
163
+ "messages": self.messages,
164
+ "offset": self.offset,
165
+ "sep": self.sep,
166
+ "sep2": self.sep2,
167
+ }
168
+
169
+
170
+ conv_bunny = Conversation(
171
+ system="A chat between a curious user and an artificial intelligence assistant. "
172
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
173
+ roles=("USER", "ASSISTANT"),
174
+ version="bunny",
175
+ messages=(),
176
+ offset=0,
177
+ sep_style=SeparatorStyle.TWO,
178
+ sep=" ",
179
+ sep2="<|endoftext|>",
180
+ )
181
+
182
+ conv_phi3 = Conversation(
183
+ system="A chat between a curious user and an artificial intelligence assistant. "
184
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
185
+ roles=("USER", "ASSISTANT"),
186
+ version="phi3",
187
+ messages=(),
188
+ offset=0,
189
+ sep_style=SeparatorStyle.TWO,
190
+ sep=" ",
191
+ sep2="<|endoftext|>",
192
+ )
193
+
194
+ conv_minicpm = Conversation(
195
+ system="A chat between a curious user and an artificial intelligence assistant. "
196
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
197
+ roles=("USER", "ASSISTANT"),
198
+ version="minicpm",
199
+ messages=(),
200
+ offset=0,
201
+ sep_style=SeparatorStyle.TWO,
202
+ sep=" ",
203
+ sep2="</s>",
204
+ )
205
+
206
+ conv_llama = Conversation(
207
+ system="A chat between a curious user and an artificial intelligence assistant. "
208
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
209
+ roles=("USER", "ASSISTANT"),
210
+ version="llama",
211
+ messages=(),
212
+ offset=0,
213
+ sep_style=SeparatorStyle.TWO,
214
+ sep=" ",
215
+ sep2="<|end_of_text|>",
216
+ )
217
+
218
+ conv_plain = Conversation(
219
+ system="",
220
+ roles=("", ""),
221
+ messages=(
222
+ ),
223
+ offset=0,
224
+ sep_style=SeparatorStyle.PLAIN,
225
+ sep="\n",
226
+ )
227
+
228
+ default_conversation = conv_bunny
229
+ conv_templates = {
230
+ "default": conv_bunny,
231
+ "bunny": conv_bunny,
232
+ "phi3": conv_phi3,
233
+ "plain": conv_plain,
234
+ 'minicpm': conv_minicpm,
235
+ 'llama': conv_llama
236
+ }
237
+
238
+ if __name__ == "__main__":
239
+ print(default_conversation.get_prompt())
bunny/eval/m4c_evaluator.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import re
3
+
4
+ from tqdm import tqdm
5
+
6
+
7
+ class EvalAIAnswerProcessor:
8
+ """
9
+ Processes an answer similar to Eval AI
10
+ copied from
11
+ https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
12
+ """
13
+
14
+ CONTRACTIONS = {
15
+ "aint": "ain't",
16
+ "arent": "aren't",
17
+ "cant": "can't",
18
+ "couldve": "could've",
19
+ "couldnt": "couldn't",
20
+ "couldn'tve": "couldn't've",
21
+ "couldnt've": "couldn't've",
22
+ "didnt": "didn't",
23
+ "doesnt": "doesn't",
24
+ "dont": "don't",
25
+ "hadnt": "hadn't",
26
+ "hadnt've": "hadn't've",
27
+ "hadn'tve": "hadn't've",
28
+ "hasnt": "hasn't",
29
+ "havent": "haven't",
30
+ "hed": "he'd",
31
+ "hed've": "he'd've",
32
+ "he'dve": "he'd've",
33
+ "hes": "he's",
34
+ "howd": "how'd",
35
+ "howll": "how'll",
36
+ "hows": "how's",
37
+ "Id've": "I'd've",
38
+ "I'dve": "I'd've",
39
+ "Im": "I'm",
40
+ "Ive": "I've",
41
+ "isnt": "isn't",
42
+ "itd": "it'd",
43
+ "itd've": "it'd've",
44
+ "it'dve": "it'd've",
45
+ "itll": "it'll",
46
+ "let's": "let's",
47
+ "maam": "ma'am",
48
+ "mightnt": "mightn't",
49
+ "mightnt've": "mightn't've",
50
+ "mightn'tve": "mightn't've",
51
+ "mightve": "might've",
52
+ "mustnt": "mustn't",
53
+ "mustve": "must've",
54
+ "neednt": "needn't",
55
+ "notve": "not've",
56
+ "oclock": "o'clock",
57
+ "oughtnt": "oughtn't",
58
+ "ow's'at": "'ow's'at",
59
+ "'ows'at": "'ow's'at",
60
+ "'ow'sat": "'ow's'at",
61
+ "shant": "shan't",
62
+ "shed've": "she'd've",
63
+ "she'dve": "she'd've",
64
+ "she's": "she's",
65
+ "shouldve": "should've",
66
+ "shouldnt": "shouldn't",
67
+ "shouldnt've": "shouldn't've",
68
+ "shouldn'tve": "shouldn't've",
69
+ "somebody'd": "somebodyd",
70
+ "somebodyd've": "somebody'd've",
71
+ "somebody'dve": "somebody'd've",
72
+ "somebodyll": "somebody'll",
73
+ "somebodys": "somebody's",
74
+ "someoned": "someone'd",
75
+ "someoned've": "someone'd've",
76
+ "someone'dve": "someone'd've",
77
+ "someonell": "someone'll",
78
+ "someones": "someone's",
79
+ "somethingd": "something'd",
80
+ "somethingd've": "something'd've",
81
+ "something'dve": "something'd've",
82
+ "somethingll": "something'll",
83
+ "thats": "that's",
84
+ "thered": "there'd",
85
+ "thered've": "there'd've",
86
+ "there'dve": "there'd've",
87
+ "therere": "there're",
88
+ "theres": "there's",
89
+ "theyd": "they'd",
90
+ "theyd've": "they'd've",
91
+ "they'dve": "they'd've",
92
+ "theyll": "they'll",
93
+ "theyre": "they're",
94
+ "theyve": "they've",
95
+ "twas": "'twas",
96
+ "wasnt": "wasn't",
97
+ "wed've": "we'd've",
98
+ "we'dve": "we'd've",
99
+ "weve": "we've",
100
+ "werent": "weren't",
101
+ "whatll": "what'll",
102
+ "whatre": "what're",
103
+ "whats": "what's",
104
+ "whatve": "what've",
105
+ "whens": "when's",
106
+ "whered": "where'd",
107
+ "wheres": "where's",
108
+ "whereve": "where've",
109
+ "whod": "who'd",
110
+ "whod've": "who'd've",
111
+ "who'dve": "who'd've",
112
+ "wholl": "who'll",
113
+ "whos": "who's",
114
+ "whove": "who've",
115
+ "whyll": "why'll",
116
+ "whyre": "why're",
117
+ "whys": "why's",
118
+ "wont": "won't",
119
+ "wouldve": "would've",
120
+ "wouldnt": "wouldn't",
121
+ "wouldnt've": "wouldn't've",
122
+ "wouldn'tve": "wouldn't've",
123
+ "yall": "y'all",
124
+ "yall'll": "y'all'll",
125
+ "y'allll": "y'all'll",
126
+ "yall'd've": "y'all'd've",
127
+ "y'alld've": "y'all'd've",
128
+ "y'all'dve": "y'all'd've",
129
+ "youd": "you'd",
130
+ "youd've": "you'd've",
131
+ "you'dve": "you'd've",
132
+ "youll": "you'll",
133
+ "youre": "you're",
134
+ "youve": "you've",
135
+ }
136
+
137
+ NUMBER_MAP = {
138
+ "none": "0",
139
+ "zero": "0",
140
+ "one": "1",
141
+ "two": "2",
142
+ "three": "3",
143
+ "four": "4",
144
+ "five": "5",
145
+ "six": "6",
146
+ "seven": "7",
147
+ "eight": "8",
148
+ "nine": "9",
149
+ "ten": "10",
150
+ }
151
+ ARTICLES = ["a", "an", "the"]
152
+ PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
153
+ COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
154
+ PUNCTUATIONS = [
155
+ ";",
156
+ r"/",
157
+ "[",
158
+ "]",
159
+ '"',
160
+ "{",
161
+ "}",
162
+ "(",
163
+ ")",
164
+ "=",
165
+ "+",
166
+ "\\",
167
+ "_",
168
+ "-",
169
+ ">",
170
+ "<",
171
+ "@",
172
+ "`",
173
+ ",",
174
+ "?",
175
+ "!",
176
+ ]
177
+
178
+ def __init__(self, *args, **kwargs):
179
+ pass
180
+
181
+ def word_tokenize(self, word):
182
+ word = word.lower()
183
+ word = word.replace(",", "").replace("?", "").replace("'s", " 's")
184
+ return word.strip()
185
+
186
+ def process_punctuation(self, in_text):
187
+ out_text = in_text
188
+ for p in self.PUNCTUATIONS:
189
+ if (p + " " in in_text or " " + p in in_text) or (
190
+ re.search(self.COMMA_STRIP, in_text) is not None
191
+ ):
192
+ out_text = out_text.replace(p, "")
193
+ else:
194
+ out_text = out_text.replace(p, " ")
195
+ out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
196
+ return out_text
197
+
198
+ def process_digit_article(self, in_text):
199
+ out_text = []
200
+ temp_text = in_text.lower().split()
201
+ for word in temp_text:
202
+ word = self.NUMBER_MAP.setdefault(word, word)
203
+ if word not in self.ARTICLES:
204
+ out_text.append(word)
205
+ else:
206
+ pass
207
+ for word_id, word in enumerate(out_text):
208
+ if word in self.CONTRACTIONS:
209
+ out_text[word_id] = self.CONTRACTIONS[word]
210
+ out_text = " ".join(out_text)
211
+ return out_text
212
+
213
+ def __call__(self, item):
214
+ item = self.word_tokenize(item)
215
+ item = item.replace("\n", " ").replace("\t", " ").strip()
216
+ item = self.process_punctuation(item)
217
+ item = self.process_digit_article(item)
218
+ return item
219
+
220
+
221
+ class TextVQAAccuracyEvaluator:
222
+ def __init__(self):
223
+ self.answer_processor = EvalAIAnswerProcessor()
224
+
225
+ def _compute_answer_scores(self, raw_answers):
226
+ """
227
+ compute the accuracy (soft score) of human answers
228
+ """
229
+ answers = [self.answer_processor(a) for a in raw_answers]
230
+ assert len(answers) == 10
231
+ gt_answers = list(enumerate(answers))
232
+ unique_answers = set(answers)
233
+ unique_answer_scores = {}
234
+
235
+ for unique_answer in unique_answers:
236
+ accs = []
237
+ for gt_answer in gt_answers:
238
+ other_answers = [item for item in gt_answers if item != gt_answer]
239
+ matching_answers = [
240
+ item for item in other_answers if item[1] == unique_answer
241
+ ]
242
+ acc = min(1, float(len(matching_answers)) / 3)
243
+ accs.append(acc)
244
+ unique_answer_scores[unique_answer] = sum(accs) / len(accs)
245
+
246
+ return unique_answer_scores
247
+
248
+ def eval_pred_list(self, pred_list):
249
+ pred_scores = []
250
+ for entry in tqdm(pred_list):
251
+ pred_answer = self.answer_processor(entry["pred_answer"])
252
+ unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
253
+ score = unique_answer_scores.get(pred_answer, 0.0)
254
+ pred_scores.append(score)
255
+
256
+ accuracy = sum(pred_scores) / len(pred_scores)
257
+ return accuracy
258
+
259
+
260
+ class STVQAAccuracyEvaluator:
261
+ def __init__(self):
262
+ self.answer_processor = EvalAIAnswerProcessor()
263
+
264
+ def eval_pred_list(self, pred_list):
265
+ pred_scores = []
266
+ for entry in pred_list:
267
+ pred_answer = self.answer_processor(entry["pred_answer"])
268
+ gts = [self.answer_processor(a) for a in entry["gt_answers"]]
269
+ score = 1.0 if pred_answer in gts else 0.0
270
+ pred_scores.append(score)
271
+
272
+ accuracy = sum(pred_scores) / len(pred_scores)
273
+ return accuracy
274
+
275
+
276
+ class STVQAANLSEvaluator:
277
+ def __init__(self):
278
+ import editdistance # install with `pip install editdistance`
279
+
280
+ self.get_edit_distance = editdistance.eval
281
+
282
+ def get_anls(self, s1, s2):
283
+ s1 = s1.lower().strip()
284
+ s2 = s2.lower().strip()
285
+ iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
286
+ anls = iou if iou >= 0.5 else 0.0
287
+ return anls
288
+
289
+ def eval_pred_list(self, pred_list):
290
+ pred_scores = []
291
+ for entry in pred_list:
292
+ anls = max(
293
+ self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
294
+ )
295
+ pred_scores.append(anls)
296
+
297
+ accuracy = sum(pred_scores) / len(pred_scores)
298
+ return accuracy
299
+
300
+
301
+ class TextCapsBleu4Evaluator:
302
+ def __init__(self):
303
+ # The following script requires Java 1.8.0 and pycocotools installed.
304
+ # The pycocoevalcap can be installed with pip as
305
+ # pip install git+https://github.com/ronghanghu/coco-caption.git@python23
306
+ # Original pycocoevalcap code is at https://github.com/tylin/coco-caption
307
+ # but has no python3 support yet.
308
+ try:
309
+ from pycocoevalcap.bleu.bleu import Bleu
310
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
311
+ except ModuleNotFoundError:
312
+ print(
313
+ "Please install pycocoevalcap module using "
314
+ "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
315
+ )
316
+ raise
317
+
318
+ self.tokenizer = PTBTokenizer()
319
+ self.scorer = Bleu(4)
320
+
321
+ def eval_pred_list(self, pred_list):
322
+ # Create reference and hypotheses captions.
323
+ gts = {}
324
+ res = {}
325
+ for idx, entry in enumerate(pred_list):
326
+ gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
327
+ res[idx] = [{"caption": entry["pred_answer"]}]
328
+
329
+ gts = self.tokenizer.tokenize(gts)
330
+ res = self.tokenizer.tokenize(res)
331
+ score, _ = self.scorer.compute_score(gts, res)
332
+
333
+ bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
334
+ return bleu4
bunny/eval/model_vqa.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
+ from bunny.conversation import conv_templates, SeparatorStyle
10
+ from bunny.model.builder import load_pretrained_model
11
+ from bunny.util.utils import disable_torch_init
12
+ from bunny.util.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ def eval_model(args):
30
+ # Model
31
+ disable_torch_init()
32
+ model_path = os.path.expanduser(args.model_path)
33
+ model_name = get_model_name_from_path(model_path)
34
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
35
+ args.model_type)
36
+
37
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
38
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
39
+ answers_file = os.path.expanduser(args.answers_file)
40
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
41
+ ans_file = open(answers_file, "w")
42
+ for line in tqdm(questions):
43
+ idx = line["question_id"]
44
+ image_file = line["image"]
45
+ qs = line["text"]
46
+ cur_prompt = qs
47
+
48
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
49
+
50
+ conv = conv_templates[args.conv_mode].copy()
51
+ conv.append_message(conv.roles[0], qs)
52
+ conv.append_message(conv.roles[1], None)
53
+ prompt = conv.get_prompt()
54
+
55
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
56
+
57
+ image = Image.open(os.path.join(args.image_folder, image_file))
58
+ image_tensor = process_images([image], image_processor, model.config)[0]
59
+
60
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
61
+
62
+ with torch.inference_mode():
63
+ output_ids = model.generate(
64
+ input_ids,
65
+ images=image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
66
+ do_sample=True if args.temperature > 0 else False,
67
+ temperature=args.temperature,
68
+ top_p=args.top_p,
69
+ num_beams=args.num_beams,
70
+ # no_repeat_ngram_size=3,
71
+ max_new_tokens=1024,
72
+ use_cache=True)
73
+
74
+ input_token_len = input_ids.shape[1]
75
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
76
+ if n_diff_input_output > 0:
77
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
78
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
79
+ outputs = outputs.strip()
80
+ if outputs.endswith(stop_str):
81
+ outputs = outputs[:-len(stop_str)]
82
+ outputs = outputs.strip()
83
+
84
+ ans_id = shortuuid.uuid()
85
+ ans_file.write(json.dumps({"question_id": idx,
86
+ "prompt": cur_prompt,
87
+ "text": outputs,
88
+ "answer_id": ans_id,
89
+ "model_id": model_name,
90
+ "metadata": {}}) + "\n")
91
+ ans_file.flush()
92
+ ans_file.close()
93
+
94
+
95
+ if __name__ == "__main__":
96
+ parser = argparse.ArgumentParser()
97
+ parser.add_argument("--model-path", type=str, default=None)
98
+ parser.add_argument("--model-base", type=str, default=None)
99
+ parser.add_argument("--model-type", type=str, default=None)
100
+ parser.add_argument("--image-folder", type=str, default=None)
101
+ parser.add_argument("--question-file", type=str, default=None)
102
+ parser.add_argument("--answers-file", type=str, default=None)
103
+ parser.add_argument("--conv-mode", type=str, default=None)
104
+ parser.add_argument("--num-chunks", type=int, default=1)
105
+ parser.add_argument("--chunk-idx", type=int, default=0)
106
+ parser.add_argument("--temperature", type=float, default=0.2)
107
+ parser.add_argument("--top_p", type=float, default=None)
108
+ parser.add_argument("--num_beams", type=int, default=1)
109
+ args = parser.parse_args()
110
+
111
+ eval_model(args)
bunny/eval/model_vqa_cmmmu.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import os
4
+ import json
5
+ import yaml
6
+ import torch
7
+
8
+ from tqdm import tqdm
9
+ from datasets import load_dataset, concatenate_datasets
10
+ from argparse import ArgumentParser
11
+
12
+ from bunny.model.builder import load_pretrained_model
13
+ from bunny.util.mm_utils import get_model_name_from_path, tokenizer_image_token, process_images
14
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
15
+ from bunny.conversation import conv_templates
16
+
17
+ CAT_CN2EN = {'艺术与设计': 'art_and_design',
18
+ '商业': 'business',
19
+ '健康与医学': 'health_and_medicine',
20
+ '人文社会科学': 'humanities_and_social_sciences',
21
+ '科学': 'science',
22
+ '技术与工程': 'technology_and_engineering'}
23
+
24
+
25
+ def call_bunny_engine_df(args, sample, model, tokenizer=None, processor=None):
26
+ def deal_with_prompt(input_text):
27
+ qs = input_text
28
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
29
+ return qs
30
+
31
+ prompt = sample['final_input_prompt']
32
+ prompt = deal_with_prompt(prompt)
33
+
34
+ conv = conv_templates[args.conv_mode].copy()
35
+ conv.append_message(conv.roles[0], prompt)
36
+ conv.append_message(conv.roles[1], None)
37
+ prompt = conv.get_prompt()
38
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
39
+
40
+ image = sample['image_1']
41
+ if sample['image_2'] is not None: # multiple images actually
42
+ if sample['type'] == '选择':
43
+ all_choices = sample['all_choices']
44
+ response = random.choice(all_choices)
45
+ else:
46
+ response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS'
47
+ elif image is not None:
48
+ output_ids = model.generate(
49
+ input_ids,
50
+ images=image.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
51
+ do_sample=False,
52
+ temperature=0,
53
+ top_p=None,
54
+ # num_beams=5,
55
+ max_new_tokens=128,
56
+ use_cache=True)
57
+
58
+ input_token_len = input_ids.shape[1]
59
+ # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
60
+ # if n_diff_input_output > 0:
61
+ # print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
62
+ response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
63
+
64
+ return response
65
+
66
+
67
+ def load_yaml(file_path):
68
+ with open(file_path, 'r') as stream:
69
+ try:
70
+ yaml_dict = yaml.safe_load(stream)
71
+ except yaml.YAMLError as exc:
72
+ print(exc)
73
+
74
+ return yaml_dict
75
+
76
+
77
+ # DATA PROCESSING
78
+ def construct_prompt(sample, config):
79
+ question = sample['question']
80
+ options = []
81
+ for i in range(1, 5):
82
+ if sample[f'option{i}'] is None:
83
+ break
84
+ options.append(sample[f'option{i}'])
85
+
86
+ example = ""
87
+ if sample['type'] == '选择':
88
+ start_chr = 'A'
89
+ prediction_range = []
90
+ for option in options:
91
+ prediction_range.append(start_chr)
92
+ example += f"({start_chr}) {option}\n"
93
+ start_chr = chr(ord(start_chr) + 1)
94
+ empty_prompt_sample_structure = config['multi_choice_example_format']
95
+ empty_prompt = empty_prompt_sample_structure.format(question, example)
96
+ res_dict = {}
97
+ res_dict['correct_choice'] = sample['answer']
98
+ res_dict['all_choices'] = prediction_range
99
+ res_dict['empty_prompt'] = empty_prompt
100
+ if config['task_instructions']:
101
+ res_dict['final_input_prompt'] = config['task_instructions'][0].strip() + '\n\n' + empty_prompt
102
+ else:
103
+ res_dict['final_input_prompt'] = empty_prompt
104
+
105
+ res_dict['gt_content'] = sample['answer']
106
+ elif sample['type'] == '判断':
107
+ empty_prompt_sample_structure = config['T/F_example_format']
108
+ empty_prompt = empty_prompt_sample_structure.format(question, example)
109
+ res_dict = {}
110
+ res_dict['empty_prompt'] = empty_prompt
111
+ if config['task_instructions']:
112
+ res_dict['final_input_prompt'] = config['task_instructions'][1].strip() + '\n\n' + empty_prompt
113
+ else:
114
+ res_dict['final_input_prompt'] = empty_prompt
115
+ res_dict['gt_content'] = sample['answer']
116
+ else:
117
+ empty_prompt_sample_structure = config['short_ans_example_format']
118
+ empty_prompt = empty_prompt_sample_structure.format(question)
119
+ res_dict = {}
120
+ res_dict['empty_prompt'] = empty_prompt
121
+ if config['task_instructions']:
122
+ res_dict['final_input_prompt'] = config['task_instructions'][2].strip() + '\n\n' + empty_prompt
123
+ else:
124
+ res_dict['final_input_prompt'] = empty_prompt
125
+ res_dict['gt_content'] = sample['answer']
126
+
127
+ res_dict.update(sample)
128
+ return res_dict
129
+
130
+
131
+ def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
132
+ out_samples = []
133
+ with torch.no_grad():
134
+ for sample in tqdm(samples):
135
+ if args.small_gpu_usage:
136
+ sample['image_1'] = sample['image_1'].cuda()
137
+ response = call_model_engine_fn(args, sample, model, tokenizer, processor)
138
+ if args.small_gpu_usage:
139
+ sample['image_1'] = sample['image_1'].cpu()
140
+
141
+ out_sample = dict()
142
+ out_sample['id'] = sample['id']
143
+ out_sample['type'] = sample['type']
144
+ out_sample['response'] = response
145
+ out_samples.append(out_sample)
146
+ return out_samples
147
+
148
+
149
+ def set_seed(seed_value):
150
+ """
151
+ Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.
152
+
153
+ :param seed_value: An integer value to be used as the seed.
154
+ """
155
+ torch.manual_seed(seed_value)
156
+ if torch.cuda.is_available():
157
+ torch.cuda.manual_seed(seed_value)
158
+ torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups
159
+ random.seed(seed_value)
160
+ np.random.seed(seed_value)
161
+ torch.backends.cudnn.deterministic = True
162
+ torch.backends.cudnn.benchmark = False
163
+
164
+
165
+ def main():
166
+ parser = ArgumentParser()
167
+ parser.add_argument('--model-path', type=str, default=None)
168
+ parser.add_argument('--model-base', type=str, default=None)
169
+ parser.add_argument("--model-type", type=str, default=None)
170
+ parser.add_argument("--conv-mode", type=str, default=None)
171
+ parser.add_argument('--data-path', type=str, default=None)
172
+ parser.add_argument('--config-path', type=str, default=None)
173
+ parser.add_argument('--output-path', type=str, default=None)
174
+ parser.add_argument('--split', type=str, default='validation')
175
+ parser.add_argument('--seed', type=int, default=42)
176
+ parser.add_argument("--small-gpu-usage", action="store_true")
177
+
178
+ args = parser.parse_args()
179
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
180
+ set_seed(args.seed)
181
+
182
+ print('bunny_initializing...')
183
+ processor = None
184
+ call_model_engine = call_bunny_engine_df
185
+
186
+ # load config and process to one value
187
+ args.config = load_yaml(args.config_path)
188
+ for key, value in args.config.items():
189
+ if key == 'task_instructions':
190
+ args.config[key] = value
191
+ elif key != 'eval_params' and type(value) == list:
192
+ assert len(value) == 1, 'key {} has more than one value'.format(key)
193
+ args.config[key] = value[0]
194
+
195
+ # run for each subject
196
+ sub_dataset_list = []
197
+ for subject in CAT_CN2EN.values():
198
+ sub_dataset = load_dataset(args.data_path, subject, split=args.split)
199
+ sub_dataset_list.append(sub_dataset)
200
+
201
+ # merge all dataset
202
+ dataset = concatenate_datasets(sub_dataset_list)
203
+
204
+ # load model
205
+ model_path = os.path.expanduser(args.model_path)
206
+ model_name = get_model_name_from_path(model_path)
207
+ tokenizer, model, vis_processors, context_len = load_pretrained_model(model_path, args.model_base, model_name,
208
+ args.model_type)
209
+
210
+ samples = []
211
+ print('Processing CMMMU dataset...')
212
+ for sample in tqdm(dataset):
213
+
214
+ sample = construct_prompt(sample, args.config)
215
+ if sample['image_1']:
216
+ sample['image_1'] = process_images([sample['image_1'].convert('RGB')], vis_processors, model.config)[0]
217
+ if not args.small_gpu_usage:
218
+ sample['image_1'] = sample['image_1'].to(device)
219
+
220
+ samples.append(sample)
221
+
222
+ print('Start to evaluate...')
223
+ # run ex
224
+ out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
225
+
226
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
227
+
228
+ with open(args.output_path, 'w') as f:
229
+ for out_sample in out_samples:
230
+ f.write(json.dumps(out_sample) + '\n')
231
+
232
+
233
+ if __name__ == '__main__':
234
+ main()
bunny/eval/model_vqa_loader.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
+ from bunny.conversation import conv_templates
10
+ from bunny.model.builder import load_pretrained_model
11
+ from bunny.util.utils import disable_torch_init
12
+ from bunny.util.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
+ from torch.utils.data import Dataset, DataLoader
14
+
15
+ from PIL import Image
16
+ import math
17
+
18
+
19
+ def split_list(lst, n):
20
+ """Split a list into n (roughly) equal-sized chunks"""
21
+ chunk_size = math.ceil(len(lst) / n) # integer division
22
+ return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
23
+
24
+
25
+ def get_chunk(lst, n, k):
26
+ chunks = split_list(lst, n)
27
+ return chunks[k]
28
+
29
+
30
+ # Custom dataset class
31
+ class CustomDataset(Dataset):
32
+ def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
33
+ self.questions = questions
34
+ self.image_folder = image_folder
35
+ self.tokenizer = tokenizer
36
+ self.image_processor = image_processor
37
+ self.model_config = model_config
38
+
39
+ def __getitem__(self, index):
40
+ line = self.questions[index]
41
+ image_file = line["image"]
42
+ qs = line["text"]
43
+
44
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
45
+
46
+ conv = conv_templates[args.conv_mode].copy()
47
+ conv.append_message(conv.roles[0], qs)
48
+ conv.append_message(conv.roles[1], None)
49
+ prompt = conv.get_prompt()
50
+
51
+ image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
52
+ image_tensor = process_images([image], self.image_processor, self.model_config)[0]
53
+
54
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
55
+
56
+ return input_ids, image_tensor
57
+
58
+ def __len__(self):
59
+ return len(self.questions)
60
+
61
+
62
+ # DataLoader
63
+ def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
64
+ assert batch_size == 1, "batch_size must be 1"
65
+ dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
66
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
67
+ return data_loader
68
+
69
+
70
+ def eval_model(args):
71
+ # Model
72
+ disable_torch_init()
73
+ model_path = os.path.expanduser(args.model_path)
74
+ model_name = get_model_name_from_path(model_path)
75
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
76
+ args.model_type)
77
+
78
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
79
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
80
+ answers_file = os.path.expanduser(args.answers_file)
81
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
82
+ ans_file = open(answers_file, "w")
83
+
84
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
85
+ args.conv_mode = args.conv_mode + '_mmtag'
86
+ print(
87
+ f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
88
+
89
+ data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
90
+
91
+ for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
92
+ idx = line["question_id"]
93
+ cur_prompt = line["text"]
94
+
95
+ input_ids = input_ids.to(device='cuda', non_blocking=True)
96
+
97
+ with torch.inference_mode():
98
+ output_ids = model.generate(
99
+ input_ids,
100
+ images=image_tensor.to(dtype=model.dtype, device='cuda', non_blocking=True),
101
+ do_sample=True if args.temperature > 0 else False,
102
+ temperature=args.temperature,
103
+ top_p=args.top_p,
104
+ num_beams=args.num_beams,
105
+ max_new_tokens=args.max_new_tokens,
106
+ use_cache=True)
107
+
108
+ input_token_len = input_ids.shape[1]
109
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
110
+ if n_diff_input_output > 0:
111
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
112
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
113
+ outputs = outputs.strip()
114
+
115
+ ans_id = shortuuid.uuid()
116
+ ans_file.write(json.dumps({"question_id": idx,
117
+ "prompt": cur_prompt,
118
+ "text": outputs,
119
+ "answer_id": ans_id,
120
+ "model_id": model_name,
121
+ "metadata": {}}) + "\n")
122
+ # ans_file.flush()
123
+ ans_file.close()
124
+
125
+
126
+ if __name__ == "__main__":
127
+ parser = argparse.ArgumentParser()
128
+ parser.add_argument("--model-path", type=str, default=None)
129
+ parser.add_argument("--model-base", type=str, default=None)
130
+ parser.add_argument("--model-type", type=str, default=None)
131
+ parser.add_argument("--image-folder", type=str, default=None)
132
+ parser.add_argument("--question-file", type=str, default=None)
133
+ parser.add_argument("--answers-file", type=str, default=None)
134
+ parser.add_argument("--conv-mode", type=str, default=None)
135
+ parser.add_argument("--num-chunks", type=int, default=1)
136
+ parser.add_argument("--chunk-idx", type=int, default=0)
137
+ parser.add_argument("--temperature", type=float, default=0.2)
138
+ parser.add_argument("--top_p", type=float, default=None)
139
+ parser.add_argument("--num_beams", type=int, default=1)
140
+ parser.add_argument("--max_new_tokens", type=int, default=128)
141
+ args = parser.parse_args()
142
+
143
+ eval_model(args)
bunny/eval/model_vqa_mmbench.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ import shortuuid
8
+
9
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
10
+ from bunny.conversation import conv_templates, SeparatorStyle
11
+ from bunny.model.builder import load_pretrained_model
12
+ from bunny.util.utils import disable_torch_init
13
+ from bunny.util.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, \
14
+ get_model_name_from_path
15
+
16
+ import math
17
+
18
+ all_options = ['A', 'B', 'C', 'D']
19
+
20
+
21
+ def split_list(lst, n):
22
+ """Split a list into n (roughly) equal-sized chunks"""
23
+ chunk_size = math.ceil(len(lst) / n) # integer division
24
+ return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
25
+
26
+
27
+ def get_chunk(lst, n, k):
28
+ chunks = split_list(lst, n)
29
+ return chunks[k]
30
+
31
+
32
+ def is_none(value):
33
+ if value is None:
34
+ return True
35
+ if type(value) is float and math.isnan(value):
36
+ return True
37
+ if type(value) is str and value.lower() == 'nan':
38
+ return True
39
+ if type(value) is str and value.lower() == 'none':
40
+ return True
41
+ return False
42
+
43
+
44
+ def get_options(row, options):
45
+ parsed_options = []
46
+ for option in options:
47
+ option_value = row[option]
48
+ if is_none(option_value):
49
+ break
50
+ parsed_options.append(option_value)
51
+ return parsed_options
52
+
53
+
54
+ def eval_model(args):
55
+ # Model
56
+ disable_torch_init()
57
+ model_path = os.path.expanduser(args.model_path)
58
+ model_name = get_model_name_from_path(model_path)
59
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
60
+ args.model_type)
61
+
62
+ questions = pd.read_table(os.path.expanduser(args.question_file))
63
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
64
+ answers_file = os.path.expanduser(args.answers_file)
65
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
66
+ ans_file = open(answers_file, "w")
67
+
68
+ for index, row in tqdm(questions.iterrows(), total=len(questions)):
69
+ options = get_options(row, all_options)
70
+ cur_option_char = all_options[:len(options)]
71
+
72
+ if args.all_rounds:
73
+ num_rounds = len(options)
74
+ else:
75
+ num_rounds = 1
76
+
77
+ for round_idx in range(num_rounds):
78
+ idx = row['index']
79
+ question = row['question']
80
+ hint = row['hint']
81
+ image = load_image_from_base64(row['image'])
82
+ if not is_none(hint):
83
+ question = hint + '\n' + question
84
+ for option_char, option in zip(all_options[:len(options)], options):
85
+ question = question + '\n' + option_char + '. ' + option
86
+ qs = cur_prompt = question
87
+
88
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
89
+
90
+ if args.single_pred_prompt:
91
+ if args.lang == 'cn':
92
+ qs = qs + '\n' + "请直接回答选项字母。"
93
+ else:
94
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
95
+
96
+ conv = conv_templates[args.conv_mode].copy()
97
+ conv.append_message(conv.roles[0], qs)
98
+ conv.append_message(conv.roles[1], None)
99
+ prompt = conv.get_prompt()
100
+
101
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
102
+ 0).cuda()
103
+
104
+ image_tensor = process_images([image], image_processor, model.config)[0]
105
+
106
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
107
+
108
+ with torch.inference_mode():
109
+ output_ids = model.generate(
110
+ input_ids,
111
+ images=image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
112
+ do_sample=True if args.temperature > 0 else False,
113
+ temperature=args.temperature,
114
+ top_p=args.top_p,
115
+ num_beams=args.num_beams,
116
+ # no_repeat_ngram_size=3,
117
+ max_new_tokens=128,
118
+ use_cache=True)
119
+
120
+ input_token_len = input_ids.shape[1]
121
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
122
+ if n_diff_input_output > 0:
123
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
124
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
125
+ outputs = outputs.strip()
126
+ if outputs.endswith(stop_str):
127
+ outputs = outputs[:-len(stop_str)]
128
+ outputs = outputs.strip()
129
+
130
+ ans_id = shortuuid.uuid()
131
+ ans_file.write(json.dumps({"question_id": idx,
132
+ "round_id": round_idx,
133
+ "prompt": cur_prompt,
134
+ "text": outputs,
135
+ "options": options,
136
+ "option_char": cur_option_char,
137
+ "answer_id": ans_id,
138
+ "model_id": model_name,
139
+ "metadata": {}}) + "\n")
140
+ ans_file.flush()
141
+
142
+ # rotate options
143
+ options = options[1:] + options[:1]
144
+ cur_option_char = cur_option_char[1:] + cur_option_char[:1]
145
+ ans_file.close()
146
+
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser()
150
+ parser.add_argument("--model-path", type=str, default=None)
151
+ parser.add_argument("--model-base", type=str, default=None)
152
+ parser.add_argument("--model-type", type=str, default=None)
153
+ parser.add_argument("--image-folder", type=str, default=None)
154
+ parser.add_argument("--question-file", type=str, default=None)
155
+ parser.add_argument("--answers-file", type=str, default=None)
156
+ parser.add_argument("--conv-mode", type=str, default=None)
157
+ parser.add_argument("--num-chunks", type=int, default=1)
158
+ parser.add_argument("--chunk-idx", type=int, default=0)
159
+ parser.add_argument("--temperature", type=float, default=0.2)
160
+ parser.add_argument("--top_p", type=float, default=None)
161
+ parser.add_argument("--num_beams", type=int, default=1)
162
+ parser.add_argument("--all-rounds", action="store_true")
163
+ parser.add_argument("--single-pred-prompt", action="store_true")
164
+ parser.add_argument("--lang", type=str, default="en")
165
+ args = parser.parse_args()
166
+
167
+ eval_model(args)
bunny/eval/model_vqa_mmmu.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ import json
6
+ import yaml
7
+ import torch
8
+
9
+ from tqdm import tqdm
10
+ from datasets import load_dataset, concatenate_datasets
11
+ from argparse import ArgumentParser
12
+
13
+ from bunny.model.builder import load_pretrained_model
14
+ from bunny.util.mm_utils import get_model_name_from_path, tokenizer_image_token, process_images
15
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
16
+ from bunny.conversation import conv_templates
17
+
18
+ CAT_SHORT2LONG = {
19
+ 'acc': 'Accounting',
20
+ 'agri': 'Agriculture',
21
+ 'arch': 'Architecture_and_Engineering',
22
+ 'art': 'Art',
23
+ 'art_theory': 'Art_Theory',
24
+ 'bas_med': 'Basic_Medical_Science',
25
+ 'bio': 'Biology',
26
+ 'chem': 'Chemistry',
27
+ 'cli_med': 'Clinical_Medicine',
28
+ 'cs': 'Computer_Science',
29
+ 'design': 'Design',
30
+ 'diag_med': 'Diagnostics_and_Laboratory_Medicine',
31
+ 'econ': 'Economics',
32
+ 'elec': 'Electronics',
33
+ 'ep': 'Energy_and_Power',
34
+ 'fin': 'Finance',
35
+ 'geo': 'Geography',
36
+ 'his': 'History',
37
+ 'liter': 'Literature',
38
+ 'manage': 'Manage',
39
+ 'mark': 'Marketing',
40
+ 'mate': 'Materials',
41
+ 'math': 'Math',
42
+ 'mech': 'Mechanical_Engineering',
43
+ 'music': 'Music',
44
+ 'phar': 'Pharmacy',
45
+ 'phys': 'Physics',
46
+ 'psy': 'Psychology',
47
+ 'pub_health': 'Public_Health',
48
+ 'socio': 'Sociology'
49
+ }
50
+
51
+
52
+ # ----------- Process Multi-choice -------------
53
+ def parse_multi_choice_response(response, all_choices, index2ans):
54
+ """
55
+ Parse the prediction from the generated response.
56
+ Return the predicted index e.g., A, B, C, D.
57
+ """
58
+ for char in [',', '.', '!', '?', ';', ':', "'"]:
59
+ response = response.strip(char)
60
+ response = " " + response + " " # add space to avoid partial match
61
+
62
+ index_ans = True
63
+ ans_with_brack = False
64
+ candidates = []
65
+ for choice in all_choices: # e.g., (A) (B) (C) (D)
66
+ if f'({choice})' in response:
67
+ candidates.append(choice)
68
+ ans_with_brack = True
69
+
70
+ if len(candidates) == 0:
71
+ for choice in all_choices: # e.g., A B C D
72
+ if f' {choice} ' in response:
73
+ candidates.append(choice)
74
+
75
+ # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
76
+ if len(candidates) == 0 and len(response.split()) > 5:
77
+ for index, ans in index2ans.items():
78
+ if ans.lower() in response.lower():
79
+ candidates.append(index)
80
+ index_ans = False # it's content ans.
81
+
82
+ if len(candidates) == 0: # still not get answer, randomly choose one.
83
+ pred_index = random.choice(all_choices)
84
+ elif len(candidates) > 1:
85
+ start_indexes = []
86
+ if index_ans:
87
+ if ans_with_brack:
88
+ for can in candidates:
89
+ index = response.rfind(f'({can})')
90
+ start_indexes.append(index) # -1 will be ignored anyway
91
+ # start_indexes = [generated_response.index(f'({can})') for can in candidates]
92
+ else:
93
+ for can in candidates:
94
+ index = response.rfind(f" {can} ")
95
+ start_indexes.append(index)
96
+ else:
97
+ for can in candidates:
98
+ index = response.lower().rfind(index2ans[can].lower())
99
+ start_indexes.append(index)
100
+ # get the last one
101
+ pred_index = candidates[np.argmax(start_indexes)]
102
+ else: # if only one candidate, use it.
103
+ pred_index = candidates[0]
104
+
105
+ return pred_index
106
+
107
+
108
+ def call_bunny_engine_df(args, sample, model, tokenizer=None, processor=None):
109
+ def deal_with_prompt(input_text):
110
+ qs = input_text
111
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
112
+ return qs
113
+
114
+ prompt = sample['final_input_prompt']
115
+ prompt = deal_with_prompt(prompt)
116
+
117
+ conv = conv_templates[args.conv_mode].copy()
118
+ conv.append_message(conv.roles[0], prompt)
119
+ conv.append_message(conv.roles[1], None)
120
+ prompt = conv.get_prompt()
121
+
122
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
123
+
124
+ image = sample['image']
125
+ if image is not None:
126
+ output_ids = model.generate(
127
+ input_ids,
128
+ images=image.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
129
+ do_sample=False,
130
+ temperature=0,
131
+ top_p=None,
132
+ # num_beams=5,
133
+ max_new_tokens=128,
134
+ use_cache=True)
135
+
136
+ input_token_len = input_ids.shape[1]
137
+ # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
138
+ # if n_diff_input_output > 0:
139
+ # print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
140
+ response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
141
+ else: # multiple images actually
142
+ if sample['question_type'] == 'multiple-choice':
143
+ all_choices = sample['all_choices']
144
+ response = random.choice(all_choices)
145
+ else:
146
+ response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS'
147
+
148
+ return response
149
+
150
+
151
+ def load_yaml(file_path):
152
+ with open(file_path, 'r') as stream:
153
+ try:
154
+ yaml_dict = yaml.safe_load(stream)
155
+ except yaml.YAMLError as exc:
156
+ print(exc)
157
+
158
+ return yaml_dict
159
+
160
+
161
+ def parse_img_path(text):
162
+ matches = re.findall("<img='(.*?)'>", text)
163
+ return matches
164
+
165
+
166
+ def process_single_sample(data):
167
+ question = data['question']
168
+ o_imgs_paths = []
169
+ for option in data['options']:
170
+ current_o_imgs_paths = parse_img_path(option)
171
+ for img_path in current_o_imgs_paths:
172
+ o_imgs_paths.append(img_path)
173
+
174
+ if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
175
+ return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
176
+ 'image': None, 'question_type': data['question_type']}
177
+ else:
178
+ return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
179
+ 'image': data['image_1'], 'question_type': data['question_type']}
180
+
181
+
182
+ # DATA PROCESSING
183
+ def construct_prompt(sample, config):
184
+ question = sample['question']
185
+ options = eval(sample['options'])
186
+ example = ""
187
+ if sample['question_type'] == 'multiple-choice':
188
+ start_chr = 'A'
189
+ prediction_range = []
190
+ index2ans = {}
191
+ for option in options:
192
+ prediction_range.append(start_chr)
193
+ example += f"({start_chr}) {option}\n"
194
+ index2ans[start_chr] = option
195
+ start_chr = chr(ord(start_chr) + 1)
196
+ empty_prompt_sample_structure = config['multi_choice_example_format']
197
+ empty_prompt = empty_prompt_sample_structure.format(question, example)
198
+ res_dict = {}
199
+ res_dict['index2ans'] = index2ans
200
+ res_dict['correct_choice'] = sample['answer']
201
+ res_dict['all_choices'] = prediction_range
202
+ res_dict['empty_prompt'] = empty_prompt
203
+ if config['task_instructions']:
204
+ res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
205
+ else:
206
+ res_dict['final_input_prompt'] = empty_prompt
207
+
208
+ res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
209
+ else:
210
+ empty_prompt_sample_structure = config['short_ans_example_format']
211
+ empty_prompt = empty_prompt_sample_structure.format(question)
212
+ res_dict = {}
213
+ res_dict['empty_prompt'] = empty_prompt
214
+ if config['task_instructions']:
215
+ res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
216
+ else:
217
+ res_dict['final_input_prompt'] = empty_prompt
218
+ res_dict['gt_content'] = sample['answer']
219
+
220
+ res_dict.update(sample)
221
+ return res_dict
222
+
223
+
224
+ def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
225
+ out_samples = dict()
226
+ with torch.no_grad():
227
+ for sample in tqdm(samples):
228
+ if args.small_gpu_usage:
229
+ sample['image'] = sample['image'].cuda()
230
+ response = call_model_engine_fn(args, sample, model, tokenizer, processor)
231
+ if args.small_gpu_usage:
232
+ sample['image'] = sample['image'].cpu()
233
+
234
+ if sample['question_type'] == 'multiple-choice':
235
+ pred_ans = parse_multi_choice_response(response, sample['all_choices'], sample['index2ans'])
236
+ else: # open question
237
+ pred_ans = response
238
+ out_samples[sample['id']] = pred_ans
239
+ return out_samples
240
+
241
+
242
+ def set_seed(seed_value):
243
+ """
244
+ Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.
245
+
246
+ :param seed_value: An integer value to be used as the seed.
247
+ """
248
+ torch.manual_seed(seed_value)
249
+ if torch.cuda.is_available():
250
+ torch.cuda.manual_seed(seed_value)
251
+ torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups
252
+ random.seed(seed_value)
253
+ np.random.seed(seed_value)
254
+ torch.backends.cudnn.deterministic = True
255
+ torch.backends.cudnn.benchmark = False
256
+
257
+
258
+ def main():
259
+ parser = ArgumentParser()
260
+ parser.add_argument('--model-path', type=str, default=None)
261
+ parser.add_argument('--model-base', type=str, default=None)
262
+ parser.add_argument("--model-type", type=str, default=None)
263
+ parser.add_argument("--conv-mode", type=str, default=None)
264
+ parser.add_argument('--data-path', type=str, default=None)
265
+ parser.add_argument('--config-path', type=str, default=None)
266
+ parser.add_argument('--output-path', type=str, default=None)
267
+ parser.add_argument('--split', type=str, default='validation')
268
+ parser.add_argument('--seed', type=int, default=42)
269
+ parser.add_argument("--small-gpu-usage", action="store_true")
270
+
271
+ args = parser.parse_args()
272
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
273
+ set_seed(args.seed)
274
+
275
+ print('bunny_initializing...')
276
+ processor = None
277
+ call_model_engine = call_bunny_engine_df
278
+
279
+ # load config and process to one value
280
+ args.config = load_yaml(args.config_path)
281
+ for key, value in args.config.items():
282
+ if key != 'eval_params' and type(value) == list:
283
+ assert len(value) == 1, 'key {} has more than one value'.format(key)
284
+ args.config[key] = value[0]
285
+
286
+ # run for each subject
287
+ sub_dataset_list = []
288
+ for subject in CAT_SHORT2LONG.values():
289
+ sub_dataset = load_dataset(args.data_path, subject, split=args.split)
290
+ sub_dataset_list.append(sub_dataset)
291
+
292
+ # merge all dataset
293
+ dataset = concatenate_datasets(sub_dataset_list)
294
+
295
+ # load model
296
+ model_path = os.path.expanduser(args.model_path)
297
+ model_name = get_model_name_from_path(model_path)
298
+ tokenizer, model, vis_processors, context_len = load_pretrained_model(model_path, args.model_base, model_name,
299
+ args.model_type)
300
+
301
+ samples = []
302
+ print('Processing MMMU dataset...')
303
+ for sample in tqdm(dataset):
304
+ sample = process_single_sample(sample)
305
+
306
+ sample = construct_prompt(sample, args.config)
307
+ if sample['image']:
308
+ sample['image'] = process_images([sample['image'].convert('RGB')], vis_processors, model.config)[0]
309
+
310
+ if not args.small_gpu_usage:
311
+ sample['image'] = sample['image'].to(device)
312
+
313
+ samples.append(sample)
314
+
315
+ print('Start to evaluate...')
316
+ # run ex
317
+ out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
318
+
319
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
320
+
321
+ with open(args.output_path, 'w') as f:
322
+ json.dump(out_samples, f, indent=4)
323
+
324
+
325
+ if __name__ == '__main__':
326
+ main()
bunny/eval/model_vqa_science.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
+ from bunny.conversation import conv_templates, SeparatorStyle
10
+ from bunny.model.builder import load_pretrained_model
11
+ from bunny.util.utils import disable_torch_init
12
+ from bunny.util.mm_utils import tokenizer_image_token, get_model_name_from_path
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ def eval_model(args):
30
+ # Model
31
+ disable_torch_init()
32
+ model_path = os.path.expanduser(args.model_path)
33
+ model_name = get_model_name_from_path(model_path)
34
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
35
+ args.model_type)
36
+
37
+ questions = json.load(open(os.path.expanduser(args.question_file), "r"))
38
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
39
+ answers_file = os.path.expanduser(args.answers_file)
40
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
41
+ ans_file = open(answers_file, "w")
42
+ for i, line in enumerate(tqdm(questions)):
43
+ idx = line["id"]
44
+ question = line['conversations'][0]
45
+ qs = question['value'].replace('<image>', '').strip()
46
+ cur_prompt = qs
47
+
48
+ if 'image' in line:
49
+ image_file = line["image"]
50
+ image = Image.open(os.path.join(args.image_folder, image_file))
51
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
52
+ images = image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True)
53
+
54
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
55
+ cur_prompt = '<image>' + '\n' + cur_prompt
56
+ else:
57
+ images = None
58
+
59
+ if args.single_pred_prompt:
60
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
61
+ cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
62
+
63
+ conv = conv_templates[args.conv_mode].copy()
64
+ conv.append_message(conv.roles[0], qs)
65
+ conv.append_message(conv.roles[1], None)
66
+ prompt = conv.get_prompt()
67
+
68
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
69
+
70
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
71
+
72
+ with torch.inference_mode():
73
+ output_ids = model.generate(
74
+ input_ids,
75
+ images=images,
76
+ do_sample=True if args.temperature > 0 else False,
77
+ temperature=args.temperature,
78
+ max_new_tokens=1024,
79
+ use_cache=True
80
+ )
81
+
82
+ input_token_len = input_ids.shape[1]
83
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
84
+ if n_diff_input_output > 0:
85
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
86
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
87
+ outputs = outputs.strip()
88
+ if outputs.endswith(stop_str):
89
+ outputs = outputs[:-len(stop_str)]
90
+ outputs = outputs.strip()
91
+
92
+ ans_id = shortuuid.uuid()
93
+ ans_file.write(json.dumps({"question_id": idx,
94
+ "prompt": cur_prompt,
95
+ "text": outputs,
96
+ "answer_id": ans_id,
97
+ "model_id": model_name,
98
+ "metadata": {}}) + "\n")
99
+ ans_file.flush()
100
+ ans_file.close()
101
+
102
+
103
+ if __name__ == "__main__":
104
+ parser = argparse.ArgumentParser()
105
+ parser.add_argument("--model-path", type=str, default=None)
106
+ parser.add_argument("--model-base", type=str, default=None)
107
+ parser.add_argument("--model-type", type=str, default=None)
108
+ parser.add_argument("--image-folder", type=str, default=None)
109
+ parser.add_argument("--question-file", type=str, default=None)
110
+ parser.add_argument("--answers-file", type=str, default=None)
111
+ parser.add_argument("--conv-mode", type=str, default=None)
112
+ parser.add_argument("--num-chunks", type=int, default=1)
113
+ parser.add_argument("--chunk-idx", type=int, default=0)
114
+ parser.add_argument("--temperature", type=float, default=0.2)
115
+ parser.add_argument("--single-pred-prompt", action="store_true")
116
+
117
+ args = parser.parse_args()
118
+
119
+ eval_model(args)
bunny/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .language_model.bunny_phi import BunnyPhiForCausalLM, BunnyPhiConfig
2
+ from .language_model.bunny_stablelm import BunnyStableLMForCausalLM, BunnyStableLMConfig
3
+ from .language_model.bunny_qwen import BunnyQwen2ForCausalLM, BunnyQwen2Config
4
+ from .language_model.bunny_minicpm import BunnyMiniCPMForCausalLM, BunnyMiniCPMConfig
5
+ from .language_model.bunny_llama import BunnyLlamaForCausalLM, BunnyLlamaConfig
6
+ from .language_model.bunny_phi3 import BunnyPhi3ForCausalLM, BunnyPhi3Config
bunny/model/builder.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import torch
4
+
5
+ from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig, logging
6
+
7
+ logging.set_verbosity_error()
8
+ warnings.filterwarnings('ignore')
9
+
10
+ from bunny.model import *
11
+
12
+
13
+ def load_pretrained_model(model_path, model_base, model_name, model_type, load_8bit=False, load_4bit=False,
14
+ device_map="auto", device="cuda", **kwargs):
15
+ if model_type not in {'phi-1.5', 'phi-2', 'phi-3', 'stablelm-2', 'qwen1.5-1.8b', 'minicpm', 'llama3-8b'}:
16
+ raise ValueError(f"Unknown Model Type {model_type}")
17
+
18
+ kwargs = {"device_map": device_map, **kwargs}
19
+
20
+ if device != "cuda":
21
+ kwargs['device_map'] = {"": device}
22
+
23
+ if load_8bit:
24
+ kwargs['load_in_8bit'] = True
25
+ elif load_4bit:
26
+ kwargs['load_in_4bit'] = True
27
+ kwargs['quantization_config'] = BitsAndBytesConfig(
28
+ load_in_4bit=True,
29
+ bnb_4bit_compute_dtype=torch.float16,
30
+ bnb_4bit_use_double_quant=True,
31
+ bnb_4bit_quant_type='nf4'
32
+ )
33
+ else:
34
+ kwargs['torch_dtype'] = torch.float16
35
+
36
+ # Load Bunny model
37
+ if 'lora' in model_name.lower() and model_base is None:
38
+ warnings.warn(
39
+ 'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.')
40
+ if 'lora' in model_name.lower() and model_base is not None:
41
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
42
+
43
+ print('Loading Bunny from base model...')
44
+ if model_type == 'phi-1.5' or model_type == 'phi-2':
45
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
46
+ model = BunnyPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
47
+ config=lora_cfg_pretrained, **kwargs)
48
+ elif model_type == 'phi-3':
49
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
50
+ model = BunnyPhi3ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
51
+ config=lora_cfg_pretrained, **kwargs)
52
+ elif model_type == 'stablelm-2':
53
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True, trust_remote_code=True)
54
+ model = BunnyStableLMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
55
+ config=lora_cfg_pretrained, **kwargs)
56
+ elif model_type == 'qwen1.5-1.8b':
57
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
58
+ model = BunnyQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained,
59
+ **kwargs)
60
+ elif model_type == 'minicpm':
61
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
62
+ model = BunnyMiniCPMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
63
+ config=lora_cfg_pretrained,
64
+ **kwargs)
65
+ elif model_type == 'llama3-8b':
66
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
67
+ model = BunnyLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
68
+ config=lora_cfg_pretrained,
69
+ **kwargs)
70
+
71
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
72
+ if model.lm_head.weight.shape[0] != token_num:
73
+ model.lm_head.weight = torch.nn.Parameter(
74
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
75
+ model.model.embed_tokens.weight = torch.nn.Parameter(
76
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
77
+
78
+ print('Loading additional Bunny weights...')
79
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
80
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
81
+ else:
82
+ # this is probably from HF Hub
83
+ from huggingface_hub import hf_hub_download
84
+ def load_from_hf(repo_id, filename, subfolder=None):
85
+ cache_file = hf_hub_download(
86
+ repo_id=repo_id,
87
+ filename=filename,
88
+ subfolder=subfolder)
89
+ return torch.load(cache_file, map_location='cpu')
90
+
91
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
92
+
93
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
94
+ non_lora_trainables.items()}
95
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
96
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in
97
+ non_lora_trainables.items()}
98
+ model.load_state_dict(non_lora_trainables, strict=False)
99
+
100
+ from peft import PeftModel
101
+ print('Loading LoRA weights...')
102
+ model = PeftModel.from_pretrained(model, model_path)
103
+ print('Merging LoRA weights...')
104
+ model = model.merge_and_unload()
105
+ print('Model is loaded...')
106
+ elif model_base is not None:
107
+ # this may be mm projector only
108
+ print('Loading Bunny from base model...')
109
+
110
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
111
+ if model_type == 'phi-1.5' or model_type == 'phi-2':
112
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
113
+ model = BunnyPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
114
+ config=cfg_pretrained, **kwargs)
115
+ elif model_type == 'phi-3':
116
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
117
+ model = BunnyPhi3ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
118
+ config=cfg_pretrained, **kwargs)
119
+ elif model_type == 'stablelm-2':
120
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True, trust_remote_code=True)
121
+ model = BunnyStableLMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
122
+ config=cfg_pretrained, **kwargs)
123
+ elif model_type == 'qwen1.5-1.8b':
124
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
125
+ model = BunnyQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
126
+ **kwargs)
127
+ elif model_type == 'minicpm':
128
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
129
+ model = BunnyMiniCPMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
130
+ **kwargs)
131
+ elif model_type == 'llama3-8b':
132
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
133
+ model = BunnyLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
134
+ **kwargs)
135
+
136
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
137
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
138
+ model.load_state_dict(mm_projector_weights, strict=False)
139
+ else:
140
+ if model_type == 'phi-1.5' or model_type == 'phi-2':
141
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
142
+ model = BunnyPhiForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
143
+ elif model_type == 'phi-3':
144
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
145
+ model = BunnyPhi3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
146
+ elif model_type == 'stablelm-2':
147
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
148
+ model = BunnyStableLMForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
149
+ elif model_type == 'qwen1.5-1.8b':
150
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
151
+ model = BunnyQwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
152
+ elif model_type == 'minicpm':
153
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
154
+ model = BunnyMiniCPMForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
155
+ elif model_type == 'llama3-8b':
156
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
157
+ model = BunnyLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
158
+
159
+ model.resize_token_embeddings(len(tokenizer))
160
+
161
+ vision_tower = model.get_vision_tower()
162
+ if not vision_tower.is_loaded:
163
+ vision_tower.load_model()
164
+
165
+ # if getattr(model.config, "unfreeze_vision_tower", False):
166
+ # if 'lora' in model_name.lower():
167
+ # assert model_base is not None
168
+ # vision_non_lora_trainables = {k[19:]: v for k, v in non_lora_trainables.items() if
169
+ # k.startswith('model.vision_tower.')}
170
+ # vision_tower.load_state_dict(vision_non_lora_trainables, strict=False)
171
+ # else:
172
+ # assert model_base is None
173
+ # from safetensors.torch import load_file
174
+ # vision_weights = {}
175
+ # for file_name in os.listdir(model_path):
176
+ # if file_name.endswith('safetensors'):
177
+ # vision_weights.update(
178
+ # {k[19:]: v for k, v in load_file(os.path.join(model_path, file_name)).items() if
179
+ # k.startswith('model.vision_tower.')})
180
+ # vision_tower.load_state_dict(vision_weights, strict=True)
181
+
182
+ vision_tower.to(device=device, dtype=torch.float16)
183
+ image_processor = vision_tower.image_processor
184
+
185
+ if hasattr(model.config, "max_sequence_length"):
186
+ context_len = model.config.max_sequence_length
187
+ else:
188
+ context_len = 2048
189
+
190
+ if model_type == 'llama3-8b':
191
+ tokenizer.eos_token_id = 128001
192
+ model.generation_config.pad_token_id = tokenizer.eos_token_id
193
+
194
+ if model.generation_config.pad_token_id is None:
195
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
196
+
197
+ return tokenizer, model, image_processor, context_len
bunny/model/bunny_arch.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+
5
+ from .multimodal_encoder.builder import build_vision_tower
6
+ from .multimodal_projector.builder import build_vision_projector
7
+
8
+ from bunny.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
9
+
10
+
11
+ class BunnyMetaModel:
12
+
13
+ def __init__(self, config):
14
+ super(BunnyMetaModel, self).__init__(config)
15
+
16
+ if hasattr(config, "mm_vision_tower"):
17
+ self.vision_tower = build_vision_tower(config, delay_load=False)
18
+ # self.vision_tower = build_vision_tower(config, delay_load=not getattr(config, 'continuous_training', False))
19
+ if getattr(config, 'continuous_training', False):
20
+ config.continuous_training = False
21
+ self.mm_projector = build_vision_projector(config)
22
+
23
+ def get_vision_tower(self):
24
+ vision_tower = getattr(self, 'vision_tower', None)
25
+ if type(vision_tower) is list:
26
+ vision_tower = vision_tower[0]
27
+ return vision_tower
28
+
29
+ def initialize_vision_modules(self, model_args):
30
+ vision_tower = model_args.vision_tower
31
+
32
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
33
+
34
+ self.config.mm_vision_tower = vision_tower
35
+
36
+ if self.get_vision_tower() is None:
37
+ vision_tower = build_vision_tower(model_args)
38
+ self.vision_tower = vision_tower
39
+ else:
40
+ vision_tower = self.vision_tower
41
+ vision_tower.load_model()
42
+
43
+ self.config.use_mm_proj = True
44
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type')
45
+ self.config.mm_hidden_size = vision_tower.hidden_size
46
+
47
+ if getattr(self, 'mm_projector', None) is None:
48
+ self.mm_projector = build_vision_projector(self.config)
49
+ else:
50
+ # In case it is frozen by LoRA
51
+ for p in self.mm_projector.parameters():
52
+ p.requires_grad = True
53
+
54
+ if pretrain_mm_mlp_adapter is not None:
55
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
56
+
57
+ def get_w(weights, keyword):
58
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
59
+
60
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
61
+
62
+
63
+ class BunnyMetaForCausalLM(ABC):
64
+
65
+ @abstractmethod
66
+ def get_model(self):
67
+ pass
68
+
69
+ def get_vision_tower(self):
70
+ return self.get_model().get_vision_tower()
71
+
72
+ def encode_images(self, images):
73
+ image_features = self.get_model().get_vision_tower()(images)
74
+ image_features = self.get_model().mm_projector(image_features)
75
+ return image_features
76
+
77
+ def prepare_inputs_labels_for_multimodal(
78
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images
79
+ ):
80
+ vision_tower = self.get_vision_tower()
81
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
82
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
83
+ 1] == 1:
84
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
85
+ attention_mask = torch.cat((attention_mask, torch.ones(
86
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
87
+ dtype=attention_mask.dtype,
88
+ device=attention_mask.device
89
+ )), dim=1)
90
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
91
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
92
+
93
+ if type(images) is list or images.ndim == 5:
94
+ concat_images = torch.cat([image for image in images], dim=0)
95
+ image_features = self.encode_images(concat_images)
96
+ split_sizes = [image.shape[0] for image in images]
97
+ image_features = torch.split(image_features, split_sizes, dim=0)
98
+ image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
99
+ else:
100
+ image_features = self.encode_images(images).to(self.device)
101
+
102
+ # Let's just add dummy tensors if they do not exist,
103
+ # it is a headache to deal with None all the time.
104
+ # But it is not ideal, and if you have a better idea,
105
+ # please open an issue / submit a PR, thanks.
106
+ _labels = labels
107
+ _position_ids = position_ids
108
+ _attention_mask = attention_mask
109
+ if attention_mask is None:
110
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
111
+ else:
112
+ attention_mask = attention_mask.bool()
113
+ if position_ids is None:
114
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
115
+ if labels is None:
116
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
117
+
118
+ input_ids_temp = input_ids # points to the actual input_ids tensor
119
+
120
+ # remove the padding using attention_mask -- TODO: double check
121
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
122
+ zip(input_ids, attention_mask)]
123
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
124
+
125
+ # -- TODO: better implementation?
126
+ # replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
127
+ input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
128
+
129
+ new_input_embeds = []
130
+ new_labels = []
131
+ cur_image_idx = 0
132
+ for batch_idx, cur_input_ids in enumerate(input_ids):
133
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
134
+ if num_images == 0:
135
+ cur_image_features = image_features[cur_image_idx]
136
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
137
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
138
+ new_input_embeds.append(cur_input_embeds)
139
+ new_labels.append(labels[batch_idx])
140
+ cur_image_idx += 1
141
+ continue
142
+
143
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
144
+ cur_input_ids.shape[0]]
145
+ cur_input_ids_noim = []
146
+ cur_labels = labels[batch_idx]
147
+ cur_labels_noim = []
148
+ for i in range(len(image_token_indices) - 1):
149
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
150
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
151
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
152
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
153
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
154
+ cur_new_input_embeds = []
155
+ cur_new_labels = []
156
+
157
+ for i in range(num_images + 1):
158
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
159
+ cur_new_labels.append(cur_labels_noim[i])
160
+ if i < num_images:
161
+ cur_image_features = image_features[cur_image_idx]
162
+ cur_image_idx += 1
163
+ cur_new_input_embeds.append(cur_image_features)
164
+ cur_new_labels.append(
165
+ torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device,
166
+ dtype=cur_labels.dtype))
167
+
168
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
169
+ cur_new_labels = torch.cat(cur_new_labels)
170
+
171
+ new_input_embeds.append(cur_new_input_embeds)
172
+ new_labels.append(cur_new_labels)
173
+
174
+ # Truncate sequences to max length as image embeddings can make the sequence longer
175
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
176
+ if tokenizer_model_max_length is not None:
177
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
178
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
179
+
180
+ # Combine them
181
+ max_len = max(x.shape[0] for x in new_input_embeds)
182
+ batch_size = len(new_input_embeds)
183
+
184
+ new_input_embeds_padded = []
185
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype,
186
+ device=new_labels[0].device)
187
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
188
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
189
+
190
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
191
+ cur_len = cur_new_embed.shape[0]
192
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
193
+ new_input_embeds_padded.append(torch.cat((
194
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
195
+ device=cur_new_embed.device),
196
+ cur_new_embed
197
+ ), dim=0))
198
+ if cur_len > 0:
199
+ new_labels_padded[i, -cur_len:] = cur_new_labels
200
+ attention_mask[i, -cur_len:] = True
201
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype,
202
+ device=position_ids.device)
203
+ else:
204
+ new_input_embeds_padded.append(torch.cat((
205
+ cur_new_embed,
206
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
207
+ device=cur_new_embed.device)
208
+ ), dim=0))
209
+ if cur_len > 0:
210
+ new_labels_padded[i, :cur_len] = cur_new_labels
211
+ attention_mask[i, :cur_len] = True
212
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype,
213
+ device=position_ids.device)
214
+
215
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
216
+
217
+ if _labels is None:
218
+ new_labels = None
219
+ else:
220
+ new_labels = new_labels_padded
221
+
222
+ if _attention_mask is None:
223
+ attention_mask = None
224
+ else:
225
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
226
+
227
+ if _position_ids is None:
228
+ position_ids = None
229
+
230
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
bunny/model/language_model/bunny_llama.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from .llama import LlamaModel, LlamaConfig, LlamaForCausalLM
8
+
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
+
13
+
14
+ class BunnyLlamaConfig(LlamaConfig):
15
+ model_type = "bunny-llama"
16
+
17
+
18
+ class BunnyLlamaModel(BunnyMetaModel, LlamaModel):
19
+ config_class = BunnyLlamaConfig
20
+
21
+ def __init__(self, config: LlamaConfig):
22
+ super(BunnyLlamaModel, self).__init__(config)
23
+
24
+
25
+ class BunnyLlamaForCausalLM(LlamaForCausalLM, BunnyMetaForCausalLM):
26
+ config_class = BunnyLlamaConfig
27
+
28
+ def __init__(self, config):
29
+ super(LlamaForCausalLM, self).__init__(config)
30
+ self.model = BunnyLlamaModel(config)
31
+ self.vocab_size = config.vocab_size
32
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
+
34
+ # Initialize weights and apply final processing
35
+ self.post_init()
36
+
37
+ def get_model(self):
38
+ return self.model
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: torch.LongTensor = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ position_ids: Optional[torch.LongTensor] = None,
45
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
46
+ inputs_embeds: Optional[torch.FloatTensor] = None,
47
+ labels: Optional[torch.LongTensor] = None,
48
+ use_cache: Optional[bool] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ images: Optional[torch.FloatTensor] = None,
52
+ return_dict: Optional[bool] = None,
53
+ cache_position: Optional[torch.LongTensor] = None,
54
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
55
+ if inputs_embeds is None:
56
+ (
57
+ input_ids,
58
+ position_ids,
59
+ attention_mask,
60
+ past_key_values,
61
+ inputs_embeds,
62
+ labels
63
+ ) = self.prepare_inputs_labels_for_multimodal(
64
+ input_ids,
65
+ position_ids,
66
+ attention_mask,
67
+ past_key_values,
68
+ labels,
69
+ images
70
+ )
71
+
72
+ return super().forward(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ labels=labels,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict,
83
+ cache_position=None
84
+ )
85
+
86
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
87
+ **kwargs):
88
+ images = kwargs.pop("images", None)
89
+
90
+ _inputs = super().prepare_inputs_for_generation(
91
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
92
+ **kwargs
93
+ )
94
+
95
+ if images is not None:
96
+ _inputs['images'] = images
97
+
98
+ return _inputs
99
+
100
+
101
+ AutoConfig.register("bunny-llama", BunnyLlamaConfig)
102
+ AutoModelForCausalLM.register(BunnyLlamaConfig, BunnyLlamaForCausalLM)
bunny/model/language_model/bunny_minicpm.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from bunny.model.language_model.minicpm.modeling_minicpm import MiniCPMModel, MiniCPMForCausalLM
8
+ from bunny.model.language_model.minicpm.configuration_minicpm import MiniCPMConfig
9
+
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+
12
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
13
+
14
+
15
+ class BunnyMiniCPMConfig(MiniCPMConfig):
16
+ model_type = "bunny-minicpm"
17
+
18
+
19
+ class BunnyMiniCPMModel(BunnyMetaModel, MiniCPMModel):
20
+ config_class = BunnyMiniCPMConfig
21
+
22
+ def __init__(self, config: MiniCPMConfig):
23
+ super(BunnyMiniCPMModel, self).__init__(config)
24
+
25
+
26
+ class BunnyMiniCPMForCausalLM(MiniCPMForCausalLM, BunnyMetaForCausalLM):
27
+ config_class = BunnyMiniCPMConfig
28
+
29
+ def __init__(self, config):
30
+ super(MiniCPMForCausalLM, self).__init__(config)
31
+ self.model = BunnyMiniCPMModel(config)
32
+ self.vocab_size = config.vocab_size
33
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
+
35
+ # Initialize weights and apply final processing
36
+ self.post_init()
37
+
38
+ def get_model(self):
39
+ return self.model
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.LongTensor = None,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ position_ids: Optional[torch.LongTensor] = None,
46
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
47
+ inputs_embeds: Optional[torch.FloatTensor] = None,
48
+ labels: Optional[torch.LongTensor] = None,
49
+ use_cache: Optional[bool] = None,
50
+ output_attentions: Optional[bool] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ images: Optional[torch.FloatTensor] = None,
53
+ return_dict: Optional[bool] = None,
54
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
55
+
56
+ if inputs_embeds is None:
57
+ (
58
+ input_ids,
59
+ position_ids,
60
+ attention_mask,
61
+ past_key_values,
62
+ inputs_embeds,
63
+ labels
64
+ ) = self.prepare_inputs_labels_for_multimodal(
65
+ input_ids,
66
+ position_ids,
67
+ attention_mask,
68
+ past_key_values,
69
+ labels,
70
+ images
71
+ )
72
+ if inputs_embeds is not None:
73
+ inputs_embeds *= self.get_model().config.scale_emb
74
+
75
+ return super().forward(
76
+ input_ids=input_ids,
77
+ attention_mask=attention_mask,
78
+ position_ids=position_ids,
79
+ past_key_values=past_key_values,
80
+ inputs_embeds=inputs_embeds,
81
+ labels=labels,
82
+ use_cache=use_cache,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict
86
+ )
87
+
88
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
89
+ **kwargs):
90
+ images = kwargs.pop("images", None)
91
+
92
+ _inputs = super().prepare_inputs_for_generation(
93
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
94
+ **kwargs
95
+ )
96
+
97
+ if images is not None:
98
+ _inputs['images'] = images
99
+ return _inputs
100
+
101
+
102
+ AutoConfig.register("bunny-minicpm", BunnyMiniCPMConfig)
103
+ AutoModelForCausalLM.register(BunnyMiniCPMConfig, BunnyMiniCPMForCausalLM)
bunny/model/language_model/bunny_phi.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from .phi import PhiModel, PhiConfig, PhiForCausalLM
8
+
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
+
13
+
14
+ class BunnyPhiConfig(PhiConfig):
15
+ model_type = "bunny-phi"
16
+
17
+
18
+ class BunnyPhiModel(BunnyMetaModel, PhiModel):
19
+ config_class = BunnyPhiConfig
20
+
21
+ def __init__(self, config: PhiConfig):
22
+ super(BunnyPhiModel, self).__init__(config)
23
+
24
+
25
+ class BunnyPhiForCausalLM(PhiForCausalLM, BunnyMetaForCausalLM):
26
+ config_class = BunnyPhiConfig
27
+
28
+ def __init__(self, config):
29
+ super(PhiForCausalLM, self).__init__(config)
30
+ self.model = BunnyPhiModel(config)
31
+ self.vocab_size = config.vocab_size
32
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
+
34
+ # Initialize weights and apply final processing
35
+ self.post_init()
36
+
37
+ def get_model(self):
38
+ return self.model
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: torch.LongTensor = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ position_ids: Optional[torch.LongTensor] = None,
45
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
46
+ inputs_embeds: Optional[torch.FloatTensor] = None,
47
+ labels: Optional[torch.LongTensor] = None,
48
+ use_cache: Optional[bool] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ images: Optional[torch.FloatTensor] = None,
52
+ return_dict: Optional[bool] = None,
53
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
54
+
55
+ if inputs_embeds is None:
56
+ (
57
+ input_ids,
58
+ position_ids,
59
+ attention_mask,
60
+ past_key_values,
61
+ inputs_embeds,
62
+ labels
63
+ ) = self.prepare_inputs_labels_for_multimodal(
64
+ input_ids,
65
+ position_ids,
66
+ attention_mask,
67
+ past_key_values,
68
+ labels,
69
+ images
70
+ )
71
+
72
+ return super().forward(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ labels=labels,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict
83
+ )
84
+
85
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
+ **kwargs):
87
+ images = kwargs.pop("images", None)
88
+
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
+ **kwargs
92
+ )
93
+
94
+ if images is not None:
95
+ _inputs['images'] = images
96
+ return _inputs
97
+
98
+
99
+ AutoConfig.register("bunny-phi", BunnyPhiConfig)
100
+ AutoModelForCausalLM.register(BunnyPhiConfig, BunnyPhiForCausalLM)
bunny/model/language_model/bunny_phi3.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from .phi3 import Phi3Model, Phi3Config, Phi3ForCausalLM
8
+
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
+
13
+
14
+ class BunnyPhi3Config(Phi3Config):
15
+ model_type = "bunny-phi3"
16
+
17
+
18
+ class BunnyPhi3Model(BunnyMetaModel, Phi3Model):
19
+ config_class = BunnyPhi3Config
20
+
21
+ def __init__(self, config: Phi3Config):
22
+ super(BunnyPhi3Model, self).__init__(config)
23
+
24
+
25
+ class BunnyPhi3ForCausalLM(Phi3ForCausalLM, BunnyMetaForCausalLM):
26
+ config_class = BunnyPhi3Config
27
+
28
+ def __init__(self, config):
29
+ super(Phi3ForCausalLM, self).__init__(config)
30
+ self.model = BunnyPhi3Model(config)
31
+ self.vocab_size = config.vocab_size
32
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
+
34
+ # Initialize weights and apply final processing
35
+ self.post_init()
36
+
37
+ def get_model(self):
38
+ return self.model
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: torch.LongTensor = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ position_ids: Optional[torch.LongTensor] = None,
45
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
46
+ inputs_embeds: Optional[torch.FloatTensor] = None,
47
+ labels: Optional[torch.LongTensor] = None,
48
+ use_cache: Optional[bool] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ images: Optional[torch.FloatTensor] = None,
52
+ return_dict: Optional[bool] = None,
53
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
54
+
55
+ if inputs_embeds is None:
56
+ (
57
+ input_ids,
58
+ position_ids,
59
+ attention_mask,
60
+ past_key_values,
61
+ inputs_embeds,
62
+ labels
63
+ ) = self.prepare_inputs_labels_for_multimodal(
64
+ input_ids,
65
+ position_ids,
66
+ attention_mask,
67
+ past_key_values,
68
+ labels,
69
+ images
70
+ )
71
+
72
+ return super().forward(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ labels=labels,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict
83
+ )
84
+
85
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
+ **kwargs):
87
+ images = kwargs.pop("images", None)
88
+
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
+ **kwargs
92
+ )
93
+
94
+ if images is not None:
95
+ _inputs['images'] = images
96
+ return _inputs
97
+
98
+
99
+ AutoConfig.register("bunny-phi3", BunnyPhi3Config)
100
+ AutoModelForCausalLM.register(BunnyPhi3Config, BunnyPhi3ForCausalLM)
bunny/model/language_model/bunny_qwen.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from .qwen2 import Qwen2Model, Qwen2Config, Qwen2ForCausalLM
8
+
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+ from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
+
13
+
14
+ class BunnyQwen2Config(Qwen2Config):
15
+ model_type = "bunny-qwen2"
16
+
17
+
18
+ class BunnyQwen2Model(BunnyMetaModel, Qwen2Model):
19
+ config_class = BunnyQwen2Config
20
+
21
+ def __init__(self, config: Qwen2Config):
22
+ super(BunnyQwen2Model, self).__init__(config)
23
+
24
+
25
+ class BunnyQwen2ForCausalLM(Qwen2ForCausalLM, BunnyMetaForCausalLM):
26
+ config_class = BunnyQwen2Config
27
+
28
+ def __init__(self, config):
29
+ super(Qwen2ForCausalLM, self).__init__(config)
30
+ self.model = BunnyQwen2Model(config)
31
+ self.vocab_size = config.vocab_size
32
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
+
34
+ # Initialize weights and apply final processing
35
+ self.post_init()
36
+
37
+ def get_model(self):
38
+ return self.model
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: torch.LongTensor = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ position_ids: Optional[torch.LongTensor] = None,
45
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
46
+ inputs_embeds: Optional[torch.FloatTensor] = None,
47
+ labels: Optional[torch.LongTensor] = None,
48
+ use_cache: Optional[bool] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ images: Optional[torch.FloatTensor] = None,
52
+ return_dict: Optional[bool] = None,
53
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
54
+
55
+ if inputs_embeds is None:
56
+ (
57
+ input_ids,
58
+ position_ids,
59
+ attention_mask,
60
+ past_key_values,
61
+ inputs_embeds,
62
+ labels
63
+ ) = self.prepare_inputs_labels_for_multimodal(
64
+ input_ids,
65
+ position_ids,
66
+ attention_mask,
67
+ past_key_values,
68
+ labels,
69
+ images
70
+ )
71
+
72
+ return super().forward(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ labels=labels,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict
83
+ )
84
+
85
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
+ **kwargs):
87
+ images = kwargs.pop("images", None)
88
+
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
+ **kwargs
92
+ )
93
+
94
+ if images is not None:
95
+ _inputs['images'] = images
96
+ return _inputs
97
+
98
+
99
+ AutoConfig.register("bunny-qwen2", BunnyQwen2Config)
100
+ AutoModelForCausalLM.register(BunnyQwen2Config, BunnyQwen2ForCausalLM)
bunny/model/language_model/bunny_stablelm.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+
7
+ from bunny.model.language_model.stable_lm.modeling_stablelm_epoch import StableLMEpochModel, StableLMEpochConfig, \
8
+ StableLMEpochForCausalLM
9
+
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+
12
+ from bunny.model.bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
13
+
14
+
15
+ class BunnyStableLMConfig(StableLMEpochConfig):
16
+ model_type = "bunny-stablelm"
17
+
18
+
19
+ class BunnyStableLMModel(BunnyMetaModel, StableLMEpochModel):
20
+ config_class = BunnyStableLMConfig
21
+
22
+ def __init__(self, config: StableLMEpochConfig):
23
+ super(BunnyStableLMModel, self).__init__(config)
24
+
25
+
26
+ class BunnyStableLMForCausalLM(StableLMEpochForCausalLM, BunnyMetaForCausalLM):
27
+ config_class = BunnyStableLMConfig
28
+
29
+ def __init__(self, config):
30
+ super(StableLMEpochForCausalLM, self).__init__(config)
31
+ self.model = BunnyStableLMModel(config)
32
+ self.vocab_size = config.vocab_size
33
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
+
35
+ # Initialize weights and apply final processing
36
+ self.post_init()
37
+
38
+ def get_model(self):
39
+ return self.model
40
+
41
+ def forward(
42
+ self,
43
+ input_ids: torch.LongTensor = None,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ position_ids: Optional[torch.LongTensor] = None,
46
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
47
+ inputs_embeds: Optional[torch.FloatTensor] = None,
48
+ labels: Optional[torch.LongTensor] = None,
49
+ use_cache: Optional[bool] = None,
50
+ output_attentions: Optional[bool] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ images: Optional[torch.FloatTensor] = None,
53
+ return_dict: Optional[bool] = None,
54
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
55
+ if inputs_embeds is None:
56
+ (
57
+ input_ids,
58
+ position_ids,
59
+ attention_mask,
60
+ past_key_values,
61
+ inputs_embeds,
62
+ labels
63
+ ) = self.prepare_inputs_labels_for_multimodal(
64
+ input_ids,
65
+ position_ids,
66
+ attention_mask,
67
+ past_key_values,
68
+ labels,
69
+ images
70
+ )
71
+
72
+ return super().forward(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ labels=labels,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict
83
+ )
84
+
85
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
+ **kwargs):
87
+ images = kwargs.pop("images", None)
88
+
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
+ **kwargs
92
+ )
93
+
94
+ if images is not None:
95
+ _inputs['images'] = images
96
+ return _inputs
97
+
98
+
99
+ AutoConfig.register("bunny-stablelm", BunnyStableLMConfig)
100
+ AutoModelForCausalLM.register(BunnyStableLMConfig, BunnyStableLMForCausalLM)
bunny/model/language_model/llama/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import (
17
+ OptionalDependencyNotAvailable,
18
+ _LazyModule,
19
+ is_flax_available,
20
+ is_sentencepiece_available,
21
+ is_tokenizers_available,
22
+ is_torch_available,
23
+ )
24
+
25
+
26
+ _import_structure = {
27
+ "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"],
28
+ }
29
+
30
+ try:
31
+ if not is_sentencepiece_available():
32
+ raise OptionalDependencyNotAvailable()
33
+ except OptionalDependencyNotAvailable:
34
+ pass
35
+ else:
36
+ _import_structure["tokenization_llama"] = ["LlamaTokenizer"]
37
+
38
+ try:
39
+ if not is_tokenizers_available():
40
+ raise OptionalDependencyNotAvailable()
41
+ except OptionalDependencyNotAvailable:
42
+ pass
43
+ else:
44
+ _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]
45
+
46
+ try:
47
+ if not is_torch_available():
48
+ raise OptionalDependencyNotAvailable()
49
+ except OptionalDependencyNotAvailable:
50
+ pass
51
+ else:
52
+ _import_structure["modeling_llama"] = [
53
+ "LlamaForCausalLM",
54
+ "LlamaModel",
55
+ "LlamaPreTrainedModel",
56
+ "LlamaForSequenceClassification",
57
+ "LlamaForQuestionAnswering",
58
+ ]
59
+
60
+ try:
61
+ if not is_flax_available():
62
+ raise OptionalDependencyNotAvailable()
63
+ except OptionalDependencyNotAvailable:
64
+ pass
65
+ else:
66
+ _import_structure["modeling_flax_llama"] = ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]
67
+
68
+
69
+ if TYPE_CHECKING:
70
+ from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
71
+
72
+ try:
73
+ if not is_sentencepiece_available():
74
+ raise OptionalDependencyNotAvailable()
75
+ except OptionalDependencyNotAvailable:
76
+ pass
77
+ else:
78
+ from .tokenization_llama import LlamaTokenizer
79
+
80
+ try:
81
+ if not is_tokenizers_available():
82
+ raise OptionalDependencyNotAvailable()
83
+ except OptionalDependencyNotAvailable:
84
+ pass
85
+ else:
86
+ from .tokenization_llama_fast import LlamaTokenizerFast
87
+
88
+ try:
89
+ if not is_torch_available():
90
+ raise OptionalDependencyNotAvailable()
91
+ except OptionalDependencyNotAvailable:
92
+ pass
93
+ else:
94
+ from .modeling_llama import (
95
+ LlamaForCausalLM,
96
+ LlamaForQuestionAnswering,
97
+ LlamaForSequenceClassification,
98
+ LlamaModel,
99
+ LlamaPreTrainedModel,
100
+ )
101
+
102
+ try:
103
+ if not is_flax_available():
104
+ raise OptionalDependencyNotAvailable()
105
+ except OptionalDependencyNotAvailable:
106
+ pass
107
+ else:
108
+ from .modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel, FlaxLlamaPreTrainedModel
109
+
110
+
111
+ else:
112
+ import sys
113
+
114
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
bunny/model/language_model/llama/configuration_llama.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ LLaMA model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ # from ..deprecated._archive_maps import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
30
+
31
+
32
+ class LlamaConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of the LLaMA-7B.
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+
42
+ Args:
43
+ vocab_size (`int`, *optional*, defaults to 32000):
44
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
45
+ `inputs_ids` passed when calling [`LlamaModel`]
46
+ hidden_size (`int`, *optional*, defaults to 4096):
47
+ Dimension of the hidden representations.
48
+ intermediate_size (`int`, *optional*, defaults to 11008):
49
+ Dimension of the MLP representations.
50
+ num_hidden_layers (`int`, *optional*, defaults to 32):
51
+ Number of hidden layers in the Transformer decoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 32):
53
+ Number of attention heads for each attention layer in the Transformer decoder.
54
+ num_key_value_heads (`int`, *optional*):
55
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
+ by meanpooling all the original heads within that group. For more details checkout [this
60
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
61
+ `num_attention_heads`.
62
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
63
+ The non-linear activation function (function or string) in the decoder.
64
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
65
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
66
+ Llama 2 up to 4096, CodeLlama up to 16384.
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
70
+ The epsilon used by the rms normalization layers.
71
+ use_cache (`bool`, *optional*, defaults to `True`):
72
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
73
+ relevant if `config.is_decoder=True`.
74
+ pad_token_id (`int`, *optional*):
75
+ Padding token id.
76
+ bos_token_id (`int`, *optional*, defaults to 1):
77
+ Beginning of stream token id.
78
+ eos_token_id (`int`, *optional*, defaults to 2):
79
+ End of stream token id.
80
+ pretraining_tp (`int`, *optional*, defaults to 1):
81
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
82
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
83
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
84
+ issue](https://github.com/pytorch/pytorch/issues/76232).
85
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
86
+ Whether to tie weight embeddings
87
+ rope_theta (`float`, *optional*, defaults to 10000.0):
88
+ The base period of the RoPE embeddings.
89
+ rope_scaling (`Dict`, *optional*):
90
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
91
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
92
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
93
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
94
+ these scaling strategies behave:
95
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
96
+ experimental feature, subject to breaking API changes in future versions.
97
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
98
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
99
+ attention_dropout (`float`, *optional*, defaults to 0.0):
100
+ The dropout ratio for the attention probabilities.
101
+
102
+ ```python
103
+ >>> from transformers import LlamaModel, LlamaConfig
104
+
105
+ >>> # Initializing a LLaMA llama-7b style configuration
106
+ >>> configuration = LlamaConfig()
107
+
108
+ >>> # Initializing a model from the llama-7b style configuration
109
+ >>> model = LlamaModel(configuration)
110
+
111
+ >>> # Accessing the model configuration
112
+ >>> configuration = model.config
113
+ ```"""
114
+
115
+ model_type = "llama"
116
+ keys_to_ignore_at_inference = ["past_key_values"]
117
+
118
+ def __init__(
119
+ self,
120
+ vocab_size=32000,
121
+ hidden_size=4096,
122
+ intermediate_size=11008,
123
+ num_hidden_layers=32,
124
+ num_attention_heads=32,
125
+ num_key_value_heads=None,
126
+ hidden_act="silu",
127
+ max_position_embeddings=2048,
128
+ initializer_range=0.02,
129
+ rms_norm_eps=1e-6,
130
+ use_cache=True,
131
+ pad_token_id=None,
132
+ bos_token_id=1,
133
+ eos_token_id=2,
134
+ pretraining_tp=1,
135
+ tie_word_embeddings=False,
136
+ rope_theta=10000.0,
137
+ rope_scaling=None,
138
+ attention_bias=False,
139
+ attention_dropout=0.0,
140
+ **kwargs,
141
+ ):
142
+ self.vocab_size = vocab_size
143
+ self.max_position_embeddings = max_position_embeddings
144
+ self.hidden_size = hidden_size
145
+ self.intermediate_size = intermediate_size
146
+ self.num_hidden_layers = num_hidden_layers
147
+ self.num_attention_heads = num_attention_heads
148
+
149
+ # for backward compatibility
150
+ if num_key_value_heads is None:
151
+ num_key_value_heads = num_attention_heads
152
+
153
+ self.num_key_value_heads = num_key_value_heads
154
+ self.hidden_act = hidden_act
155
+ self.initializer_range = initializer_range
156
+ self.rms_norm_eps = rms_norm_eps
157
+ self.pretraining_tp = pretraining_tp
158
+ self.use_cache = use_cache
159
+ self.rope_theta = rope_theta
160
+ self.rope_scaling = rope_scaling
161
+ self._rope_scaling_validation()
162
+ self.attention_bias = attention_bias
163
+ self.attention_dropout = attention_dropout
164
+
165
+ super().__init__(
166
+ pad_token_id=pad_token_id,
167
+ bos_token_id=bos_token_id,
168
+ eos_token_id=eos_token_id,
169
+ tie_word_embeddings=tie_word_embeddings,
170
+ **kwargs,
171
+ )
172
+
173
+ def _rope_scaling_validation(self):
174
+ """
175
+ Validate the `rope_scaling` configuration.
176
+ """
177
+ if self.rope_scaling is None:
178
+ return
179
+
180
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
181
+ raise ValueError(
182
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
183
+ )
184
+ rope_scaling_type = self.rope_scaling.get("type", None)
185
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
186
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
187
+ raise ValueError(
188
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
189
+ )
190
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
191
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
bunny/model/language_model/llama/modeling_llama.py ADDED
@@ -0,0 +1,1844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch LLaMA model."""
21
+
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
+ # from transformers.modeling_attn_mask_utils import AttentionMaskConverter
35
+ from dataclasses import dataclass
36
+ @dataclass
37
+ class AttentionMaskConverter:
38
+ """
39
+ A utility attention mask class that allows one to:
40
+ - Create a causal 4d mask
41
+ - Create a causal 4d mask with slided window
42
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
43
+ key_value_length) that can be multiplied with attention scores
44
+
45
+ Examples:
46
+
47
+ ```python
48
+ >>> import torch
49
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
50
+
51
+ >>> converter = AttentionMaskConverter(True)
52
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
53
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
54
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
55
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
56
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
57
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
58
+ ```
59
+
60
+ Parameters:
61
+ is_causal (`bool`):
62
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
63
+
64
+ sliding_window (`int`, *optional*):
65
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
66
+ """
67
+
68
+ is_causal: bool
69
+ sliding_window: int
70
+
71
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
72
+ self.is_causal = is_causal
73
+ self.sliding_window = sliding_window
74
+
75
+ if self.sliding_window is not None and self.sliding_window <= 0:
76
+ raise ValueError(
77
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
78
+ )
79
+
80
+ def to_causal_4d(
81
+ self,
82
+ batch_size: int,
83
+ query_length: int,
84
+ key_value_length: int,
85
+ dtype: torch.dtype,
86
+ device: Union[torch.device, "str"] = "cpu",
87
+ ) -> Optional[torch.Tensor]:
88
+ """
89
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
90
+ bias to upper right hand triangular matrix (causal mask).
91
+ """
92
+ if not self.is_causal:
93
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
94
+
95
+ # If shape is not cached, create a new causal mask and cache it
96
+ input_shape = (batch_size, query_length)
97
+ past_key_values_length = key_value_length - query_length
98
+
99
+ # create causal mask
100
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
101
+ causal_4d_mask = None
102
+ if input_shape[-1] > 1 or self.sliding_window is not None:
103
+ causal_4d_mask = self._make_causal_mask(
104
+ input_shape,
105
+ dtype,
106
+ device=device,
107
+ past_key_values_length=past_key_values_length,
108
+ sliding_window=self.sliding_window,
109
+ )
110
+
111
+ return causal_4d_mask
112
+
113
+ def to_4d(
114
+ self,
115
+ attention_mask_2d: torch.Tensor,
116
+ query_length: int,
117
+ dtype: torch.dtype,
118
+ key_value_length: Optional[int] = None,
119
+ ) -> torch.Tensor:
120
+ """
121
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
122
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
123
+ causal, a causal mask will be added.
124
+ """
125
+ input_shape = (attention_mask_2d.shape[0], query_length)
126
+
127
+ # create causal mask
128
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
129
+ causal_4d_mask = None
130
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
131
+ if key_value_length is None:
132
+ raise ValueError(
133
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
134
+ )
135
+
136
+ past_key_values_length = key_value_length - query_length
137
+ causal_4d_mask = self._make_causal_mask(
138
+ input_shape,
139
+ dtype,
140
+ device=attention_mask_2d.device,
141
+ past_key_values_length=past_key_values_length,
142
+ sliding_window=self.sliding_window,
143
+ )
144
+ elif self.sliding_window is not None:
145
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
146
+
147
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
148
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
149
+ attention_mask_2d.device
150
+ )
151
+
152
+ if causal_4d_mask is not None:
153
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
154
+
155
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
156
+ expanded_4d_mask = expanded_attn_mask
157
+
158
+ return expanded_4d_mask
159
+
160
+ @staticmethod
161
+ def _make_causal_mask(
162
+ input_ids_shape: torch.Size,
163
+ dtype: torch.dtype,
164
+ device: torch.device,
165
+ past_key_values_length: int = 0,
166
+ sliding_window: Optional[int] = None,
167
+ ):
168
+ """
169
+ Make causal mask used for bi-directional self-attention.
170
+ """
171
+ bsz, tgt_len = input_ids_shape
172
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
173
+ mask_cond = torch.arange(mask.size(-1), device=device)
174
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
175
+
176
+ mask = mask.to(dtype)
177
+
178
+ if past_key_values_length > 0:
179
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
180
+
181
+ # add lower triangular sliding window mask if necessary
182
+ if sliding_window is not None:
183
+ diagonal = past_key_values_length - sliding_window - 1
184
+
185
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
186
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
187
+
188
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
189
+
190
+ @staticmethod
191
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
192
+ """
193
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
194
+ """
195
+ bsz, src_len = mask.size()
196
+ tgt_len = tgt_len if tgt_len is not None else src_len
197
+
198
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
199
+
200
+ inverted_mask = 1.0 - expanded_mask
201
+
202
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
203
+
204
+ @staticmethod
205
+ def _unmask_unattended(
206
+ expanded_mask: torch.FloatTensor,
207
+ min_dtype: float,
208
+ ):
209
+ # fmt: off
210
+ """
211
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
212
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
213
+ Details: https://github.com/pytorch/pytorch/issues/110213
214
+
215
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
216
+ `attention_mask` is [bsz, src_seq_len].
217
+
218
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
219
+
220
+ For example, if `expanded_mask` is (e.g. here left-padding case)
221
+ ```
222
+ [[[[0, 0, 0],
223
+ [0, 0, 0],
224
+ [0, 0, 1]]],
225
+ [[[1, 0, 0],
226
+ [1, 1, 0],
227
+ [1, 1, 1]]],
228
+ [[[0, 0, 0],
229
+ [0, 1, 0],
230
+ [0, 1, 1]]]]
231
+ ```
232
+ then the modified `expanded_mask` will be
233
+ ```
234
+ [[[[1, 1, 1], <-- modified
235
+ [1, 1, 1], <-- modified
236
+ [0, 0, 1]]],
237
+ [[[1, 0, 0],
238
+ [1, 1, 0],
239
+ [1, 1, 1]]],
240
+ [[[1, 1, 1], <-- modified
241
+ [0, 1, 0],
242
+ [0, 1, 1]]]]
243
+ ```
244
+ """
245
+ # fmt: on
246
+ if expanded_mask.dtype == torch.bool:
247
+ raise ValueError(
248
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
249
+ )
250
+
251
+ return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
252
+
253
+ @staticmethod
254
+ def _ignore_causal_mask_sdpa(
255
+ attention_mask: Optional[torch.Tensor],
256
+ inputs_embeds: torch.Tensor,
257
+ past_key_values_length: int,
258
+ sliding_window: Optional[int] = None,
259
+ ) -> bool:
260
+ """
261
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
262
+
263
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
264
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
265
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
266
+ """
267
+
268
+ batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
269
+ key_value_length = query_length + past_key_values_length
270
+
271
+ is_tracing = (
272
+ torch.jit.is_tracing()
273
+ or isinstance(inputs_embeds, torch.fx.Proxy)
274
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
275
+ )
276
+
277
+ ignore_causal_mask = False
278
+
279
+ if attention_mask is None:
280
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
281
+ # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
282
+ # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag.
283
+ #
284
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
285
+ if (
286
+ not is_tracing
287
+ and (query_length == 1 or key_value_length == query_length)
288
+ and (sliding_window is None or key_value_length < sliding_window)
289
+ ):
290
+ ignore_causal_mask = True
291
+ elif sliding_window is None or key_value_length < sliding_window:
292
+ if len(attention_mask.shape) == 4:
293
+ expected_shape = (batch_size, 1, query_length, key_value_length)
294
+ if tuple(attention_mask.shape) != expected_shape:
295
+ raise ValueError(
296
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
297
+ )
298
+ elif not is_tracing and torch.all(attention_mask == 1):
299
+ if query_length == 1 or key_value_length == query_length:
300
+ # For query_length == 1, causal attention and bi-directional attention are the same.
301
+ ignore_causal_mask = True
302
+
303
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
304
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
305
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
306
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
307
+
308
+ return ignore_causal_mask
309
+
310
+
311
+ from transformers.modeling_outputs import (
312
+ BaseModelOutputWithPast,
313
+ CausalLMOutputWithPast,
314
+ QuestionAnsweringModelOutput,
315
+ SequenceClassifierOutputWithPast,
316
+ )
317
+ from transformers.modeling_utils import PreTrainedModel
318
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
319
+ from transformers.utils import (
320
+ add_start_docstrings,
321
+ add_start_docstrings_to_model_forward,
322
+ is_flash_attn_2_available,
323
+ is_flash_attn_greater_or_equal_2_10,
324
+ logging,
325
+ replace_return_docstrings,
326
+ )
327
+ from .configuration_llama import LlamaConfig
328
+
329
+
330
+ if is_flash_attn_2_available():
331
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
332
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
333
+
334
+
335
+ logger = logging.get_logger(__name__)
336
+
337
+ _CONFIG_FOR_DOC = "LlamaConfig"
338
+
339
+
340
+ def _get_unpad_data(attention_mask):
341
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
342
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
343
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
344
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
345
+ return (
346
+ indices,
347
+ cu_seqlens,
348
+ max_seqlen_in_batch,
349
+ )
350
+
351
+
352
+ class LlamaRMSNorm(nn.Module):
353
+ def __init__(self, hidden_size, eps=1e-6):
354
+ """
355
+ LlamaRMSNorm is equivalent to T5LayerNorm
356
+ """
357
+ super().__init__()
358
+ self.weight = nn.Parameter(torch.ones(hidden_size))
359
+ self.variance_epsilon = eps
360
+
361
+ def forward(self, hidden_states):
362
+ input_dtype = hidden_states.dtype
363
+ hidden_states = hidden_states.to(torch.float32)
364
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
365
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
366
+ return self.weight * hidden_states.to(input_dtype)
367
+
368
+
369
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
370
+
371
+
372
+ class LlamaRotaryEmbedding(nn.Module):
373
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
374
+ super().__init__()
375
+ self.scaling_factor = scaling_factor
376
+ self.dim = dim
377
+ self.max_position_embeddings = max_position_embeddings
378
+ self.base = base
379
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
380
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
381
+ # For BC we register cos and sin cached
382
+ self.max_seq_len_cached = max_position_embeddings
383
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
384
+ t = t / self.scaling_factor
385
+ freqs = torch.outer(t, self.inv_freq)
386
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
387
+ emb = torch.cat((freqs, freqs), dim=-1)
388
+ self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
389
+ self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
390
+
391
+ @property
392
+ def sin_cached(self):
393
+ logger.warning_once(
394
+ "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
395
+ "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
396
+ )
397
+ return self._sin_cached
398
+
399
+ @property
400
+ def cos_cached(self):
401
+ logger.warning_once(
402
+ "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
403
+ "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
404
+ )
405
+ return self._cos_cached
406
+
407
+ @torch.no_grad()
408
+ def forward(self, x, position_ids):
409
+ # x: [bs, num_attention_heads, seq_len, head_size]
410
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
411
+ position_ids_expanded = position_ids[:, None, :].float()
412
+ # Force float32 since bfloat16 loses precision on long contexts
413
+ # See https://github.com/huggingface/transformers/pull/29285
414
+ device_type = x.device.type
415
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
416
+ with torch.autocast(device_type=device_type, enabled=False):
417
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
418
+ emb = torch.cat((freqs, freqs), dim=-1)
419
+ cos = emb.cos()
420
+ sin = emb.sin()
421
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
422
+
423
+
424
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
425
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
426
+
427
+ def forward(self, x, position_ids):
428
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
429
+ position_ids = position_ids.float() / self.scaling_factor
430
+ cos, sin = super().forward(x, position_ids)
431
+ return cos, sin
432
+
433
+
434
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
435
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
436
+
437
+ def forward(self, x, position_ids):
438
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
439
+ seq_len = torch.max(position_ids) + 1
440
+ if seq_len > self.max_position_embeddings:
441
+ base = self.base * (
442
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
443
+ ) ** (self.dim / (self.dim - 2))
444
+ inv_freq = 1.0 / (
445
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
446
+ )
447
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
448
+
449
+ cos, sin = super().forward(x, position_ids)
450
+ return cos, sin
451
+
452
+
453
+ def rotate_half(x):
454
+ """Rotates half the hidden dims of the input."""
455
+ x1 = x[..., : x.shape[-1] // 2]
456
+ x2 = x[..., x.shape[-1] // 2 :]
457
+ return torch.cat((-x2, x1), dim=-1)
458
+
459
+
460
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
461
+ """Applies Rotary Position Embedding to the query and key tensors.
462
+
463
+ Args:
464
+ q (`torch.Tensor`): The query tensor.
465
+ k (`torch.Tensor`): The key tensor.
466
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
467
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
468
+ position_ids (`torch.Tensor`, *optional*):
469
+ Deprecated and unused.
470
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
471
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
472
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
473
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
474
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
475
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
476
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
477
+ Returns:
478
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
479
+ """
480
+ cos = cos.unsqueeze(unsqueeze_dim)
481
+ sin = sin.unsqueeze(unsqueeze_dim)
482
+ q_embed = (q * cos) + (rotate_half(q) * sin)
483
+ k_embed = (k * cos) + (rotate_half(k) * sin)
484
+ return q_embed, k_embed
485
+
486
+
487
+ class LlamaMLP(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.config = config
491
+ self.hidden_size = config.hidden_size
492
+ self.intermediate_size = config.intermediate_size
493
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
494
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
495
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
496
+ self.act_fn = ACT2FN[config.hidden_act]
497
+
498
+ def forward(self, x):
499
+ if self.config.pretraining_tp > 1:
500
+ slice = self.intermediate_size // self.config.pretraining_tp
501
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
502
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
503
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
504
+
505
+ gate_proj = torch.cat(
506
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
507
+ )
508
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
509
+
510
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
511
+ down_proj = [
512
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
513
+ ]
514
+ down_proj = sum(down_proj)
515
+ else:
516
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
517
+
518
+ return down_proj
519
+
520
+
521
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
522
+ """
523
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
524
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
525
+ """
526
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
527
+ if n_rep == 1:
528
+ return hidden_states
529
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
530
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
531
+
532
+
533
+ class LlamaAttention(nn.Module):
534
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
535
+
536
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
537
+ super().__init__()
538
+ self.config = config
539
+ self.layer_idx = layer_idx
540
+ if layer_idx is None:
541
+ logger.warning_once(
542
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
543
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
544
+ "when creating this class."
545
+ )
546
+
547
+ self.attention_dropout = config.attention_dropout
548
+ self.hidden_size = config.hidden_size
549
+ self.num_heads = config.num_attention_heads
550
+ self.head_dim = self.hidden_size // self.num_heads
551
+ self.num_key_value_heads = config.num_key_value_heads
552
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
553
+ self.max_position_embeddings = config.max_position_embeddings
554
+ self.rope_theta = config.rope_theta
555
+ self.is_causal = True
556
+
557
+ if (self.head_dim * self.num_heads) != self.hidden_size:
558
+ raise ValueError(
559
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
560
+ f" and `num_heads`: {self.num_heads})."
561
+ )
562
+
563
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
564
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
565
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
566
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
567
+ self._init_rope()
568
+
569
+ def _init_rope(self):
570
+ if self.config.rope_scaling is None:
571
+ self.rotary_emb = LlamaRotaryEmbedding(
572
+ self.head_dim,
573
+ max_position_embeddings=self.max_position_embeddings,
574
+ base=self.rope_theta,
575
+ )
576
+ else:
577
+ scaling_type = self.config.rope_scaling["type"]
578
+ scaling_factor = self.config.rope_scaling["factor"]
579
+ if scaling_type == "linear":
580
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
581
+ self.head_dim,
582
+ max_position_embeddings=self.max_position_embeddings,
583
+ scaling_factor=scaling_factor,
584
+ base=self.rope_theta,
585
+ )
586
+ elif scaling_type == "dynamic":
587
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
588
+ self.head_dim,
589
+ max_position_embeddings=self.max_position_embeddings,
590
+ scaling_factor=scaling_factor,
591
+ base=self.rope_theta,
592
+ )
593
+ else:
594
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
595
+
596
+ def forward(
597
+ self,
598
+ hidden_states: torch.Tensor,
599
+ attention_mask: Optional[torch.Tensor] = None,
600
+ position_ids: Optional[torch.LongTensor] = None,
601
+ past_key_value: Optional[Cache] = None,
602
+ output_attentions: bool = False,
603
+ use_cache: bool = False,
604
+ cache_position: Optional[torch.LongTensor] = None,
605
+ **kwargs,
606
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
607
+ bsz, q_len, _ = hidden_states.size()
608
+
609
+ if self.config.pretraining_tp > 1:
610
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
611
+ query_slices = self.q_proj.weight.split(
612
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
613
+ )
614
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
615
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
616
+
617
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
618
+ query_states = torch.cat(query_states, dim=-1)
619
+
620
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
621
+ key_states = torch.cat(key_states, dim=-1)
622
+
623
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
624
+ value_states = torch.cat(value_states, dim=-1)
625
+
626
+ else:
627
+ query_states = self.q_proj(hidden_states)
628
+ key_states = self.k_proj(hidden_states)
629
+ value_states = self.v_proj(hidden_states)
630
+
631
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
632
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
633
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
634
+
635
+ past_key_value = getattr(self, "past_key_value", past_key_value)
636
+ cos, sin = self.rotary_emb(value_states, position_ids)
637
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
638
+
639
+ if past_key_value is not None:
640
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
641
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
642
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
643
+
644
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
645
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
646
+
647
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
648
+
649
+ if attention_mask is not None: # no matter the length, we just slice it
650
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
651
+ attn_weights = attn_weights + causal_mask
652
+
653
+ # upcast attention to fp32
654
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
655
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
656
+ attn_output = torch.matmul(attn_weights, value_states)
657
+
658
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
659
+ raise ValueError(
660
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
661
+ f" {attn_output.size()}"
662
+ )
663
+
664
+ attn_output = attn_output.transpose(1, 2).contiguous()
665
+
666
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
667
+
668
+ if self.config.pretraining_tp > 1:
669
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
670
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
671
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
672
+ else:
673
+ attn_output = self.o_proj(attn_output)
674
+
675
+ if not output_attentions:
676
+ attn_weights = None
677
+
678
+ return attn_output, attn_weights, past_key_value
679
+
680
+
681
+ class LlamaFlashAttention2(LlamaAttention):
682
+ """
683
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
684
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
685
+ flash attention and deal with padding tokens in case the input contains any of them.
686
+ """
687
+
688
+ def __init__(self, *args, **kwargs):
689
+ super().__init__(*args, **kwargs)
690
+
691
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
692
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
693
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
694
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
695
+
696
+ def forward(
697
+ self,
698
+ hidden_states: torch.Tensor,
699
+ attention_mask: Optional[torch.LongTensor] = None,
700
+ position_ids: Optional[torch.LongTensor] = None,
701
+ past_key_value: Optional[Cache] = None,
702
+ output_attentions: bool = False,
703
+ use_cache: bool = False,
704
+ cache_position: Optional[torch.LongTensor] = None,
705
+ **kwargs,
706
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
707
+ output_attentions = False
708
+
709
+ bsz, q_len, _ = hidden_states.size()
710
+
711
+ query_states = self.q_proj(hidden_states)
712
+ key_states = self.k_proj(hidden_states)
713
+ value_states = self.v_proj(hidden_states)
714
+
715
+ # Flash attention requires the input to have the shape
716
+ # batch_size x seq_length x head_dim x hidden_dim
717
+ # therefore we just need to keep the original shape
718
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
719
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
720
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
721
+
722
+ cos, sin = self.rotary_emb(value_states, position_ids)
723
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
724
+
725
+ past_key_value = getattr(self, "past_key_value", past_key_value)
726
+
727
+ if past_key_value is not None:
728
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
729
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
730
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
731
+
732
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
733
+ # to be able to avoid many of these transpose/reshape/view.
734
+ query_states = query_states.transpose(1, 2)
735
+ key_states = key_states.transpose(1, 2)
736
+ value_states = value_states.transpose(1, 2)
737
+
738
+ dropout_rate = self.attention_dropout if self.training else 0.0
739
+
740
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
741
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
742
+ # cast them back in the correct dtype just to be sure everything works as expected.
743
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
744
+ # in fp32. (LlamaRMSNorm handles it correctly)
745
+
746
+ input_dtype = query_states.dtype
747
+ if input_dtype == torch.float32:
748
+ if torch.is_autocast_enabled():
749
+ target_dtype = torch.get_autocast_gpu_dtype()
750
+ # Handle the case where the model is quantized
751
+ elif hasattr(self.config, "_pre_quantization_dtype"):
752
+ target_dtype = self.config._pre_quantization_dtype
753
+ else:
754
+ target_dtype = self.q_proj.weight.dtype
755
+
756
+ logger.warning_once(
757
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
758
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
759
+ f" {target_dtype}."
760
+ )
761
+
762
+ query_states = query_states.to(target_dtype)
763
+ key_states = key_states.to(target_dtype)
764
+ value_states = value_states.to(target_dtype)
765
+
766
+ attn_output = self._flash_attention_forward(
767
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
768
+ )
769
+
770
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
771
+ attn_output = self.o_proj(attn_output)
772
+
773
+ if not output_attentions:
774
+ attn_weights = None
775
+
776
+ return attn_output, attn_weights, past_key_value
777
+
778
+ def _flash_attention_forward(
779
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
780
+ ):
781
+ """
782
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
783
+ first unpad the input, then computes the attention scores and pad the final attention scores.
784
+
785
+ Args:
786
+ query_states (`torch.Tensor`):
787
+ Input query states to be passed to Flash Attention API
788
+ key_states (`torch.Tensor`):
789
+ Input key states to be passed to Flash Attention API
790
+ value_states (`torch.Tensor`):
791
+ Input value states to be passed to Flash Attention API
792
+ attention_mask (`torch.Tensor`):
793
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
794
+ position of padding tokens and 1 for the position of non-padding tokens.
795
+ dropout (`float`):
796
+ Attention dropout
797
+ softmax_scale (`float`, *optional*):
798
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
799
+ """
800
+ if not self._flash_attn_uses_top_left_mask:
801
+ causal = self.is_causal
802
+ else:
803
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
804
+ causal = self.is_causal and query_length != 1
805
+
806
+ # Contains at least one padding token in the sequence
807
+ if attention_mask is not None:
808
+ batch_size = query_states.shape[0]
809
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
810
+ query_states, key_states, value_states, attention_mask, query_length
811
+ )
812
+
813
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
814
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
815
+
816
+ attn_output_unpad = flash_attn_varlen_func(
817
+ query_states,
818
+ key_states,
819
+ value_states,
820
+ cu_seqlens_q=cu_seqlens_q,
821
+ cu_seqlens_k=cu_seqlens_k,
822
+ max_seqlen_q=max_seqlen_in_batch_q,
823
+ max_seqlen_k=max_seqlen_in_batch_k,
824
+ dropout_p=dropout,
825
+ softmax_scale=softmax_scale,
826
+ causal=causal,
827
+ )
828
+
829
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
830
+ else:
831
+ attn_output = flash_attn_func(
832
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
833
+ )
834
+
835
+ return attn_output
836
+
837
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
838
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
839
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
840
+
841
+ key_layer = index_first_axis(
842
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
843
+ )
844
+ value_layer = index_first_axis(
845
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
846
+ )
847
+ if query_length == kv_seq_len:
848
+ query_layer = index_first_axis(
849
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
850
+ )
851
+ cu_seqlens_q = cu_seqlens_k
852
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
853
+ indices_q = indices_k
854
+ elif query_length == 1:
855
+ max_seqlen_in_batch_q = 1
856
+ cu_seqlens_q = torch.arange(
857
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
858
+ ) # There is a memcpy here, that is very bad.
859
+ indices_q = cu_seqlens_q[:-1]
860
+ query_layer = query_layer.squeeze(1)
861
+ else:
862
+ # The -q_len: slice assumes left padding.
863
+ attention_mask = attention_mask[:, -query_length:]
864
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
865
+
866
+ return (
867
+ query_layer,
868
+ key_layer,
869
+ value_layer,
870
+ indices_q,
871
+ (cu_seqlens_q, cu_seqlens_k),
872
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
873
+ )
874
+
875
+
876
+ class LlamaSdpaAttention(LlamaAttention):
877
+ """
878
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
879
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
880
+ SDPA API.
881
+ """
882
+
883
+ # Adapted from LlamaAttention.forward
884
+ def forward(
885
+ self,
886
+ hidden_states: torch.Tensor,
887
+ attention_mask: Optional[torch.Tensor] = None,
888
+ position_ids: Optional[torch.LongTensor] = None,
889
+ past_key_value: Optional[Cache] = None,
890
+ output_attentions: bool = False,
891
+ use_cache: bool = False,
892
+ cache_position: Optional[torch.LongTensor] = None,
893
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
894
+ if output_attentions:
895
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
896
+ logger.warning_once(
897
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
898
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
899
+ )
900
+ return super().forward(
901
+ hidden_states=hidden_states,
902
+ attention_mask=attention_mask,
903
+ position_ids=position_ids,
904
+ past_key_value=past_key_value,
905
+ output_attentions=output_attentions,
906
+ use_cache=use_cache,
907
+ cache_position=cache_position,
908
+ )
909
+
910
+ bsz, q_len, _ = hidden_states.size()
911
+
912
+ query_states = self.q_proj(hidden_states)
913
+ key_states = self.k_proj(hidden_states)
914
+ value_states = self.v_proj(hidden_states)
915
+
916
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
917
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
918
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
919
+
920
+ cos, sin = self.rotary_emb(value_states, position_ids)
921
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
922
+
923
+ # In case static cache is used, it is an instance attribute.
924
+ past_key_value = getattr(self, "past_key_value", past_key_value)
925
+
926
+ if past_key_value is not None:
927
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
928
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
929
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
930
+
931
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
932
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
933
+
934
+ causal_mask = attention_mask
935
+ if attention_mask is not None:
936
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
937
+
938
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
939
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
940
+ if query_states.device.type == "cuda" and causal_mask is not None:
941
+ query_states = query_states.contiguous()
942
+ key_states = key_states.contiguous()
943
+ value_states = value_states.contiguous()
944
+
945
+ # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
946
+ # relying on the `is_causal` argument.
947
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
948
+ query_states,
949
+ key_states,
950
+ value_states,
951
+ attn_mask=causal_mask,
952
+ dropout_p=self.attention_dropout if self.training else 0.0,
953
+ is_causal=causal_mask is None and q_len > 1,
954
+ )
955
+
956
+ attn_output = attn_output.transpose(1, 2).contiguous()
957
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
958
+
959
+ attn_output = self.o_proj(attn_output)
960
+
961
+ return attn_output, None, past_key_value
962
+
963
+
964
+ LLAMA_ATTENTION_CLASSES = {
965
+ "eager": LlamaAttention,
966
+ "flash_attention_2": LlamaFlashAttention2,
967
+ "sdpa": LlamaSdpaAttention,
968
+ }
969
+
970
+
971
+ class LlamaDecoderLayer(nn.Module):
972
+ def __init__(self, config: LlamaConfig, layer_idx: int):
973
+ super().__init__()
974
+ self.hidden_size = config.hidden_size
975
+
976
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
977
+
978
+ self.mlp = LlamaMLP(config)
979
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
980
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
981
+
982
+ def forward(
983
+ self,
984
+ hidden_states: torch.Tensor,
985
+ attention_mask: Optional[torch.Tensor] = None,
986
+ position_ids: Optional[torch.LongTensor] = None,
987
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
988
+ output_attentions: Optional[bool] = False,
989
+ use_cache: Optional[bool] = False,
990
+ cache_position: Optional[torch.LongTensor] = None,
991
+ **kwargs,
992
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
993
+ """
994
+ Args:
995
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
996
+ attention_mask (`torch.FloatTensor`, *optional*):
997
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
998
+ query_sequence_length, key_sequence_length)` if default attention is used.
999
+ output_attentions (`bool`, *optional*):
1000
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1001
+ returned tensors for more detail.
1002
+ use_cache (`bool`, *optional*):
1003
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1004
+ (see `past_key_values`).
1005
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1006
+ """
1007
+ if "padding_mask" in kwargs:
1008
+ warnings.warn(
1009
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1010
+ )
1011
+
1012
+ residual = hidden_states
1013
+
1014
+ hidden_states = self.input_layernorm(hidden_states)
1015
+
1016
+ # Self Attention
1017
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1018
+ hidden_states=hidden_states,
1019
+ attention_mask=attention_mask,
1020
+ position_ids=position_ids,
1021
+ past_key_value=past_key_value,
1022
+ output_attentions=output_attentions,
1023
+ use_cache=use_cache,
1024
+ cache_position=cache_position,
1025
+ **kwargs,
1026
+ )
1027
+ hidden_states = residual + hidden_states
1028
+
1029
+ # Fully Connected
1030
+ residual = hidden_states
1031
+ hidden_states = self.post_attention_layernorm(hidden_states)
1032
+ hidden_states = self.mlp(hidden_states)
1033
+ hidden_states = residual + hidden_states
1034
+
1035
+ outputs = (hidden_states,)
1036
+
1037
+ if output_attentions:
1038
+ outputs += (self_attn_weights,)
1039
+
1040
+ if use_cache:
1041
+ outputs += (present_key_value,)
1042
+
1043
+ return outputs
1044
+
1045
+
1046
+ LLAMA_START_DOCSTRING = r"""
1047
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1048
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1049
+ etc.)
1050
+
1051
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1052
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1053
+ and behavior.
1054
+
1055
+ Parameters:
1056
+ config ([`LlamaConfig`]):
1057
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1058
+ load the weights associated with the model, only the configuration. Check out the
1059
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1060
+ """
1061
+
1062
+
1063
+ @add_start_docstrings(
1064
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1065
+ LLAMA_START_DOCSTRING,
1066
+ )
1067
+ class LlamaPreTrainedModel(PreTrainedModel):
1068
+ config_class = LlamaConfig
1069
+ base_model_prefix = "model"
1070
+ supports_gradient_checkpointing = True
1071
+ _no_split_modules = ["LlamaDecoderLayer"]
1072
+ _skip_keys_device_placement = ["past_key_values"]
1073
+ _supports_flash_attn_2 = True
1074
+ _supports_sdpa = True
1075
+ _supports_cache_class = True
1076
+
1077
+ def _init_weights(self, module):
1078
+ std = self.config.initializer_range
1079
+ if isinstance(module, nn.Linear):
1080
+ module.weight.data.normal_(mean=0.0, std=std)
1081
+ if module.bias is not None:
1082
+ module.bias.data.zero_()
1083
+ elif isinstance(module, nn.Embedding):
1084
+ module.weight.data.normal_(mean=0.0, std=std)
1085
+ if module.padding_idx is not None:
1086
+ module.weight.data[module.padding_idx].zero_()
1087
+
1088
+ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
1089
+ if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
1090
+ raise ValueError(
1091
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
1092
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
1093
+ )
1094
+
1095
+ for layer in self.model.layers:
1096
+ device = layer.input_layernorm.weight.device
1097
+ if hasattr(self.config, "_pre_quantization_dtype"):
1098
+ dtype = self.config._pre_quantization_dtype
1099
+ else:
1100
+ dtype = layer.self_attn.o_proj.weight.dtype
1101
+ layer.self_attn.past_key_value = cache_cls(
1102
+ self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
1103
+ )
1104
+
1105
+ def _reset_cache(self):
1106
+ for layer in self.model.layers:
1107
+ layer.self_attn.past_key_value = None
1108
+
1109
+
1110
+ LLAMA_INPUTS_DOCSTRING = r"""
1111
+ Args:
1112
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1113
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1114
+ it.
1115
+
1116
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1117
+ [`PreTrainedTokenizer.__call__`] for details.
1118
+
1119
+ [What are input IDs?](../glossary#input-ids)
1120
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1121
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1122
+
1123
+ - 1 for tokens that are **not masked**,
1124
+ - 0 for tokens that are **masked**.
1125
+
1126
+ [What are attention masks?](../glossary#attention-mask)
1127
+
1128
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1129
+ [`PreTrainedTokenizer.__call__`] for details.
1130
+
1131
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1132
+ `past_key_values`).
1133
+
1134
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1135
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1136
+ information on the default strategy.
1137
+
1138
+ - 1 indicates the head is **not masked**,
1139
+ - 0 indicates the head is **masked**.
1140
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1141
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1142
+ config.n_positions - 1]`.
1143
+
1144
+ [What are position IDs?](../glossary#position-ids)
1145
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1146
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1147
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1148
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1149
+
1150
+ Two formats are allowed:
1151
+ - a [`~cache_utils.Cache`] instance;
1152
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1153
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1154
+ cache format.
1155
+
1156
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1157
+ legacy cache format will be returned.
1158
+
1159
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1160
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1161
+ of shape `(batch_size, sequence_length)`.
1162
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1163
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1164
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1165
+ model's internal embedding lookup matrix.
1166
+ use_cache (`bool`, *optional*):
1167
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1168
+ `past_key_values`).
1169
+ output_attentions (`bool`, *optional*):
1170
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1171
+ tensors for more detail.
1172
+ output_hidden_states (`bool`, *optional*):
1173
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1174
+ more detail.
1175
+ return_dict (`bool`, *optional*):
1176
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1177
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1178
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1179
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1180
+ the complete sequence length.
1181
+ """
1182
+
1183
+
1184
+ @add_start_docstrings(
1185
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1186
+ LLAMA_START_DOCSTRING,
1187
+ )
1188
+ class LlamaModel(LlamaPreTrainedModel):
1189
+ """
1190
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
1191
+
1192
+ Args:
1193
+ config: LlamaConfig
1194
+ """
1195
+
1196
+ def __init__(self, config: LlamaConfig):
1197
+ super().__init__(config)
1198
+ self.padding_idx = config.pad_token_id
1199
+ self.vocab_size = config.vocab_size
1200
+
1201
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1202
+ self.layers = nn.ModuleList(
1203
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1204
+ )
1205
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1206
+ self.gradient_checkpointing = False
1207
+
1208
+ # Initialize weights and apply final processing
1209
+ self.post_init()
1210
+
1211
+ def get_input_embeddings(self):
1212
+ return self.embed_tokens
1213
+
1214
+ def set_input_embeddings(self, value):
1215
+ self.embed_tokens = value
1216
+
1217
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1218
+ def forward(
1219
+ self,
1220
+ input_ids: torch.LongTensor = None,
1221
+ attention_mask: Optional[torch.Tensor] = None,
1222
+ position_ids: Optional[torch.LongTensor] = None,
1223
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1224
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1225
+ use_cache: Optional[bool] = None,
1226
+ output_attentions: Optional[bool] = None,
1227
+ output_hidden_states: Optional[bool] = None,
1228
+ return_dict: Optional[bool] = None,
1229
+ cache_position: Optional[torch.LongTensor] = None,
1230
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1231
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1232
+ output_hidden_states = (
1233
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1234
+ )
1235
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1237
+
1238
+ if (input_ids is None) ^ (inputs_embeds is not None):
1239
+ raise ValueError(
1240
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1241
+ )
1242
+
1243
+ if self.gradient_checkpointing and self.training and use_cache:
1244
+ logger.warning_once(
1245
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1246
+ )
1247
+ use_cache = False
1248
+
1249
+ if inputs_embeds is None:
1250
+ inputs_embeds = self.embed_tokens(input_ids)
1251
+
1252
+ past_seen_tokens = 0
1253
+ if use_cache: # kept for BC (cache positions)
1254
+ if not isinstance(past_key_values, StaticCache):
1255
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1256
+ past_seen_tokens = past_key_values.get_seq_length()
1257
+
1258
+ if cache_position is None:
1259
+ if isinstance(past_key_values, StaticCache):
1260
+ raise ValueError("cache_position is a required argument when using StaticCache.")
1261
+ cache_position = torch.arange(
1262
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1263
+ )
1264
+
1265
+ if position_ids is None:
1266
+ position_ids = cache_position.unsqueeze(0)
1267
+
1268
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
1269
+
1270
+ # embed positions
1271
+ hidden_states = inputs_embeds
1272
+
1273
+ # decoder layers
1274
+ all_hidden_states = () if output_hidden_states else None
1275
+ all_self_attns = () if output_attentions else None
1276
+ next_decoder_cache = None
1277
+
1278
+ for decoder_layer in self.layers:
1279
+ if output_hidden_states:
1280
+ all_hidden_states += (hidden_states,)
1281
+
1282
+ if self.gradient_checkpointing and self.training:
1283
+ layer_outputs = self._gradient_checkpointing_func(
1284
+ decoder_layer.__call__,
1285
+ hidden_states,
1286
+ causal_mask,
1287
+ position_ids,
1288
+ past_key_values,
1289
+ output_attentions,
1290
+ use_cache,
1291
+ cache_position,
1292
+ )
1293
+ else:
1294
+ layer_outputs = decoder_layer(
1295
+ hidden_states,
1296
+ attention_mask=causal_mask,
1297
+ position_ids=position_ids,
1298
+ past_key_value=past_key_values,
1299
+ output_attentions=output_attentions,
1300
+ use_cache=use_cache,
1301
+ cache_position=cache_position,
1302
+ )
1303
+
1304
+ hidden_states = layer_outputs[0]
1305
+
1306
+ if use_cache:
1307
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1308
+
1309
+ if output_attentions:
1310
+ all_self_attns += (layer_outputs[1],)
1311
+
1312
+ hidden_states = self.norm(hidden_states)
1313
+
1314
+ # add hidden states from the last decoder layer
1315
+ if output_hidden_states:
1316
+ all_hidden_states += (hidden_states,)
1317
+
1318
+ next_cache = None
1319
+ if use_cache:
1320
+ next_cache = (
1321
+ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1322
+ )
1323
+ if not return_dict:
1324
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1325
+ return BaseModelOutputWithPast(
1326
+ last_hidden_state=hidden_states,
1327
+ past_key_values=next_cache,
1328
+ hidden_states=all_hidden_states,
1329
+ attentions=all_self_attns,
1330
+ )
1331
+
1332
+ def _update_causal_mask(
1333
+ self,
1334
+ attention_mask: torch.Tensor,
1335
+ input_tensor: torch.Tensor,
1336
+ cache_position: torch.Tensor,
1337
+ past_seen_tokens: int,
1338
+ ):
1339
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1340
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1341
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1342
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1343
+
1344
+ if self.config._attn_implementation == "flash_attention_2":
1345
+ if attention_mask is not None and 0.0 in attention_mask:
1346
+ return attention_mask
1347
+ return None
1348
+
1349
+ if self.config._attn_implementation == "sdpa":
1350
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
1351
+ # in order to dispatch on Flash Attention 2.
1352
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1353
+ attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
1354
+ ):
1355
+ return None
1356
+
1357
+ dtype, device = input_tensor.dtype, input_tensor.device
1358
+ min_dtype = torch.finfo(dtype).min
1359
+ sequence_length = input_tensor.shape[1]
1360
+ if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
1361
+ target_length = self.config.max_position_embeddings
1362
+ else: # dynamic cache
1363
+ target_length = (
1364
+ attention_mask.shape[-1]
1365
+ if isinstance(attention_mask, torch.Tensor)
1366
+ else past_seen_tokens + sequence_length + 1
1367
+ )
1368
+
1369
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1370
+ if sequence_length != 1:
1371
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1372
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1373
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1374
+ if attention_mask is not None:
1375
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1376
+ if attention_mask.dim() == 2:
1377
+ mask_length = attention_mask.shape[-1]
1378
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1379
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1380
+ elif attention_mask.dim() == 4:
1381
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1382
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1383
+ if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1384
+ offset = cache_position[0]
1385
+ else:
1386
+ offset = 0
1387
+ mask_shape = attention_mask.shape
1388
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1389
+ causal_mask[
1390
+ : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1391
+ ] = mask_slice
1392
+
1393
+ if (
1394
+ self.config._attn_implementation == "sdpa"
1395
+ and attention_mask is not None
1396
+ and attention_mask.device.type == "cuda"
1397
+ ):
1398
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1399
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1400
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1401
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1402
+
1403
+ return causal_mask
1404
+
1405
+
1406
+ class LlamaForCausalLM(LlamaPreTrainedModel):
1407
+ _tied_weights_keys = ["lm_head.weight"]
1408
+
1409
+ def __init__(self, config):
1410
+ super().__init__(config)
1411
+ self.model = LlamaModel(config)
1412
+ self.vocab_size = config.vocab_size
1413
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1414
+
1415
+ # Initialize weights and apply final processing
1416
+ self.post_init()
1417
+
1418
+ def get_input_embeddings(self):
1419
+ return self.model.embed_tokens
1420
+
1421
+ def set_input_embeddings(self, value):
1422
+ self.model.embed_tokens = value
1423
+
1424
+ def get_output_embeddings(self):
1425
+ return self.lm_head
1426
+
1427
+ def set_output_embeddings(self, new_embeddings):
1428
+ self.lm_head = new_embeddings
1429
+
1430
+ def set_decoder(self, decoder):
1431
+ self.model = decoder
1432
+
1433
+ def get_decoder(self):
1434
+ return self.model
1435
+
1436
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1437
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1438
+ def forward(
1439
+ self,
1440
+ input_ids: torch.LongTensor = None,
1441
+ attention_mask: Optional[torch.Tensor] = None,
1442
+ position_ids: Optional[torch.LongTensor] = None,
1443
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1444
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1445
+ labels: Optional[torch.LongTensor] = None,
1446
+ use_cache: Optional[bool] = None,
1447
+ output_attentions: Optional[bool] = None,
1448
+ output_hidden_states: Optional[bool] = None,
1449
+ return_dict: Optional[bool] = None,
1450
+ cache_position: Optional[torch.LongTensor] = None,
1451
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1452
+ r"""
1453
+ Args:
1454
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1455
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1456
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1457
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1458
+
1459
+ Returns:
1460
+
1461
+ Example:
1462
+
1463
+ ```python
1464
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1465
+
1466
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1467
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1468
+
1469
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1470
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1471
+
1472
+ >>> # Generate
1473
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1474
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1475
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1476
+ ```"""
1477
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1478
+ output_hidden_states = (
1479
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1480
+ )
1481
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1482
+
1483
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1484
+ outputs = self.model(
1485
+ input_ids=input_ids,
1486
+ attention_mask=attention_mask,
1487
+ position_ids=position_ids,
1488
+ past_key_values=past_key_values,
1489
+ inputs_embeds=inputs_embeds,
1490
+ use_cache=use_cache,
1491
+ output_attentions=output_attentions,
1492
+ output_hidden_states=output_hidden_states,
1493
+ return_dict=return_dict,
1494
+ cache_position=cache_position,
1495
+ )
1496
+
1497
+ hidden_states = outputs[0]
1498
+ if self.config.pretraining_tp > 1:
1499
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1500
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1501
+ logits = torch.cat(logits, dim=-1)
1502
+ else:
1503
+ logits = self.lm_head(hidden_states)
1504
+ logits = logits.float()
1505
+
1506
+ loss = None
1507
+ if labels is not None:
1508
+ # Shift so that tokens < n predict n
1509
+ shift_logits = logits[..., :-1, :].contiguous()
1510
+ shift_labels = labels[..., 1:].contiguous()
1511
+ # Flatten the tokens
1512
+ loss_fct = CrossEntropyLoss()
1513
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1514
+ shift_labels = shift_labels.view(-1)
1515
+ # Enable model parallelism
1516
+ shift_labels = shift_labels.to(shift_logits.device)
1517
+ loss = loss_fct(shift_logits, shift_labels)
1518
+
1519
+ if not return_dict:
1520
+ output = (logits,) + outputs[1:]
1521
+ return (loss,) + output if loss is not None else output
1522
+
1523
+ return CausalLMOutputWithPast(
1524
+ loss=loss,
1525
+ logits=logits,
1526
+ past_key_values=outputs.past_key_values,
1527
+ hidden_states=outputs.hidden_states,
1528
+ attentions=outputs.attentions,
1529
+ )
1530
+
1531
+ def prepare_inputs_for_generation(
1532
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
1533
+ ):
1534
+ # With static cache, the `past_key_values` is None
1535
+ # TODO joao: standardize interface for the different Cache classes and remove of this if
1536
+ has_static_cache = False
1537
+ if past_key_values is None:
1538
+ past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
1539
+ has_static_cache = past_key_values is not None
1540
+
1541
+ past_length = 0
1542
+ if past_key_values is not None:
1543
+ if isinstance(past_key_values, Cache):
1544
+ past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1545
+ max_cache_length = (
1546
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1547
+ if past_key_values.get_max_length() is not None
1548
+ else None
1549
+ )
1550
+ cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
1551
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1552
+ else:
1553
+ cache_length = past_length = past_key_values[0][0].shape[2]
1554
+ max_cache_length = None
1555
+
1556
+ # Keep only the unprocessed tokens:
1557
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1558
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1559
+ # input)
1560
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1561
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1562
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1563
+ # input_ids based on the past_length.
1564
+ elif past_length < input_ids.shape[1]:
1565
+ input_ids = input_ids[:, past_length:]
1566
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1567
+ else:
1568
+ remove_prefix_length = input_ids.shape[1] - 1
1569
+ input_ids = input_ids[:, remove_prefix_length:]
1570
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1571
+ if (
1572
+ max_cache_length is not None
1573
+ and attention_mask is not None
1574
+ and cache_length + input_ids.shape[1] > max_cache_length
1575
+ ):
1576
+ attention_mask = attention_mask[:, -max_cache_length:]
1577
+
1578
+ position_ids = kwargs.get("position_ids", None)
1579
+ if attention_mask is not None and position_ids is None:
1580
+ # create position_ids on the fly for batch generation
1581
+ position_ids = attention_mask.long().cumsum(-1) - 1
1582
+ position_ids.masked_fill_(attention_mask == 0, 1)
1583
+ if past_key_values:
1584
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1585
+
1586
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1587
+ if inputs_embeds is not None and past_key_values is None:
1588
+ model_inputs = {"inputs_embeds": inputs_embeds}
1589
+ else:
1590
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1591
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1592
+ # TODO: use `next_tokens` directly instead.
1593
+ model_inputs = {"input_ids": input_ids.contiguous()}
1594
+
1595
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1596
+ if cache_position is None:
1597
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
1598
+ else:
1599
+ cache_position = cache_position[-input_length:]
1600
+
1601
+ if has_static_cache:
1602
+ past_key_values = None
1603
+
1604
+ model_inputs.update(
1605
+ {
1606
+ "position_ids": position_ids,
1607
+ "cache_position": cache_position,
1608
+ "past_key_values": past_key_values,
1609
+ "use_cache": kwargs.get("use_cache"),
1610
+ "attention_mask": attention_mask,
1611
+ }
1612
+ )
1613
+ return model_inputs
1614
+
1615
+ @staticmethod
1616
+ def _reorder_cache(past_key_values, beam_idx):
1617
+ reordered_past = ()
1618
+ for layer_past in past_key_values:
1619
+ reordered_past += (
1620
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1621
+ )
1622
+ return reordered_past
1623
+
1624
+
1625
+ @add_start_docstrings(
1626
+ """
1627
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1628
+
1629
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1630
+ (e.g. GPT-2) do.
1631
+
1632
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1633
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1634
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1635
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1636
+ each row of the batch).
1637
+ """,
1638
+ LLAMA_START_DOCSTRING,
1639
+ )
1640
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1641
+ def __init__(self, config):
1642
+ super().__init__(config)
1643
+ self.num_labels = config.num_labels
1644
+ self.model = LlamaModel(config)
1645
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1646
+
1647
+ # Initialize weights and apply final processing
1648
+ self.post_init()
1649
+
1650
+ def get_input_embeddings(self):
1651
+ return self.model.embed_tokens
1652
+
1653
+ def set_input_embeddings(self, value):
1654
+ self.model.embed_tokens = value
1655
+
1656
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1657
+ def forward(
1658
+ self,
1659
+ input_ids: torch.LongTensor = None,
1660
+ attention_mask: Optional[torch.Tensor] = None,
1661
+ position_ids: Optional[torch.LongTensor] = None,
1662
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1663
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1664
+ labels: Optional[torch.LongTensor] = None,
1665
+ use_cache: Optional[bool] = None,
1666
+ output_attentions: Optional[bool] = None,
1667
+ output_hidden_states: Optional[bool] = None,
1668
+ return_dict: Optional[bool] = None,
1669
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1670
+ r"""
1671
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1672
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1673
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1674
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1675
+ """
1676
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1677
+
1678
+ transformer_outputs = self.model(
1679
+ input_ids,
1680
+ attention_mask=attention_mask,
1681
+ position_ids=position_ids,
1682
+ past_key_values=past_key_values,
1683
+ inputs_embeds=inputs_embeds,
1684
+ use_cache=use_cache,
1685
+ output_attentions=output_attentions,
1686
+ output_hidden_states=output_hidden_states,
1687
+ return_dict=return_dict,
1688
+ )
1689
+ hidden_states = transformer_outputs[0]
1690
+ logits = self.score(hidden_states)
1691
+
1692
+ if input_ids is not None:
1693
+ batch_size = input_ids.shape[0]
1694
+ else:
1695
+ batch_size = inputs_embeds.shape[0]
1696
+
1697
+ if self.config.pad_token_id is None and batch_size != 1:
1698
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1699
+ if self.config.pad_token_id is None:
1700
+ sequence_lengths = -1
1701
+ else:
1702
+ if input_ids is not None:
1703
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1704
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1705
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1706
+ sequence_lengths = sequence_lengths.to(logits.device)
1707
+ else:
1708
+ sequence_lengths = -1
1709
+
1710
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1711
+
1712
+ loss = None
1713
+ if labels is not None:
1714
+ labels = labels.to(logits.device)
1715
+ if self.config.problem_type is None:
1716
+ if self.num_labels == 1:
1717
+ self.config.problem_type = "regression"
1718
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1719
+ self.config.problem_type = "single_label_classification"
1720
+ else:
1721
+ self.config.problem_type = "multi_label_classification"
1722
+
1723
+ if self.config.problem_type == "regression":
1724
+ loss_fct = MSELoss()
1725
+ if self.num_labels == 1:
1726
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1727
+ else:
1728
+ loss = loss_fct(pooled_logits, labels)
1729
+ elif self.config.problem_type == "single_label_classification":
1730
+ loss_fct = CrossEntropyLoss()
1731
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1732
+ elif self.config.problem_type == "multi_label_classification":
1733
+ loss_fct = BCEWithLogitsLoss()
1734
+ loss = loss_fct(pooled_logits, labels)
1735
+ if not return_dict:
1736
+ output = (pooled_logits,) + transformer_outputs[1:]
1737
+ return ((loss,) + output) if loss is not None else output
1738
+
1739
+ return SequenceClassifierOutputWithPast(
1740
+ loss=loss,
1741
+ logits=pooled_logits,
1742
+ past_key_values=transformer_outputs.past_key_values,
1743
+ hidden_states=transformer_outputs.hidden_states,
1744
+ attentions=transformer_outputs.attentions,
1745
+ )
1746
+
1747
+
1748
+ @add_start_docstrings(
1749
+ """
1750
+ The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
1751
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1752
+ """,
1753
+ LLAMA_START_DOCSTRING,
1754
+ )
1755
+ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1756
+ base_model_prefix = "transformer"
1757
+
1758
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
1759
+ def __init__(self, config):
1760
+ super().__init__(config)
1761
+ self.transformer = LlamaModel(config)
1762
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1763
+
1764
+ # Initialize weights and apply final processing
1765
+ self.post_init()
1766
+
1767
+ def get_input_embeddings(self):
1768
+ return self.transformer.embed_tokens
1769
+
1770
+ def set_input_embeddings(self, value):
1771
+ self.transformer.embed_tokens = value
1772
+
1773
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1774
+ def forward(
1775
+ self,
1776
+ input_ids: Optional[torch.LongTensor] = None,
1777
+ attention_mask: Optional[torch.FloatTensor] = None,
1778
+ position_ids: Optional[torch.LongTensor] = None,
1779
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1780
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1781
+ start_positions: Optional[torch.LongTensor] = None,
1782
+ end_positions: Optional[torch.LongTensor] = None,
1783
+ output_attentions: Optional[bool] = None,
1784
+ output_hidden_states: Optional[bool] = None,
1785
+ return_dict: Optional[bool] = None,
1786
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1787
+ r"""
1788
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1789
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1790
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1791
+ are not taken into account for computing the loss.
1792
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1793
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1794
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1795
+ are not taken into account for computing the loss.
1796
+ """
1797
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1798
+
1799
+ outputs = self.transformer(
1800
+ input_ids,
1801
+ attention_mask=attention_mask,
1802
+ position_ids=position_ids,
1803
+ past_key_values=past_key_values,
1804
+ inputs_embeds=inputs_embeds,
1805
+ output_attentions=output_attentions,
1806
+ output_hidden_states=output_hidden_states,
1807
+ return_dict=return_dict,
1808
+ )
1809
+
1810
+ sequence_output = outputs[0]
1811
+
1812
+ logits = self.qa_outputs(sequence_output)
1813
+ start_logits, end_logits = logits.split(1, dim=-1)
1814
+ start_logits = start_logits.squeeze(-1).contiguous()
1815
+ end_logits = end_logits.squeeze(-1).contiguous()
1816
+
1817
+ total_loss = None
1818
+ if start_positions is not None and end_positions is not None:
1819
+ # If we are on multi-GPU, split add a dimension
1820
+ if len(start_positions.size()) > 1:
1821
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1822
+ if len(end_positions.size()) > 1:
1823
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1824
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1825
+ ignored_index = start_logits.size(1)
1826
+ start_positions = start_positions.clamp(0, ignored_index)
1827
+ end_positions = end_positions.clamp(0, ignored_index)
1828
+
1829
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1830
+ start_loss = loss_fct(start_logits, start_positions)
1831
+ end_loss = loss_fct(end_logits, end_positions)
1832
+ total_loss = (start_loss + end_loss) / 2
1833
+
1834
+ if not return_dict:
1835
+ output = (start_logits, end_logits) + outputs[2:]
1836
+ return ((total_loss,) + output) if total_loss is not None else output
1837
+
1838
+ return QuestionAnsweringModelOutput(
1839
+ loss=total_loss,
1840
+ start_logits=start_logits,
1841
+ end_logits=end_logits,
1842
+ hidden_states=outputs.hidden_states,
1843
+ attentions=outputs.attentions,
1844
+ )
bunny/model/language_model/llama/tokenization_llama.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ """Tokenization classes for LLaMA."""
22
+ import os
23
+ from shutil import copyfile
24
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
25
+
26
+ import sentencepiece as spm
27
+
28
+ from transformers.convert_slow_tokenizer import import_protobuf
29
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
30
+ from transformers.utils import logging
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers.tokenization_utils_base import TextInput
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
39
+
40
+ SPIECE_UNDERLINE = "▁"
41
+
42
+ B_INST, E_INST = "[INST]", "[/INST]"
43
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
44
+
45
+ # fmt: off
46
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
47
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
48
+ that your responses are socially unbiased and positive in nature.
49
+
50
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
51
+ correct. If you don't know the answer to a question, please don't share false information."""
52
+ # fmt: on
53
+
54
+
55
+ class LlamaTokenizer(PreTrainedTokenizer):
56
+ """
57
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
58
+ no padding token in the original model.
59
+
60
+ Args:
61
+ vocab_file (`str`):
62
+ Path to the vocabulary file.
63
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
64
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
65
+ token instead.
66
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
67
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
68
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
69
+ The end of sequence token.
70
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*):
71
+ A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
72
+ attention mechanisms or loss computation.
73
+ sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
74
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
75
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
76
+ to set:
77
+
78
+ - `enable_sampling`: Enable subword regularization.
79
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
80
+
81
+ - `nbest_size = {0,1}`: No sampling is performed.
82
+ - `nbest_size > 1`: samples from the nbest_size results.
83
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
84
+ using forward-filtering-and-backward-sampling algorithm.
85
+
86
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
87
+ BPE-dropout.
88
+
89
+ add_bos_token (`bool`, *optional*, defaults to `True`):
90
+ Whether or not to add an `bos_token` at the start of sequences.
91
+ add_eos_token (`bool`, *optional*, defaults to `False`):
92
+ Whether or not to add an `eos_token` at the end of sequences.
93
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
94
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
95
+ extra spaces.
96
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
97
+ Whether or not the default system prompt for Llama should be used.
98
+ spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
99
+ Whether or not to add spaces between special tokens.
100
+ legacy (`bool`, *optional*):
101
+ Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
102
+ and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
103
+ example:
104
+
105
+ - `legacy=True`:
106
+ ```python
107
+ >>> from transformers import T5Tokenizer
108
+
109
+ >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True)
110
+ >>> tokenizer.encode("Hello <extra_id_0>.")
111
+ [8774, 32099, 3, 5, 1]
112
+ ```
113
+ - `legacy=False`:
114
+ ```python
115
+ >>> from transformers import T5Tokenizer
116
+
117
+ >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
118
+ >>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
119
+ [8774, 32099, 5, 1]
120
+ ```
121
+ Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
122
+ add_prefix_space (`bool`, *optional*, defaults to `True`):
123
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
124
+ other word.
125
+
126
+ """
127
+
128
+ vocab_files_names = VOCAB_FILES_NAMES
129
+ model_input_names = ["input_ids", "attention_mask"]
130
+
131
+ def __init__(
132
+ self,
133
+ vocab_file,
134
+ unk_token="<unk>",
135
+ bos_token="<s>",
136
+ eos_token="</s>",
137
+ pad_token=None,
138
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
139
+ add_bos_token=True,
140
+ add_eos_token=False,
141
+ clean_up_tokenization_spaces=False,
142
+ use_default_system_prompt=False,
143
+ spaces_between_special_tokens=False,
144
+ legacy=None,
145
+ add_prefix_space=True,
146
+ **kwargs,
147
+ ):
148
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
149
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
150
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
151
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
152
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
153
+
154
+ if legacy is None:
155
+ logger.warning_once(
156
+ f"You are using the default legacy behaviour of the {self.__class__}. This is"
157
+ " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
158
+ " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
159
+ " means, and thoroughly read the reason why this was added as explained in"
160
+ " https://github.com/huggingface/transformers/pull/24565"
161
+ )
162
+ legacy = True
163
+
164
+ self.legacy = legacy
165
+ self.vocab_file = vocab_file
166
+ self.add_bos_token = add_bos_token
167
+ self.add_eos_token = add_eos_token
168
+ self.use_default_system_prompt = use_default_system_prompt
169
+ self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
170
+ self.add_prefix_space = add_prefix_space
171
+
172
+ super().__init__(
173
+ bos_token=bos_token,
174
+ eos_token=eos_token,
175
+ unk_token=unk_token,
176
+ pad_token=pad_token,
177
+ add_bos_token=add_bos_token,
178
+ add_eos_token=add_eos_token,
179
+ sp_model_kwargs=self.sp_model_kwargs,
180
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
181
+ use_default_system_prompt=use_default_system_prompt,
182
+ spaces_between_special_tokens=spaces_between_special_tokens,
183
+ legacy=legacy,
184
+ add_prefix_space=add_prefix_space,
185
+ **kwargs,
186
+ )
187
+
188
+ @property
189
+ def unk_token_length(self):
190
+ return len(self.sp_model.encode(str(self.unk_token)))
191
+
192
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
193
+ def get_spm_processor(self, from_slow=False):
194
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
195
+ if self.legacy or from_slow: # no dependency on protobuf
196
+ tokenizer.Load(self.vocab_file)
197
+ return tokenizer
198
+
199
+ with open(self.vocab_file, "rb") as f:
200
+ sp_model = f.read()
201
+ model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
202
+ model = model_pb2.ModelProto.FromString(sp_model)
203
+ normalizer_spec = model_pb2.NormalizerSpec()
204
+ normalizer_spec.add_dummy_prefix = False
205
+ model.normalizer_spec.MergeFrom(normalizer_spec)
206
+ sp_model = model.SerializeToString()
207
+ tokenizer.LoadFromSerializedProto(sp_model)
208
+ return tokenizer
209
+
210
+ def __getstate__(self):
211
+ state = self.__dict__.copy()
212
+ state["sp_model"] = None
213
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
214
+ return state
215
+
216
+ def __setstate__(self, d):
217
+ self.__dict__ = d
218
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
219
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
220
+
221
+ @property
222
+ def vocab_size(self):
223
+ """Returns vocab size"""
224
+ return self.sp_model.get_piece_size()
225
+
226
+ def get_vocab(self):
227
+ """Returns vocab as a dict"""
228
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
229
+ vocab.update(self.added_tokens_encoder)
230
+ return vocab
231
+
232
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
233
+ def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
234
+ """
235
+ Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
236
+ first token is special.
237
+ """
238
+ if self.legacy or len(text) == 0:
239
+ return super().tokenize(text, **kwargs)
240
+
241
+ text = text.replace(SPIECE_UNDERLINE, " ")
242
+ if self.add_prefix_space:
243
+ text = SPIECE_UNDERLINE + text
244
+
245
+ tokens = super().tokenize(text, **kwargs)
246
+
247
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
248
+ tokens = tokens[1:]
249
+ return tokens
250
+
251
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
252
+ def _tokenize(self, text, **kwargs):
253
+ """
254
+ Returns a tokenized string.
255
+
256
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
257
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
258
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
259
+ `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
260
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
261
+ """
262
+ tokens = self.sp_model.encode(text, out_type=str)
263
+ if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
264
+ return tokens
265
+
266
+ # 1. Encode string + prefix ex: "<unk> Hey"
267
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
268
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
269
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
270
+
271
+ def _convert_token_to_id(self, token):
272
+ """Converts a token (str) in an id using the vocab."""
273
+ return self.sp_model.piece_to_id(token)
274
+
275
+ def _convert_id_to_token(self, index):
276
+ """Converts an index (integer) in a token (str) using the vocab."""
277
+ token = self.sp_model.IdToPiece(index)
278
+ return token
279
+
280
+ def convert_tokens_to_string(self, tokens):
281
+ """Converts a sequence of tokens (string) in a single string."""
282
+ # since we manually add the prefix space, we have to remove it when decoding
283
+ if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
284
+ tokens[0] = tokens[0][1:]
285
+
286
+ current_sub_tokens = []
287
+ out_string = ""
288
+ prev_is_special = False
289
+ for i, token in enumerate(tokens):
290
+ # make sure that special tokens are not decoded using sentencepiece model
291
+ if token in self.all_special_tokens:
292
+ if not prev_is_special and i != 0 and self.legacy:
293
+ out_string += " "
294
+ out_string += self.sp_model.decode(current_sub_tokens) + token
295
+ prev_is_special = True
296
+ current_sub_tokens = []
297
+ else:
298
+ if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE):
299
+ out_string += " "
300
+ current_sub_tokens.append(token)
301
+ prev_is_special = False
302
+ out_string += self.sp_model.decode(current_sub_tokens)
303
+ return out_string
304
+
305
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
306
+ """
307
+ Save the vocabulary and special tokens file to a directory.
308
+
309
+ Args:
310
+ save_directory (`str`):
311
+ The directory in which to save the vocabulary.
312
+
313
+ Returns:
314
+ `Tuple(str)`: Paths to the files saved.
315
+ """
316
+ if not os.path.isdir(save_directory):
317
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
318
+ return
319
+ out_vocab_file = os.path.join(
320
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
321
+ )
322
+
323
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
324
+ copyfile(self.vocab_file, out_vocab_file)
325
+ elif not os.path.isfile(self.vocab_file):
326
+ with open(out_vocab_file, "wb") as fi:
327
+ content_spiece_model = self.sp_model.serialized_model_proto()
328
+ fi.write(content_spiece_model)
329
+
330
+ return (out_vocab_file,)
331
+
332
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
333
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
334
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
335
+
336
+ output = bos_token_id + token_ids_0 + eos_token_id
337
+
338
+ if token_ids_1 is not None:
339
+ output = output + bos_token_id + token_ids_1 + eos_token_id
340
+
341
+ return output
342
+
343
+ def get_special_tokens_mask(
344
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
345
+ ) -> List[int]:
346
+ """
347
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
348
+ special tokens using the tokenizer `prepare_for_model` method.
349
+
350
+ Args:
351
+ token_ids_0 (`List[int]`):
352
+ List of IDs.
353
+ token_ids_1 (`List[int]`, *optional*):
354
+ Optional second list of IDs for sequence pairs.
355
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
356
+ Whether or not the token list is already formatted with special tokens for the model.
357
+
358
+ Returns:
359
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
360
+ """
361
+ if already_has_special_tokens:
362
+ return super().get_special_tokens_mask(
363
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
364
+ )
365
+
366
+ bos_token_id = [1] if self.add_bos_token else []
367
+ eos_token_id = [1] if self.add_eos_token else []
368
+
369
+ if token_ids_1 is None:
370
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
371
+ return (
372
+ bos_token_id
373
+ + ([0] * len(token_ids_0))
374
+ + eos_token_id
375
+ + bos_token_id
376
+ + ([0] * len(token_ids_1))
377
+ + eos_token_id
378
+ )
379
+
380
+ def create_token_type_ids_from_sequences(
381
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
382
+ ) -> List[int]:
383
+ """
384
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
385
+ sequence pair mask has the following format:
386
+
387
+ ```
388
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
389
+ | first sequence | second sequence |
390
+ ```
391
+
392
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
393
+
394
+ Args:
395
+ token_ids_0 (`List[int]`):
396
+ List of ids.
397
+ token_ids_1 (`List[int]`, *optional*):
398
+ Optional second list of IDs for sequence pairs.
399
+
400
+ Returns:
401
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
402
+ """
403
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
404
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
405
+
406
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
407
+
408
+ if token_ids_1 is not None:
409
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
410
+
411
+ return output
412
+
413
+ @property
414
+ def default_chat_template(self):
415
+ """
416
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
417
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
418
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
419
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
420
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
421
+ to fine-tune a model with more flexible role ordering!
422
+
423
+ The output should look something like:
424
+
425
+ <bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
426
+ <bos>[INST] Prompt [/INST]
427
+
428
+ The reference for this chat template is [this code
429
+ snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
430
+ in the original repository.
431
+ """
432
+ logger.warning_once(
433
+ "\nNo chat template is defined for this tokenizer - using the default template "
434
+ f"for the {self.__class__.__name__} class. If the default is not appropriate for "
435
+ "your model, please set `tokenizer.chat_template` to an appropriate template. "
436
+ "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
437
+ )
438
+ template = (
439
+ "{% if messages[0]['role'] == 'system' %}"
440
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
441
+ "{% set system_message = messages[0]['content'] %}"
442
+ "{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
443
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
444
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
445
+ "{% else %}"
446
+ "{% set loop_messages = messages %}"
447
+ "{% set system_message = false %}"
448
+ "{% endif %}"
449
+ "{% for message in loop_messages %}" # Loop over all non-system messages
450
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
451
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
452
+ "{% endif %}"
453
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
454
+ "{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
455
+ "{% else %}"
456
+ "{% set content = message['content'] %}"
457
+ "{% endif %}"
458
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
459
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
460
+ "{% elif message['role'] == 'system' %}"
461
+ "{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
462
+ "{% elif message['role'] == 'assistant' %}"
463
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
464
+ "{% endif %}"
465
+ "{% endfor %}"
466
+ )
467
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
468
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
469
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
470
+
471
+ return template
bunny/model/language_model/llama/tokenization_llama_fast.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ from shutil import copyfile
17
+ from typing import Optional, Tuple
18
+
19
+ from tokenizers import processors
20
+
21
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
22
+ from transformers.utils import is_sentencepiece_available, logging
23
+ from transformers.utils.versions import require_version
24
+
25
+
26
+ require_version("tokenizers>=0.13.3")
27
+
28
+ if is_sentencepiece_available():
29
+ from .tokenization_llama import LlamaTokenizer
30
+ else:
31
+ LlamaTokenizer = None
32
+
33
+ logger = logging.get_logger(__name__)
34
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
35
+
36
+ B_INST, E_INST = "[INST]", "[/INST]"
37
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
38
+
39
+ # fmt: off
40
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
41
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
42
+ that your responses are socially unbiased and positive in nature.
43
+
44
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
45
+ correct. If you don't know the answer to a question, please don't share false information."""
46
+ # fmt: on
47
+
48
+
49
+ class LlamaTokenizerFast(PreTrainedTokenizerFast):
50
+ """
51
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
52
+
53
+ This uses notably ByteFallback and no normalization.
54
+
55
+ ```python
56
+ >>> from transformers import LlamaTokenizerFast
57
+
58
+ >>> tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
59
+ >>> tokenizer.encode("Hello this is a test")
60
+ [1, 15043, 445, 338, 263, 1243]
61
+ ```
62
+
63
+ If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
64
+ call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
65
+ values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
66
+ [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
67
+
68
+
69
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
70
+ refer to this superclass for more information regarding those methods.
71
+
72
+ Args:
73
+ vocab_file (`str`, *optional*):
74
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
75
+ contains the vocabulary necessary to instantiate a tokenizer.
76
+ tokenizer_file (`str`, *optional*):
77
+ [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
78
+ contains everything needed to load the tokenizer.
79
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
80
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
81
+ extra spaces.
82
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
83
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
84
+ token instead.
85
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
86
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
87
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
88
+ The end of sequence token.
89
+ add_bos_token (`bool`, *optional*, defaults to `True`):
90
+ Whether or not to add an `bos_token` at the start of sequences.
91
+ add_eos_token (`bool`, *optional*, defaults to `False`):
92
+ Whether or not to add an `eos_token` at the end of sequences.
93
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
94
+ Whether or not the default system prompt for Llama should be used.
95
+ add_prefix_space (`bool`, *optional*):
96
+ Whether or not the tokenizer should automatically add a prefix space
97
+ """
98
+
99
+ vocab_files_names = VOCAB_FILES_NAMES
100
+ slow_tokenizer_class = LlamaTokenizer
101
+ padding_side = "left"
102
+ model_input_names = ["input_ids", "attention_mask"]
103
+
104
+ def __init__(
105
+ self,
106
+ vocab_file=None,
107
+ tokenizer_file=None,
108
+ clean_up_tokenization_spaces=False,
109
+ unk_token="<unk>",
110
+ bos_token="<s>",
111
+ eos_token="</s>",
112
+ add_bos_token=True,
113
+ add_eos_token=False,
114
+ use_default_system_prompt=False,
115
+ add_prefix_space=None,
116
+ **kwargs,
117
+ ):
118
+ if add_prefix_space is not None:
119
+ logger.warning_once(
120
+ "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"
121
+ )
122
+ kwargs["from_slow"] = True
123
+
124
+ super().__init__(
125
+ vocab_file=vocab_file,
126
+ tokenizer_file=tokenizer_file,
127
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
128
+ unk_token=unk_token,
129
+ bos_token=bos_token,
130
+ eos_token=eos_token,
131
+ add_bos_token=add_bos_token,
132
+ add_eos_token=add_eos_token,
133
+ use_default_system_prompt=use_default_system_prompt,
134
+ **kwargs,
135
+ )
136
+ self._add_bos_token = add_bos_token
137
+ self._add_eos_token = add_eos_token
138
+ self.update_post_processor()
139
+ self.use_default_system_prompt = use_default_system_prompt
140
+ self.vocab_file = vocab_file
141
+
142
+ @property
143
+ def can_save_slow_tokenizer(self) -> bool:
144
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
145
+
146
+ def update_post_processor(self):
147
+ """
148
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
149
+ """
150
+ bos = self.bos_token
151
+ bos_token_id = self.bos_token_id
152
+ if bos is None and self.add_bos_token:
153
+ raise ValueError("add_bos_token = True but bos_token = None")
154
+
155
+ eos = self.eos_token
156
+ eos_token_id = self.eos_token_id
157
+ if eos is None and self.add_eos_token:
158
+ raise ValueError("add_eos_token = True but eos_token = None")
159
+
160
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
161
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
162
+
163
+ special_tokens = []
164
+ if self.add_bos_token:
165
+ special_tokens.append((bos, bos_token_id))
166
+ if self.add_eos_token:
167
+ special_tokens.append((eos, eos_token_id))
168
+ self._tokenizer.post_processor = processors.TemplateProcessing(
169
+ single=single, pair=pair, special_tokens=special_tokens
170
+ )
171
+
172
+ @property
173
+ def add_eos_token(self):
174
+ return self._add_eos_token
175
+
176
+ @property
177
+ def add_bos_token(self):
178
+ return self._add_bos_token
179
+
180
+ @add_eos_token.setter
181
+ def add_eos_token(self, value):
182
+ self._add_eos_token = value
183
+ self.update_post_processor()
184
+
185
+ @add_bos_token.setter
186
+ def add_bos_token(self, value):
187
+ self._add_bos_token = value
188
+ self.update_post_processor()
189
+
190
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
191
+ if not self.can_save_slow_tokenizer:
192
+ raise ValueError(
193
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
194
+ "tokenizer."
195
+ )
196
+
197
+ if not os.path.isdir(save_directory):
198
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
199
+ return
200
+ out_vocab_file = os.path.join(
201
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
202
+ )
203
+
204
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
205
+ copyfile(self.vocab_file, out_vocab_file)
206
+
207
+ return (out_vocab_file,)
208
+
209
+ @property
210
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
211
+ def default_chat_template(self):
212
+ """
213
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
214
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
215
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
216
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
217
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
218
+ to fine-tune a model with more flexible role ordering!
219
+
220
+ The output should look something like:
221
+
222
+ <bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
223
+ <bos>[INST] Prompt [/INST]
224
+
225
+ The reference for this chat template is [this code
226
+ snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
227
+ in the original repository.
228
+ """
229
+ logger.warning_once(
230
+ "\nNo chat template is defined for this tokenizer - using the default template "
231
+ f"for the {self.__class__.__name__} class. If the default is not appropriate for "
232
+ "your model, please set `tokenizer.chat_template` to an appropriate template. "
233
+ "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
234
+ )
235
+ template = (
236
+ "{% if messages[0]['role'] == 'system' %}"
237
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
238
+ "{% set system_message = messages[0]['content'] %}"
239
+ "{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
240
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
241
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
242
+ "{% else %}"
243
+ "{% set loop_messages = messages %}"
244
+ "{% set system_message = false %}"
245
+ "{% endif %}"
246
+ "{% for message in loop_messages %}" # Loop over all non-system messages
247
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
248
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
249
+ "{% endif %}"
250
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
251
+ "{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
252
+ "{% else %}"
253
+ "{% set content = message['content'] %}"
254
+ "{% endif %}"
255
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
256
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
257
+ "{% elif message['role'] == 'system' %}"
258
+ "{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
259
+ "{% elif message['role'] == 'assistant' %}"
260
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
261
+ "{% endif %}"
262
+ "{% endfor %}"
263
+ )
264
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
265
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
266
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
267
+
268
+ return template
269
+
270
+ # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
271
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
272
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
273
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
274
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
275
+
276
+ output = bos_token_id + token_ids_0 + eos_token_id
277
+
278
+ if token_ids_1 is not None:
279
+ output = output + bos_token_id + token_ids_1 + eos_token_id
280
+
281
+ return output
bunny/model/language_model/minicpm/configuration_minicpm.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ MiniCPM model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
29
+
30
+
31
+ class MiniCPMConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
34
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35
+ defaults will yield a similar configuration to that of the MiniCPM-7B.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32000):
43
+ Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`MiniCPMModel`]
45
+ hidden_size (`int`, *optional*, defaults to 4096):
46
+ Dimension of the hidden representations.
47
+ intermediate_size (`int`, *optional*, defaults to 11008):
48
+ Dimension of the MLP representations.
49
+ num_hidden_layers (`int`, *optional*, defaults to 32):
50
+ Number of hidden layers in the Transformer decoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 32):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ num_key_value_heads (`int`, *optional*):
54
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58
+ by meanpooling all the original heads within that group. For more details checkout [this
59
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60
+ `num_attention_heads`.
61
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
+ The non-linear activation function (function or string) in the decoder.
63
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
64
+ The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
65
+ MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69
+ The epsilon used by the rms normalization layers.
70
+ use_cache (`bool`, *optional*, defaults to `True`):
71
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
72
+ relevant if `config.is_decoder=True`.
73
+ pad_token_id (`int`, *optional*):
74
+ Padding token id.
75
+ bos_token_id (`int`, *optional*, defaults to 1):
76
+ Beginning of stream token id.
77
+ eos_token_id (`int`, *optional*, defaults to 2):
78
+ End of stream token id.
79
+ pretraining_tp (`int`, *optional*, defaults to 1):
80
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
81
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
82
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
83
+ issue](https://github.com/pytorch/pytorch/issues/76232).
84
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
85
+ Whether to tie weight embeddings
86
+ rope_theta (`float`, *optional*, defaults to 10000.0):
87
+ The base period of the RoPE embeddings.
88
+ rope_scaling (`Dict`, *optional*):
89
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
90
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
91
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
92
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
93
+ these scaling strategies behave:
94
+ https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
95
+ experimental feature, subject to breaking API changes in future versions.
96
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
97
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
98
+ attention_dropout (`float`, *optional*, defaults to 0.0):
99
+ The dropout ratio for the attention probabilities.
100
+
101
+ ```python
102
+ >>> from transformers import MiniCPMModel, MiniCPMConfig
103
+
104
+ >>> # Initializing a MiniCPM minicpm-7b style configuration
105
+ >>> configuration = MiniCPMConfig()
106
+
107
+ >>> # Initializing a model from the minicpm-7b style configuration
108
+ >>> model = MiniCPMModel(configuration)
109
+
110
+ >>> # Accessing the model configuration
111
+ >>> configuration = model.config
112
+ ```"""
113
+
114
+ model_type = "minicpm"
115
+ keys_to_ignore_at_inference = ["past_key_values"]
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_size=32000,
120
+ hidden_size=4096,
121
+ intermediate_size=11008,
122
+ num_hidden_layers=32,
123
+ num_attention_heads=32,
124
+ num_key_value_heads=None,
125
+ hidden_act="silu",
126
+ max_position_embeddings=2048,
127
+ initializer_range=0.02,
128
+ rms_norm_eps=1e-6,
129
+ use_cache=True,
130
+ pad_token_id=None,
131
+ bos_token_id=1,
132
+ eos_token_id=2,
133
+ pretraining_tp=1,
134
+ tie_word_embeddings=True,
135
+ rope_theta=10000.0,
136
+ rope_scaling=None,
137
+ attention_bias=False,
138
+ attention_dropout=0.0,
139
+ scale_emb=1,
140
+ dim_model_base=1,
141
+ scale_depth=1,
142
+ **kwargs,
143
+ ):
144
+ self.vocab_size = vocab_size
145
+ self.max_position_embeddings = max_position_embeddings
146
+ self.hidden_size = hidden_size
147
+ self.intermediate_size = intermediate_size
148
+ self.num_hidden_layers = num_hidden_layers
149
+ self.num_attention_heads = num_attention_heads
150
+
151
+ # for backward compatibility
152
+ if num_key_value_heads is None:
153
+ num_key_value_heads = num_attention_heads
154
+
155
+ self.num_key_value_heads = num_key_value_heads
156
+ self.hidden_act = hidden_act
157
+ self.initializer_range = initializer_range
158
+ self.rms_norm_eps = rms_norm_eps
159
+ self.pretraining_tp = pretraining_tp
160
+ self.use_cache = use_cache
161
+ self.rope_theta = rope_theta
162
+ self.rope_scaling = rope_scaling
163
+ self._rope_scaling_validation()
164
+ self.attention_bias = attention_bias
165
+ self.attention_dropout = attention_dropout
166
+ self.scale_emb = scale_emb
167
+ self.dim_model_base = dim_model_base
168
+ self.scale_depth = scale_depth
169
+
170
+ super().__init__(
171
+ pad_token_id=pad_token_id,
172
+ bos_token_id=bos_token_id,
173
+ eos_token_id=eos_token_id,
174
+ tie_word_embeddings=tie_word_embeddings,
175
+ **kwargs,
176
+ )
177
+ try:
178
+ import flash_attn
179
+ self._attn_implementation = "flash_attention_2"
180
+ except:
181
+ pass
182
+
183
+ def _rope_scaling_validation(self):
184
+ """
185
+ Validate the `rope_scaling` configuration.
186
+ """
187
+ if self.rope_scaling is None:
188
+ return
189
+
190
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
191
+ raise ValueError(
192
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
193
+ f"got {self.rope_scaling}"
194
+ )
195
+ rope_scaling_type = self.rope_scaling.get("type", None)
196
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
197
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
198
+ raise ValueError(
199
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
200
+ )
201
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
202
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
bunny/model/language_model/minicpm/modeling_minicpm.py ADDED
@@ -0,0 +1,1456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch MiniCPM model."""
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union, Dict
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.modeling_attn_mask_utils import (
34
+ AttentionMaskConverter,
35
+ _prepare_4d_attention_mask,
36
+ _prepare_4d_causal_attention_mask,
37
+ _prepare_4d_causal_attention_mask_for_sdpa,
38
+ )
39
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
40
+ from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
42
+ from transformers.utils import (
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ is_flash_attn_2_available,
46
+ is_flash_attn_greater_or_equal_2_10,
47
+ logging,
48
+ replace_return_docstrings,
49
+ )
50
+ from transformers.utils.import_utils import is_torch_fx_available
51
+ from .configuration_minicpm import MiniCPMConfig
52
+ import re
53
+
54
+ try:
55
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
56
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
+ except:
58
+ pass
59
+
60
+
61
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
62
+ # It means that the function will not be traced through and simply appear as a node in the graph.
63
+ if is_torch_fx_available():
64
+ if not is_torch_greater_or_equal_than_1_13:
65
+ import torch.fx
66
+
67
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
68
+
69
+
70
+ logger = logging.get_logger(__name__)
71
+
72
+ _CONFIG_FOR_DOC = "MiniCPMConfig"
73
+
74
+
75
+ def _get_unpad_data(attention_mask):
76
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
77
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
78
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
79
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
80
+ return (
81
+ indices,
82
+ cu_seqlens,
83
+ max_seqlen_in_batch,
84
+ )
85
+
86
+
87
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
88
+ warnings.warn(
89
+ "Calling `transformers.models.minicpm.modeling_minicpm._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
90
+ )
91
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
92
+
93
+
94
+ def _make_causal_mask(
95
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
96
+ ):
97
+ warnings.warn(
98
+ "Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask"
99
+ )
100
+ return AttentionMaskConverter._make_causal_mask(
101
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
102
+ )
103
+
104
+ # @torch.jit.script # type: ignore
105
+ def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
106
+ old_dtype = hidden.dtype
107
+ variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
108
+ hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
109
+ return hidden * weight
110
+
111
+
112
+ class MiniCPMRMSNorm(nn.Module):
113
+ def __init__(self, hidden_size, eps=1e-6):
114
+ """
115
+ MiniCPMRMSNorm is equivalent to T5LayerNorm
116
+ """
117
+ super().__init__()
118
+ self.weight = nn.Parameter(torch.ones(hidden_size))
119
+ self.variance_epsilon = eps
120
+
121
+ def forward(self, hidden_states):
122
+ return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
123
+
124
+
125
+ ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm)
126
+
127
+
128
+ class MiniCPMRotaryEmbedding(nn.Module):
129
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
130
+ super().__init__()
131
+
132
+ self.dim = dim
133
+ self.max_position_embeddings = max_position_embeddings
134
+ self.base = base
135
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
136
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
137
+
138
+ # Build here to make `torch.jit.trace` work.
139
+ self._set_cos_sin_cache(
140
+ # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
141
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
142
+ )
143
+
144
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
145
+ self.max_seq_len_cached = seq_len
146
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
147
+ freqs = torch.outer(t, self.inv_freq)
148
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
149
+ emb = torch.cat((freqs, freqs), dim=-1)
150
+
151
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
152
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
153
+
154
+ def forward(self, x, seq_len=None):
155
+ # x: [bs, num_attention_heads, seq_len, head_size]
156
+ if seq_len > self.max_seq_len_cached:
157
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
158
+
159
+ return (
160
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
161
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
162
+ )
163
+
164
+
165
+ class MiniCPMLinearScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
166
+ """MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
167
+
168
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
169
+ self.scaling_factor = scaling_factor
170
+ super().__init__(dim, max_position_embeddings, base, device)
171
+
172
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
173
+ self.max_seq_len_cached = seq_len
174
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
175
+ t = t / self.scaling_factor
176
+
177
+ freqs = torch.outer(t, self.inv_freq)
178
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
179
+ emb = torch.cat((freqs, freqs), dim=-1)
180
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
181
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
182
+
183
+
184
+ class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
185
+ """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
186
+
187
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
188
+ self.scaling_factor = scaling_factor
189
+ super().__init__(dim, max_position_embeddings, base, device)
190
+
191
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
192
+ self.max_seq_len_cached = seq_len
193
+
194
+ if seq_len > self.max_position_embeddings:
195
+ base = self.base * (
196
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
197
+ ) ** (self.dim / (self.dim - 2))
198
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
199
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
200
+
201
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
202
+
203
+ freqs = torch.outer(t, self.inv_freq)
204
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
205
+ emb = torch.cat((freqs, freqs), dim=-1)
206
+
207
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
208
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
209
+
210
+
211
+ def rotate_half(x):
212
+ """Rotates half the hidden dims of the input."""
213
+ x1 = x[..., : x.shape[-1] // 2]
214
+ x2 = x[..., x.shape[-1] // 2 :]
215
+ return torch.cat((-x2, x1), dim=-1)
216
+
217
+
218
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
219
+ """Applies Rotary Position Embedding to the query and key tensors.
220
+
221
+ Args:
222
+ q (`torch.Tensor`): The query tensor.
223
+ k (`torch.Tensor`): The key tensor.
224
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
225
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
226
+ position_ids (`torch.Tensor`):
227
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
228
+ used to pass offsetted position ids when working with a KV-cache.
229
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
230
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
231
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
232
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
233
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
234
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
235
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
236
+ Returns:
237
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
238
+ """
239
+ # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
240
+ # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
241
+ # q_embed = (q * cos) + (rotate_half(q) * sin)
242
+ # k_embed = (k * cos) + (rotate_half(k) * sin)
243
+ orig_dtype = k.dtype
244
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
245
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
246
+ q_fp32 = q.to(dtype=torch.float32, device=q.device)
247
+ k_fp32 = k.to(dtype=torch.float32, device=k.device)
248
+ q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
249
+ k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
250
+ return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
251
+
252
+ class MiniCPMMLP(nn.Module):
253
+ def __init__(self, config):
254
+ super().__init__()
255
+ self.config = config
256
+ self.hidden_size = config.hidden_size
257
+ self.intermediate_size = config.intermediate_size
258
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
259
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
260
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
261
+ self.act_fn = ACT2FN[config.hidden_act]
262
+
263
+ def forward(self, x):
264
+ if self.config.pretraining_tp > 1:
265
+ slice = self.intermediate_size // self.config.pretraining_tp
266
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
267
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
268
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
269
+
270
+ gate_proj = torch.cat(
271
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
272
+ )
273
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
274
+
275
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
276
+ down_proj = [
277
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
278
+ ]
279
+ down_proj = sum(down_proj)
280
+ else:
281
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
282
+
283
+ return down_proj
284
+
285
+
286
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
287
+ """
288
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
289
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
290
+ """
291
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
292
+ if n_rep == 1:
293
+ return hidden_states
294
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
295
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
296
+
297
+
298
+
299
+ class MiniCPMAttention(nn.Module):
300
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
301
+
302
+ def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):
303
+ super().__init__()
304
+ self.config = config
305
+ self.layer_idx = layer_idx
306
+ if layer_idx is None:
307
+ logger.warning_once(
308
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
309
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
310
+ "when creating this class."
311
+ )
312
+
313
+ self.attention_dropout = config.attention_dropout
314
+ self.hidden_size = config.hidden_size
315
+ self.num_heads = config.num_attention_heads
316
+ self.head_dim = self.hidden_size // self.num_heads
317
+ self.num_key_value_heads = config.num_key_value_heads
318
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
319
+ self.max_position_embeddings = config.max_position_embeddings
320
+ self.rope_theta = config.rope_theta
321
+ self.is_causal = True
322
+
323
+ if (self.head_dim * self.num_heads) != self.hidden_size:
324
+ raise ValueError(
325
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
326
+ f" and `num_heads`: {self.num_heads})."
327
+ )
328
+
329
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
330
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
331
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
332
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
333
+ self._init_rope()
334
+
335
+ def _init_rope(self):
336
+ if self.config.rope_scaling is None:
337
+ self.rotary_emb = MiniCPMRotaryEmbedding(
338
+ self.head_dim,
339
+ max_position_embeddings=self.max_position_embeddings,
340
+ base=self.rope_theta,
341
+ )
342
+ else:
343
+ scaling_type = self.config.rope_scaling["type"]
344
+ scaling_factor = self.config.rope_scaling["factor"]
345
+ if scaling_type == "linear":
346
+ self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
347
+ self.head_dim,
348
+ max_position_embeddings=self.max_position_embeddings,
349
+ scaling_factor=scaling_factor,
350
+ base=self.rope_theta,
351
+ )
352
+ elif scaling_type == "dynamic":
353
+ self.rotary_emb = MiniCPMDynamicNTKScalingRotaryEmbedding(
354
+ self.head_dim,
355
+ max_position_embeddings=self.max_position_embeddings,
356
+ scaling_factor=scaling_factor,
357
+ base=self.rope_theta,
358
+ )
359
+ else:
360
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
361
+
362
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
363
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
364
+
365
+ def forward(
366
+ self,
367
+ hidden_states: torch.Tensor,
368
+ attention_mask: Optional[torch.Tensor] = None,
369
+ position_ids: Optional[torch.LongTensor] = None,
370
+ past_key_value: Optional[Cache] = None,
371
+ output_attentions: bool = False,
372
+ use_cache: bool = False,
373
+ **kwargs,
374
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
375
+ if "padding_mask" in kwargs:
376
+ warnings.warn(
377
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
378
+ )
379
+
380
+ bsz, q_len, _ = hidden_states.size()
381
+
382
+ if self.config.pretraining_tp > 1:
383
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
384
+ query_slices = self.q_proj.weight.split(
385
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
386
+ )
387
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
388
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
389
+
390
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
391
+ query_states = torch.cat(query_states, dim=-1)
392
+
393
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
394
+ key_states = torch.cat(key_states, dim=-1)
395
+
396
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
397
+ value_states = torch.cat(value_states, dim=-1)
398
+
399
+ else:
400
+ query_states = self.q_proj(hidden_states)
401
+ key_states = self.k_proj(hidden_states)
402
+ value_states = self.v_proj(hidden_states)
403
+
404
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
405
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
406
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
407
+
408
+ kv_seq_len = key_states.shape[-2]
409
+ if past_key_value is not None:
410
+ if self.layer_idx is None:
411
+ raise ValueError(
412
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
413
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
414
+ "with a layer index."
415
+ )
416
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
417
+ cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
418
+
419
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
420
+
421
+ if past_key_value is not None:
422
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
423
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
424
+
425
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
426
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
427
+
428
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
429
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
430
+ raise ValueError(
431
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
432
+ f" {attn_weights.size()}"
433
+ )
434
+
435
+ if attention_mask is not None:
436
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
437
+ raise ValueError(
438
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
439
+ )
440
+ attn_weights = attn_weights + attention_mask
441
+
442
+ # upcast attention to fp32
443
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
444
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
445
+ attn_output = torch.matmul(attn_weights, value_states)
446
+
447
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
448
+ raise ValueError(
449
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
450
+ f" {attn_output.size()}"
451
+ )
452
+
453
+ attn_output = attn_output.transpose(1, 2).contiguous()
454
+
455
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
456
+
457
+ if self.config.pretraining_tp > 1:
458
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
459
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
460
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
461
+ else:
462
+ attn_output = self.o_proj(attn_output)
463
+
464
+ if not output_attentions:
465
+ attn_weights = None
466
+
467
+ return attn_output, attn_weights, past_key_value
468
+
469
+
470
+ class MiniCPMFlashAttention2(MiniCPMAttention):
471
+ """
472
+ MiniCPM flash attention module. This module inherits from `MiniCPMAttention` as the weights of the module stays
473
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
474
+ flash attention and deal with padding tokens in case the input contains any of them.
475
+ """
476
+
477
+ def __init__(self, *args, **kwargs):
478
+ super().__init__(*args, **kwargs)
479
+
480
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
481
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
482
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
483
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ attention_mask: Optional[torch.LongTensor] = None,
489
+ position_ids: Optional[torch.LongTensor] = None,
490
+ past_key_value: Optional[Cache] = None,
491
+ output_attentions: bool = False,
492
+ use_cache: bool = False,
493
+ **kwargs,
494
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
495
+ # MiniCPMFlashAttention2 attention does not support output_attentions
496
+ if "padding_mask" in kwargs:
497
+ warnings.warn(
498
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
499
+ )
500
+
501
+ # overwrite attention_mask with padding_mask
502
+ attention_mask = kwargs.pop("padding_mask")
503
+
504
+ output_attentions = False
505
+
506
+ bsz, q_len, _ = hidden_states.size()
507
+
508
+ query_states = self.q_proj(hidden_states)
509
+ key_states = self.k_proj(hidden_states)
510
+ value_states = self.v_proj(hidden_states)
511
+
512
+ # Flash attention requires the input to have the shape
513
+ # batch_size x seq_length x head_dim x hidden_dim
514
+ # therefore we just need to keep the original shape
515
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
516
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
517
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
518
+
519
+ kv_seq_len = key_states.shape[-2]
520
+ if past_key_value is not None:
521
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
522
+ cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
523
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
524
+
525
+ if past_key_value is not None:
526
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
527
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
528
+
529
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
530
+ # to be able to avoid many of these transpose/reshape/view.
531
+ query_states = query_states.transpose(1, 2)
532
+ key_states = key_states.transpose(1, 2)
533
+ value_states = value_states.transpose(1, 2)
534
+
535
+ dropout_rate = self.attention_dropout if self.training else 0.0
536
+
537
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
538
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
539
+ # cast them back in the correct dtype just to be sure everything works as expected.
540
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
541
+ # in fp32. (MiniCPMRMSNorm handles it correctly)
542
+
543
+ input_dtype = query_states.dtype
544
+ if input_dtype == torch.float32:
545
+ # Handle the case where the model is quantized
546
+ if hasattr(self.config, "_pre_quantization_dtype"):
547
+ target_dtype = self.config._pre_quantization_dtype
548
+ else:
549
+ target_dtype = self.q_proj.weight.dtype
550
+
551
+ logger.warning_once(
552
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
553
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
554
+ f" {target_dtype}."
555
+ )
556
+
557
+ query_states = query_states.to(target_dtype)
558
+ key_states = key_states.to(target_dtype)
559
+ value_states = value_states.to(target_dtype)
560
+
561
+ attn_output = self._flash_attention_forward(
562
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
563
+ )
564
+
565
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
566
+ attn_output = self.o_proj(attn_output)
567
+
568
+ if not output_attentions:
569
+ attn_weights = None
570
+
571
+ return attn_output, attn_weights, past_key_value
572
+
573
+ def _flash_attention_forward(
574
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
575
+ ):
576
+ """
577
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
578
+ first unpad the input, then computes the attention scores and pad the final attention scores.
579
+
580
+ Args:
581
+ query_states (`torch.Tensor`):
582
+ Input query states to be passed to Flash Attention API
583
+ key_states (`torch.Tensor`):
584
+ Input key states to be passed to Flash Attention API
585
+ value_states (`torch.Tensor`):
586
+ Input value states to be passed to Flash Attention API
587
+ attention_mask (`torch.Tensor`):
588
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
589
+ position of padding tokens and 1 for the position of non-padding tokens.
590
+ dropout (`int`, *optional*):
591
+ Attention dropout
592
+ softmax_scale (`float`, *optional*):
593
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
594
+ """
595
+ if not self._flash_attn_uses_top_left_mask:
596
+ causal = self.is_causal
597
+ else:
598
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
599
+ causal = self.is_causal and query_length != 1
600
+ # Contains at least one padding token in the sequence
601
+ if attention_mask is not None:
602
+ batch_size = query_states.shape[0]
603
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
604
+ query_states, key_states, value_states, attention_mask, query_length
605
+ )
606
+
607
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
608
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
609
+ attn_output_unpad = flash_attn_varlen_func(
610
+ query_states,
611
+ key_states,
612
+ value_states,
613
+ cu_seqlens_q=cu_seqlens_q,
614
+ cu_seqlens_k=cu_seqlens_k,
615
+ max_seqlen_q=max_seqlen_in_batch_q,
616
+ max_seqlen_k=max_seqlen_in_batch_k,
617
+ dropout_p=dropout,
618
+ softmax_scale=softmax_scale,
619
+ causal=causal,
620
+ )
621
+
622
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
623
+ else:
624
+ attn_output = flash_attn_func(
625
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
626
+ )
627
+
628
+ return attn_output
629
+
630
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
631
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
632
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
633
+
634
+ key_layer = index_first_axis(
635
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
636
+ )
637
+ value_layer = index_first_axis(
638
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
639
+ )
640
+ if query_length == kv_seq_len:
641
+ query_layer = index_first_axis(
642
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
643
+ )
644
+ cu_seqlens_q = cu_seqlens_k
645
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
646
+ indices_q = indices_k
647
+ elif query_length == 1:
648
+ max_seqlen_in_batch_q = 1
649
+ cu_seqlens_q = torch.arange(
650
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
651
+ ) # There is a memcpy here, that is very bad.
652
+ indices_q = cu_seqlens_q[:-1]
653
+ query_layer = query_layer.squeeze(1)
654
+ else:
655
+ # The -q_len: slice assumes left padding.
656
+ attention_mask = attention_mask[:, -query_length:]
657
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
658
+
659
+ return (
660
+ query_layer,
661
+ key_layer,
662
+ value_layer,
663
+ indices_q,
664
+ (cu_seqlens_q, cu_seqlens_k),
665
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
666
+ )
667
+
668
+
669
+ class MiniCPMSdpaAttention(MiniCPMAttention):
670
+ """
671
+ MiniCPM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
672
+ `MiniCPMAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
673
+ SDPA API.
674
+ """
675
+
676
+ # Adapted from MiniCPMAttention.forward
677
+ def forward(
678
+ self,
679
+ hidden_states: torch.Tensor,
680
+ attention_mask: Optional[torch.Tensor] = None,
681
+ position_ids: Optional[torch.LongTensor] = None,
682
+ past_key_value: Optional[Cache] = None,
683
+ output_attentions: bool = False,
684
+ use_cache: bool = False,
685
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
686
+ if output_attentions:
687
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
688
+ logger.warning_once(
689
+ "MiniCPMModel is using MiniCPMSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
690
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
691
+ )
692
+ return super().forward(
693
+ hidden_states=hidden_states,
694
+ attention_mask=attention_mask,
695
+ position_ids=position_ids,
696
+ past_key_value=past_key_value,
697
+ output_attentions=output_attentions,
698
+ use_cache=use_cache,
699
+ )
700
+
701
+ bsz, q_len, _ = hidden_states.size()
702
+
703
+ query_states = self.q_proj(hidden_states)
704
+ key_states = self.k_proj(hidden_states)
705
+ value_states = self.v_proj(hidden_states)
706
+
707
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
708
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
709
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
710
+
711
+ kv_seq_len = key_states.shape[-2]
712
+ if past_key_value is not None:
713
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
714
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
715
+
716
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
717
+
718
+ if past_key_value is not None:
719
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
720
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
721
+
722
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
723
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
724
+
725
+ if attention_mask is not None:
726
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
727
+ raise ValueError(
728
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
729
+ )
730
+
731
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
732
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
733
+ if query_states.device.type == "cuda" and attention_mask is not None:
734
+ query_states = query_states.contiguous()
735
+ key_states = key_states.contiguous()
736
+ value_states = value_states.contiguous()
737
+
738
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
739
+ query_states,
740
+ key_states,
741
+ value_states,
742
+ attn_mask=attention_mask,
743
+ dropout_p=self.attention_dropout if self.training else 0.0,
744
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
745
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
746
+ )
747
+
748
+ attn_output = attn_output.transpose(1, 2).contiguous()
749
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
750
+
751
+ attn_output = self.o_proj(attn_output)
752
+
753
+ return attn_output, None, past_key_value
754
+
755
+
756
+ MINICPM_ATTENTION_CLASSES = {
757
+ "eager": MiniCPMAttention,
758
+ "flash_attention_2": MiniCPMFlashAttention2,
759
+ "sdpa": MiniCPMSdpaAttention,
760
+ }
761
+
762
+
763
+ class MiniCPMDecoderLayer(nn.Module):
764
+ def __init__(self, config: MiniCPMConfig, layer_idx: int):
765
+ super().__init__()
766
+ self.hidden_size = config.hidden_size
767
+ self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
768
+
769
+ self.mlp = MiniCPMMLP(config)
770
+ self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
771
+ self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
772
+
773
+ self.scale_depth = config.scale_depth
774
+ self.num_hidden_layers = config.num_hidden_layers
775
+
776
+ def forward(
777
+ self,
778
+ hidden_states: torch.Tensor,
779
+ attention_mask: Optional[torch.Tensor] = None,
780
+ position_ids: Optional[torch.LongTensor] = None,
781
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
782
+ output_attentions: Optional[bool] = False,
783
+ use_cache: Optional[bool] = False,
784
+ **kwargs,
785
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
786
+ """
787
+ Args:
788
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
789
+ attention_mask (`torch.FloatTensor`, *optional*):
790
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
791
+ query_sequence_length, key_sequence_length)` if default attention is used.
792
+ output_attentions (`bool`, *optional*):
793
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
794
+ returned tensors for more detail.
795
+ use_cache (`bool`, *optional*):
796
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
797
+ (see `past_key_values`).
798
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
799
+ """
800
+ if "padding_mask" in kwargs:
801
+ warnings.warn(
802
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
803
+ )
804
+
805
+ residual = hidden_states
806
+ hidden_states = self.input_layernorm(hidden_states)
807
+ # Self Attention
808
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
809
+ hidden_states=hidden_states,
810
+ attention_mask=attention_mask,
811
+ position_ids=position_ids,
812
+ past_key_value=past_key_value,
813
+ output_attentions=output_attentions,
814
+ use_cache=use_cache,
815
+ **kwargs,
816
+ )
817
+
818
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
819
+
820
+ # Fully Connected
821
+ residual = hidden_states
822
+ hidden_states = self.post_attention_layernorm(hidden_states)
823
+
824
+ hidden_states = self.mlp(hidden_states)
825
+ hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
826
+
827
+ outputs = (hidden_states,)
828
+
829
+ if output_attentions:
830
+ outputs += (self_attn_weights,)
831
+
832
+ if use_cache:
833
+ outputs += (present_key_value,)
834
+
835
+ return outputs
836
+
837
+
838
+ MINICPM_START_DOCSTRING = r"""
839
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
840
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
841
+ etc.)
842
+
843
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
844
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
845
+ and behavior.
846
+
847
+ Parameters:
848
+ config ([`MiniCPMConfig`]):
849
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
850
+ load the weights associated with the model, only the configuration. Check out the
851
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
852
+ """
853
+
854
+
855
+ @add_start_docstrings(
856
+ "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
857
+ MINICPM_START_DOCSTRING,
858
+ )
859
+ class MiniCPMPreTrainedModel(PreTrainedModel):
860
+ config_class = MiniCPMConfig
861
+ base_model_prefix = "model"
862
+ supports_gradient_checkpointing = True
863
+ _no_split_modules = ["MiniCPMDecoderLayer"]
864
+ _skip_keys_device_placement = "past_key_values"
865
+ _supports_flash_attn_2 = True
866
+ _supports_sdpa = True
867
+ _supports_cache_class = True
868
+
869
+ def _init_weights(self, module):
870
+ std = self.config.initializer_range
871
+ if isinstance(module, nn.Linear):
872
+ module.weight.data.normal_(mean=0.0, std=std)
873
+ if module.bias is not None:
874
+ module.bias.data.zero_()
875
+ elif isinstance(module, nn.Embedding):
876
+ module.weight.data.normal_(mean=0.0, std=std)
877
+ if module.padding_idx is not None:
878
+ module.weight.data[module.padding_idx].zero_()
879
+
880
+
881
+ MINICPM_INPUTS_DOCSTRING = r"""
882
+ Args:
883
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
884
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
885
+ it.
886
+
887
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
888
+ [`PreTrainedTokenizer.__call__`] for details.
889
+
890
+ [What are input IDs?](../glossary#input-ids)
891
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
892
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
893
+
894
+ - 1 for tokens that are **not masked**,
895
+ - 0 for tokens that are **masked**.
896
+
897
+ [What are attention masks?](../glossary#attention-mask)
898
+
899
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
900
+ [`PreTrainedTokenizer.__call__`] for details.
901
+
902
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
903
+ `past_key_values`).
904
+
905
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
906
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
907
+ information on the default strategy.
908
+
909
+ - 1 indicates the head is **not masked**,
910
+ - 0 indicates the head is **masked**.
911
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
912
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
913
+ config.n_positions - 1]`.
914
+
915
+ [What are position IDs?](../glossary#position-ids)
916
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
917
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
918
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
919
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
920
+
921
+ Two formats are allowed:
922
+ - a [`~cache_utils.Cache`] instance;
923
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
924
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
925
+ cache format.
926
+
927
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
928
+ legacy cache format will be returned.
929
+
930
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
931
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
932
+ of shape `(batch_size, sequence_length)`.
933
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
934
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
935
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
936
+ model's internal embedding lookup matrix.
937
+ use_cache (`bool`, *optional*):
938
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
939
+ `past_key_values`).
940
+ output_attentions (`bool`, *optional*):
941
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
942
+ tensors for more detail.
943
+ output_hidden_states (`bool`, *optional*):
944
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
945
+ more detail.
946
+ return_dict (`bool`, *optional*):
947
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
948
+ """
949
+
950
+
951
+ @add_start_docstrings(
952
+ "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
953
+ MINICPM_START_DOCSTRING,
954
+ )
955
+ class MiniCPMModel(MiniCPMPreTrainedModel):
956
+ """
957
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
958
+
959
+ Args:
960
+ config: MiniCPMConfig
961
+ """
962
+
963
+ def __init__(self, config: MiniCPMConfig):
964
+ super().__init__(config)
965
+ self.padding_idx = config.pad_token_id
966
+ self.vocab_size = config.vocab_size
967
+
968
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
969
+ self.layers = nn.ModuleList(
970
+ [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
971
+ )
972
+ self._use_sdpa = config._attn_implementation == "sdpa"
973
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
974
+
975
+ self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
976
+
977
+ self.gradient_checkpointing = False
978
+ # Initialize weights and apply final processing
979
+ self.post_init()
980
+
981
+ def get_input_embeddings(self):
982
+ return self.embed_tokens
983
+
984
+ def set_input_embeddings(self, value):
985
+ self.embed_tokens = value
986
+
987
+ @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
988
+ def forward(
989
+ self,
990
+ input_ids: torch.LongTensor = None,
991
+ attention_mask: Optional[torch.Tensor] = None,
992
+ position_ids: Optional[torch.LongTensor] = None,
993
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
994
+ inputs_embeds: Optional[torch.FloatTensor] = None,
995
+ use_cache: Optional[bool] = None,
996
+ output_attentions: Optional[bool] = None,
997
+ output_hidden_states: Optional[bool] = None,
998
+ return_dict: Optional[bool] = None,
999
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1000
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1001
+ output_hidden_states = (
1002
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1003
+ )
1004
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1005
+
1006
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1007
+
1008
+ # retrieve input_ids and inputs_embeds
1009
+ if input_ids is not None and inputs_embeds is not None:
1010
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1011
+ elif input_ids is not None:
1012
+ batch_size, seq_length = input_ids.shape[:2]
1013
+ elif inputs_embeds is not None:
1014
+ batch_size, seq_length = inputs_embeds.shape[:2]
1015
+ else:
1016
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1017
+
1018
+ if self.gradient_checkpointing and self.training:
1019
+ if use_cache:
1020
+ logger.warning_once(
1021
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1022
+ )
1023
+ use_cache = False
1024
+
1025
+ past_key_values_length = 0
1026
+ if use_cache:
1027
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1028
+ if use_legacy_cache:
1029
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1030
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1031
+
1032
+ if position_ids is None:
1033
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1034
+ position_ids = torch.arange(
1035
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1036
+ )
1037
+ position_ids = position_ids.unsqueeze(0)
1038
+
1039
+ if inputs_embeds is None:
1040
+ inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
1041
+
1042
+
1043
+ if self._use_flash_attention_2:
1044
+ # 2d mask is passed through the layers
1045
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1046
+ elif self._use_sdpa and not output_attentions:
1047
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1048
+ # the manual implementation that requires a 4D causal mask in all cases.
1049
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1050
+ attention_mask,
1051
+ (batch_size, seq_length),
1052
+ inputs_embeds,
1053
+ past_key_values_length,
1054
+ )
1055
+ else:
1056
+ # 4d mask is passed through the layers
1057
+ attention_mask = _prepare_4d_causal_attention_mask(
1058
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1059
+ )
1060
+
1061
+ # embed positions
1062
+ hidden_states = inputs_embeds
1063
+
1064
+ # decoder layers
1065
+ all_hidden_states = () if output_hidden_states else None
1066
+ all_self_attns = () if output_attentions else None
1067
+ next_decoder_cache = None
1068
+
1069
+ for decoder_layer in self.layers:
1070
+ if output_hidden_states:
1071
+ all_hidden_states += (hidden_states,)
1072
+
1073
+ if self.gradient_checkpointing and self.training:
1074
+ layer_outputs = self._gradient_checkpointing_func(
1075
+ decoder_layer.__call__,
1076
+ hidden_states,
1077
+ attention_mask,
1078
+ position_ids,
1079
+ past_key_values,
1080
+ output_attentions,
1081
+ use_cache,
1082
+ )
1083
+ else:
1084
+ layer_outputs = decoder_layer(
1085
+ hidden_states,
1086
+ attention_mask=attention_mask,
1087
+ position_ids=position_ids,
1088
+ past_key_value=past_key_values,
1089
+ output_attentions=output_attentions,
1090
+ use_cache=use_cache,
1091
+ )
1092
+
1093
+ hidden_states = layer_outputs[0]
1094
+
1095
+ if use_cache:
1096
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1097
+
1098
+ if output_attentions:
1099
+ all_self_attns += (layer_outputs[1],)
1100
+
1101
+ hidden_states = self.norm(hidden_states)
1102
+
1103
+ # add hidden states from the last decoder layer
1104
+ if output_hidden_states:
1105
+ all_hidden_states += (hidden_states,)
1106
+
1107
+ next_cache = None
1108
+ if use_cache:
1109
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1110
+ if not return_dict:
1111
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1112
+ return BaseModelOutputWithPast(
1113
+ last_hidden_state=hidden_states,
1114
+ past_key_values=next_cache,
1115
+ hidden_states=all_hidden_states,
1116
+ attentions=all_self_attns,
1117
+ )
1118
+
1119
+
1120
+ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1121
+ _tied_weights_keys = ["lm_head.weight"]
1122
+
1123
+ def __init__(self, config):
1124
+ super().__init__(config)
1125
+ self.model = MiniCPMModel(config)
1126
+ self.vocab_size = config.vocab_size
1127
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1128
+
1129
+ # Initialize weights and apply final processing
1130
+ self.post_init()
1131
+
1132
+ def get_input_embeddings(self):
1133
+ return self.model.embed_tokens
1134
+
1135
+ def set_input_embeddings(self, value):
1136
+ self.model.embed_tokens = value
1137
+
1138
+ def get_output_embeddings(self):
1139
+ return self.lm_head
1140
+
1141
+ def set_output_embeddings(self, new_embeddings):
1142
+ self.lm_head = new_embeddings
1143
+
1144
+ def set_decoder(self, decoder):
1145
+ self.model = decoder
1146
+
1147
+ def get_decoder(self):
1148
+ return self.model
1149
+
1150
+ @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1151
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1152
+ def forward(
1153
+ self,
1154
+ input_ids: torch.LongTensor = None,
1155
+ attention_mask: Optional[torch.Tensor] = None,
1156
+ position_ids: Optional[torch.LongTensor] = None,
1157
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1158
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1159
+ labels: Optional[torch.LongTensor] = None,
1160
+ use_cache: Optional[bool] = None,
1161
+ output_attentions: Optional[bool] = None,
1162
+ output_hidden_states: Optional[bool] = None,
1163
+ return_dict: Optional[bool] = None,
1164
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1165
+ r"""
1166
+ Args:
1167
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1168
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1169
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1170
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1171
+
1172
+ Returns:
1173
+
1174
+ Example:
1175
+
1176
+ ```python
1177
+ >>> from transformers import AutoTokenizer, MiniCPMForCausalLM
1178
+
1179
+ >>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1180
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1181
+
1182
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1183
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1184
+
1185
+ >>> # Generate
1186
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1187
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1188
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1189
+ ```"""
1190
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1191
+ output_hidden_states = (
1192
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1193
+ )
1194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1195
+
1196
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1197
+ outputs = self.model(
1198
+ input_ids=input_ids,
1199
+ attention_mask=attention_mask,
1200
+ position_ids=position_ids,
1201
+ past_key_values=past_key_values,
1202
+ inputs_embeds=inputs_embeds,
1203
+ use_cache=use_cache,
1204
+ output_attentions=output_attentions,
1205
+ output_hidden_states=output_hidden_states,
1206
+ return_dict=return_dict,
1207
+ )
1208
+
1209
+ hidden_states = outputs[0]
1210
+ if self.config.pretraining_tp > 1:
1211
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1212
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1213
+ logits = torch.cat(logits, dim=-1)
1214
+ else:
1215
+ logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))
1216
+ logits = logits.float()
1217
+
1218
+ loss = None
1219
+ if labels is not None:
1220
+ # Shift so that tokens < n predict n
1221
+ shift_logits = logits[..., :-1, :].contiguous()
1222
+ shift_labels = labels[..., 1:].contiguous()
1223
+ # Flatten the tokens
1224
+ loss_fct = CrossEntropyLoss()
1225
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1226
+ shift_labels = shift_labels.view(-1)
1227
+ # Enable model parallelism
1228
+ shift_labels = shift_labels.to(shift_logits.device)
1229
+ loss = loss_fct(shift_logits, shift_labels)
1230
+
1231
+ if not return_dict:
1232
+ output = (logits,) + outputs[1:]
1233
+ return (loss,) + output if loss is not None else output
1234
+
1235
+ return CausalLMOutputWithPast(
1236
+ loss=loss,
1237
+ logits=logits,
1238
+ past_key_values=outputs.past_key_values,
1239
+ hidden_states=outputs.hidden_states,
1240
+ attentions=outputs.attentions,
1241
+ )
1242
+
1243
+ def prepare_inputs_for_generation(
1244
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1245
+ ):
1246
+ if past_key_values is not None:
1247
+ if isinstance(past_key_values, Cache):
1248
+ cache_length = past_key_values.get_seq_length()
1249
+ past_length = past_key_values.seen_tokens
1250
+ max_cache_length = past_key_values.get_max_length()
1251
+ else:
1252
+ cache_length = past_length = past_key_values[0][0].shape[2]
1253
+ max_cache_length = None
1254
+
1255
+ # Keep only the unprocessed tokens:
1256
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1257
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1258
+ # input)
1259
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1260
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1261
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1262
+ # input_ids based on the past_length.
1263
+ elif past_length < input_ids.shape[1]:
1264
+ input_ids = input_ids[:, past_length:]
1265
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1266
+ else:
1267
+ remove_prefix_length = input_ids.shape[1] - 1
1268
+ input_ids = input_ids[:, remove_prefix_length:]
1269
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1270
+ if (
1271
+ max_cache_length is not None
1272
+ and attention_mask is not None
1273
+ and cache_length + input_ids.shape[1] > max_cache_length
1274
+ ):
1275
+ attention_mask = attention_mask[:, -max_cache_length:]
1276
+
1277
+ position_ids = kwargs.get("position_ids", None)
1278
+ if attention_mask is not None and position_ids is None:
1279
+ # create position_ids on the fly for batch generation
1280
+ position_ids = attention_mask.long().cumsum(-1) - 1
1281
+ position_ids.masked_fill_(attention_mask == 0, 1)
1282
+ if past_key_values:
1283
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1284
+
1285
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1286
+ if inputs_embeds is not None and past_key_values is None:
1287
+ model_inputs = {"inputs_embeds": inputs_embeds}
1288
+ else:
1289
+ model_inputs = {"input_ids": input_ids}
1290
+
1291
+ model_inputs.update(
1292
+ {
1293
+ "position_ids": position_ids,
1294
+ "past_key_values": past_key_values,
1295
+ "use_cache": kwargs.get("use_cache"),
1296
+ "attention_mask": attention_mask,
1297
+ }
1298
+ )
1299
+ return model_inputs
1300
+
1301
+ @staticmethod
1302
+ def _reorder_cache(past_key_values, beam_idx):
1303
+ reordered_past = ()
1304
+ for layer_past in past_key_values:
1305
+ reordered_past += (
1306
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1307
+ )
1308
+ return reordered_past
1309
+
1310
+ @torch.inference_mode()
1311
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1312
+ max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
1313
+ **kwargs):
1314
+ if history is None:
1315
+ history = []
1316
+ if logits_processor:
1317
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1318
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1319
+ else:
1320
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1321
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1322
+
1323
+ history.append({"role": role, "content": query})
1324
+ history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
1325
+ inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
1326
+ outputs = self.generate(**inputs, **gen_kwargs)
1327
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1328
+ response = tokenizer.decode(outputs)
1329
+ pattern = re.compile(r".*?(?=<AI>|<用户>)", re.DOTALL)
1330
+ matches = pattern.findall(response)
1331
+ if len(matches) > 0:
1332
+ response = matches[0]
1333
+ history.append({"role": "assistant", "content": response})
1334
+ return response, history
1335
+
1336
+
1337
+ @add_start_docstrings(
1338
+ """
1339
+ The MiniCPM Model transformer with a sequence classification head on top (linear layer).
1340
+
1341
+ [`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1342
+ (e.g. GPT-2) do.
1343
+
1344
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1345
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1346
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1347
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1348
+ each row of the batch).
1349
+ """,
1350
+ MINICPM_START_DOCSTRING,
1351
+ )
1352
+ class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
1353
+ def __init__(self, config):
1354
+ super().__init__(config)
1355
+ self.num_labels = config.num_labels
1356
+ self.model = MiniCPMModel(config)
1357
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1358
+
1359
+ # Initialize weights and apply final processing
1360
+ self.post_init()
1361
+
1362
+ def get_input_embeddings(self):
1363
+ return self.model.embed_tokens
1364
+
1365
+ def set_input_embeddings(self, value):
1366
+ self.model.embed_tokens = value
1367
+
1368
+ @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1369
+ def forward(
1370
+ self,
1371
+ input_ids: torch.LongTensor = None,
1372
+ attention_mask: Optional[torch.Tensor] = None,
1373
+ position_ids: Optional[torch.LongTensor] = None,
1374
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1375
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1376
+ labels: Optional[torch.LongTensor] = None,
1377
+ use_cache: Optional[bool] = None,
1378
+ output_attentions: Optional[bool] = None,
1379
+ output_hidden_states: Optional[bool] = None,
1380
+ return_dict: Optional[bool] = None,
1381
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1382
+ r"""
1383
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1384
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1385
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1386
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1387
+ """
1388
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1389
+
1390
+ transformer_outputs = self.model(
1391
+ input_ids,
1392
+ attention_mask=attention_mask,
1393
+ position_ids=position_ids,
1394
+ past_key_values=past_key_values,
1395
+ inputs_embeds=inputs_embeds,
1396
+ use_cache=use_cache,
1397
+ output_attentions=output_attentions,
1398
+ output_hidden_states=output_hidden_states,
1399
+ return_dict=return_dict,
1400
+ )
1401
+ hidden_states = transformer_outputs[0]
1402
+ logits = self.score(hidden_states)
1403
+
1404
+ if input_ids is not None:
1405
+ batch_size = input_ids.shape[0]
1406
+ else:
1407
+ batch_size = inputs_embeds.shape[0]
1408
+
1409
+ if self.config.pad_token_id is None and batch_size != 1:
1410
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1411
+ if self.config.pad_token_id is None:
1412
+ sequence_lengths = -1
1413
+ else:
1414
+ if input_ids is not None:
1415
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1416
+ logits.device
1417
+ )
1418
+ else:
1419
+ sequence_lengths = -1
1420
+
1421
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1422
+
1423
+ loss = None
1424
+ if labels is not None:
1425
+ labels = labels.to(logits.device)
1426
+ if self.config.problem_type is None:
1427
+ if self.num_labels == 1:
1428
+ self.config.problem_type = "regression"
1429
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1430
+ self.config.problem_type = "single_label_classification"
1431
+ else:
1432
+ self.config.problem_type = "multi_label_classification"
1433
+
1434
+ if self.config.problem_type == "regression":
1435
+ loss_fct = MSELoss()
1436
+ if self.num_labels == 1:
1437
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1438
+ else:
1439
+ loss = loss_fct(pooled_logits, labels)
1440
+ elif self.config.problem_type == "single_label_classification":
1441
+ loss_fct = CrossEntropyLoss()
1442
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1443
+ elif self.config.problem_type == "multi_label_classification":
1444
+ loss_fct = BCEWithLogitsLoss()
1445
+ loss = loss_fct(pooled_logits, labels)
1446
+ if not return_dict:
1447
+ output = (pooled_logits,) + transformer_outputs[1:]
1448
+ return ((loss,) + output) if loss is not None else output
1449
+
1450
+ return SequenceClassifierOutputWithPast(
1451
+ loss=loss,
1452
+ logits=pooled_logits,
1453
+ past_key_values=transformer_outputs.past_key_values,
1454
+ hidden_states=transformer_outputs.hidden_states,
1455
+ attentions=transformer_outputs.attentions,
1456
+ )
bunny/model/language_model/phi/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Microsoft and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ from transformers.utils import (
19
+ OptionalDependencyNotAvailable,
20
+ _LazyModule,
21
+ is_sentencepiece_available,
22
+ is_tokenizers_available,
23
+ is_torch_available,
24
+ )
25
+
26
+
27
+ _import_structure = {
28
+ "configuration_phi": ["PHI_PRETRAINED_CONFIG_ARCHIVE_MAP", "PhiConfig"],
29
+ }
30
+
31
+ try:
32
+ if not is_torch_available():
33
+ raise OptionalDependencyNotAvailable()
34
+ except OptionalDependencyNotAvailable:
35
+ pass
36
+ else:
37
+ _import_structure["modeling_phi"] = [
38
+ "PHI_PRETRAINED_MODEL_ARCHIVE_LIST",
39
+ "PhiPreTrainedModel",
40
+ "PhiModel",
41
+ "PhiForCausalLM",
42
+ "PhiForSequenceClassification",
43
+ "PhiForTokenClassification",
44
+ ]
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from .configuration_phi import PHI_PRETRAINED_CONFIG_ARCHIVE_MAP, PhiConfig
49
+
50
+ try:
51
+ if not is_torch_available():
52
+ raise OptionalDependencyNotAvailable()
53
+ except OptionalDependencyNotAvailable:
54
+ pass
55
+ else:
56
+ from .modeling_phi import (
57
+ PHI_PRETRAINED_MODEL_ARCHIVE_LIST,
58
+ PhiForCausalLM,
59
+ PhiForSequenceClassification,
60
+ PhiForTokenClassification,
61
+ PhiModel,
62
+ PhiPreTrainedModel,
63
+ )
64
+
65
+
66
+ else:
67
+ import sys
68
+
69
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
bunny/model/language_model/phi/configuration_phi.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Phi model configuration"""
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/resolve/main/config.json",
27
+ "microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/resolve/main/config.json",
28
+ "microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
29
+ }
30
+
31
+
32
+ class PhiConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of the Phi
37
+ [microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Args:
43
+ vocab_size (`int`, *optional*, defaults to 51200):
44
+ Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the
45
+ `inputs_ids` passed when calling [`PhiModel`].
46
+ hidden_size (`int`, *optional*, defaults to 2048):
47
+ Dimension of the hidden representations.
48
+ intermediate_size (`int`, *optional*, defaults to 8192):
49
+ Dimension of the MLP representations.
50
+ num_hidden_layers (`int`, *optional*, defaults to 24):
51
+ Number of hidden layers in the Transformer decoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 32):
53
+ Number of attention heads for each attention layer in the Transformer decoder.
54
+ num_key_value_heads (`int`, *optional*):
55
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
+ by meanpooling all the original heads within that group. For more details checkout [this
60
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
61
+ `num_attention_heads`.
62
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
63
+ Dropout probability for mlp outputs.
64
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
65
+ The dropout ratio for the embeddings.
66
+ attention_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio after computing the attention scores.
68
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
69
+ The non-linear activation function (function or string) in the decoder.
70
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
71
+ The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
72
+ tokens.
73
+ initializer_range (`float`, *optional*, defaults to 0.02):
74
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
76
+ The epsilon used by the rms normalization layers.
77
+ use_cache (`bool`, *optional*, defaults to `True`):
78
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
79
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
80
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
81
+ Whether to tie weight embeddings
82
+ rope_theta (`float`, *optional*, defaults to 10000.0):
83
+ The base period of the RoPE embeddings.
84
+ rope_scaling (`Dict`, *optional*):
85
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
86
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
87
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
88
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
89
+ these scaling strategies behave:
90
+ https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
91
+ is an experimental feature, subject to breaking API changes in future versions.
92
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
93
+ Percentage of the query and keys which will have rotary embedding.
94
+ qk_layernorm (`bool`, *optional*, defaults to `False`):
95
+ Whether or not to normalize the Queries and Keys after projecting the hidden states.
96
+ bos_token_id (`int`, *optional*, defaults to 1):
97
+ Denotes beginning of sequences token id.
98
+ eos_token_id (`int`, *optional*, defaults to 2):
99
+ Denotes end of sequences token id.
100
+
101
+ Example:
102
+
103
+ ```python
104
+ >>> from transformers import PhiModel, PhiConfig
105
+
106
+ >>> # Initializing a Phi-1 style configuration
107
+ >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
108
+
109
+ >>> # Initializing a model from the configuration
110
+ >>> model = PhiModel(configuration)
111
+
112
+ >>> # Accessing the model configuration
113
+ >>> configuration = model.config
114
+ ```"""
115
+
116
+ model_type = "phi"
117
+ keys_to_ignore_at_inference = ["past_key_values"]
118
+
119
+ def __init__(
120
+ self,
121
+ vocab_size=51200,
122
+ hidden_size=2048,
123
+ intermediate_size=8192,
124
+ num_hidden_layers=24,
125
+ num_attention_heads=32,
126
+ num_key_value_heads=None,
127
+ resid_pdrop=0.0,
128
+ embd_pdrop=0.0,
129
+ attention_dropout=0.0,
130
+ hidden_act="gelu_new",
131
+ max_position_embeddings=2048,
132
+ initializer_range=0.02,
133
+ layer_norm_eps=1e-5,
134
+ use_cache=True,
135
+ tie_word_embeddings=False,
136
+ rope_theta=10000.0,
137
+ rope_scaling=None,
138
+ partial_rotary_factor=0.5,
139
+ qk_layernorm=False,
140
+ bos_token_id=1,
141
+ eos_token_id=2,
142
+ **kwargs,
143
+ ):
144
+ self.vocab_size = vocab_size
145
+ self.hidden_size = hidden_size
146
+ self.intermediate_size = intermediate_size
147
+ self.num_hidden_layers = num_hidden_layers
148
+ self.num_attention_heads = num_attention_heads
149
+
150
+ if num_key_value_heads is None:
151
+ num_key_value_heads = num_attention_heads
152
+
153
+ self.num_key_value_heads = num_key_value_heads
154
+ self.resid_pdrop = resid_pdrop
155
+ self.embd_pdrop = embd_pdrop
156
+ self.attention_dropout = attention_dropout
157
+ self.hidden_act = hidden_act
158
+ self.max_position_embeddings = max_position_embeddings
159
+ self.initializer_range = initializer_range
160
+ self.layer_norm_eps = layer_norm_eps
161
+ self.use_cache = use_cache
162
+ self.rope_theta = rope_theta
163
+ self.rope_scaling = rope_scaling
164
+ self.partial_rotary_factor = partial_rotary_factor
165
+ self.qk_layernorm = qk_layernorm
166
+ self._rope_scaling_validation()
167
+
168
+ super().__init__(
169
+ bos_token_id=bos_token_id,
170
+ eos_token_id=eos_token_id,
171
+ tie_word_embeddings=tie_word_embeddings,
172
+ **kwargs,
173
+ )
174
+
175
+ # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
176
+ def _rope_scaling_validation(self):
177
+ """
178
+ Validate the `rope_scaling` configuration.
179
+ """
180
+ if self.rope_scaling is None:
181
+ return
182
+
183
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
184
+ raise ValueError(
185
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
186
+ f"got {self.rope_scaling}"
187
+ )
188
+ rope_scaling_type = self.rope_scaling.get("type", None)
189
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
190
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
191
+ raise ValueError(
192
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
193
+ )
194
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
195
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
bunny/model/language_model/phi/modeling_phi.py ADDED
@@ -0,0 +1,1374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch Phi model."""
17
+
18
+
19
+ import math
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPast,
33
+ CausalLMOutputWithPast,
34
+ SequenceClassifierOutputWithPast,
35
+ TokenClassifierOutput,
36
+ )
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import (
39
+ add_code_sample_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ is_flash_attn_2_available,
43
+ is_flash_attn_greater_or_equal_2_10,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_phi import PhiConfig
48
+
49
+
50
+ if is_flash_attn_2_available():
51
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
52
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CHECKPOINT_FOR_DOC = "microsoft/phi-1"
58
+ _CONFIG_FOR_DOC = "PhiConfig"
59
+
60
+ PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
61
+ "microsoft/phi-1",
62
+ "microsoft/phi-1_5",
63
+ "microsoft/phi-2",
64
+ # See all Phi models at https://huggingface.co/models?filter=phi
65
+ ]
66
+
67
+
68
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
69
+ def _get_unpad_data(attention_mask):
70
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
71
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
72
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
73
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
74
+ return (
75
+ indices,
76
+ cu_seqlens,
77
+ max_seqlen_in_batch,
78
+ )
79
+
80
+
81
+ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi
82
+ class PhiRotaryEmbedding(nn.Module):
83
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
84
+ super().__init__()
85
+
86
+ self.dim = dim
87
+ self.max_position_embeddings = max_position_embeddings
88
+ self.base = base
89
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
90
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
91
+
92
+ # Build here to make `torch.jit.trace` work.
93
+ self._set_cos_sin_cache(
94
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
95
+ )
96
+
97
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
98
+ self.max_seq_len_cached = seq_len
99
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
100
+
101
+ freqs = torch.outer(t, self.inv_freq)
102
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
103
+ emb = torch.cat((freqs, freqs), dim=-1)
104
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
105
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
106
+
107
+ def forward(self, x, seq_len=None):
108
+ # x: [bs, num_attention_heads, seq_len, head_size]
109
+ if seq_len > self.max_seq_len_cached:
110
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
111
+
112
+ return (
113
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
114
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
115
+ )
116
+
117
+
118
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
119
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
120
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
121
+
122
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
123
+ self.scaling_factor = scaling_factor
124
+ super().__init__(dim, max_position_embeddings, base, device)
125
+
126
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
127
+ self.max_seq_len_cached = seq_len
128
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
129
+ t = t / self.scaling_factor
130
+
131
+ freqs = torch.outer(t, self.inv_freq)
132
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
133
+ emb = torch.cat((freqs, freqs), dim=-1)
134
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
135
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
136
+
137
+
138
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
139
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
140
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
141
+
142
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
143
+ self.scaling_factor = scaling_factor
144
+ super().__init__(dim, max_position_embeddings, base, device)
145
+
146
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
147
+ self.max_seq_len_cached = seq_len
148
+
149
+ if seq_len > self.max_position_embeddings:
150
+ base = self.base * (
151
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
152
+ ) ** (self.dim / (self.dim - 2))
153
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
154
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
155
+
156
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
157
+
158
+ freqs = torch.outer(t, self.inv_freq)
159
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
162
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
163
+
164
+
165
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
166
+ def rotate_half(x):
167
+ """Rotates half the hidden dims of the input."""
168
+ x1 = x[..., : x.shape[-1] // 2]
169
+ x2 = x[..., x.shape[-1] // 2 :]
170
+ return torch.cat((-x2, x1), dim=-1)
171
+
172
+
173
+ # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
174
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
175
+ """Applies Rotary Position Embedding to the query and key tensors.
176
+
177
+ Args:
178
+ q (`torch.Tensor`): The query tensor.
179
+ k (`torch.Tensor`): The key tensor.
180
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
181
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
182
+ position_ids (`torch.Tensor`):
183
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
184
+ used to pass offsetted position ids when working with a KV-cache.
185
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
186
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
187
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
188
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
189
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
190
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
191
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
192
+ Returns:
193
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
194
+ """
195
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
196
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
197
+ q_embed = (q * cos) + (rotate_half(q) * sin)
198
+ k_embed = (k * cos) + (rotate_half(k) * sin)
199
+ return q_embed, k_embed
200
+
201
+
202
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
203
+ class PhiMLP(nn.Module):
204
+ def __init__(self, config):
205
+ super().__init__()
206
+ self.config = config
207
+ self.activation_fn = ACT2FN[config.hidden_act]
208
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
209
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
210
+
211
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
212
+ hidden_states = self.fc1(hidden_states)
213
+ hidden_states = self.activation_fn(hidden_states)
214
+ hidden_states = self.fc2(hidden_states)
215
+ return hidden_states
216
+
217
+
218
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
219
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
220
+ """
221
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
222
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
223
+ """
224
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
225
+ if n_rep == 1:
226
+ return hidden_states
227
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
228
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
229
+
230
+
231
+ class PhiAttention(nn.Module):
232
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
233
+
234
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
235
+ super().__init__()
236
+ self.config = config
237
+ self.layer_idx = layer_idx
238
+ if layer_idx is None:
239
+ logger.warning_once(
240
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
241
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
242
+ "when creating this class."
243
+ )
244
+
245
+ self.attention_dropout = config.attention_dropout
246
+ self.hidden_size = config.hidden_size
247
+ self.num_heads = config.num_attention_heads
248
+ self.head_dim = self.hidden_size // self.num_heads
249
+ self.num_key_value_heads = config.num_key_value_heads
250
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
251
+ self.max_position_embeddings = config.max_position_embeddings
252
+ self.rope_theta = config.rope_theta
253
+ self.partial_rotary_factor = config.partial_rotary_factor
254
+ self.is_causal = True
255
+
256
+ if (self.head_dim * self.num_heads) != self.hidden_size:
257
+ raise ValueError(
258
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
259
+ f" and `num_heads`: {self.num_heads})."
260
+ )
261
+
262
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
263
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
264
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
265
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
266
+
267
+ self.qk_layernorm = config.qk_layernorm
268
+ if self.qk_layernorm:
269
+ self.q_layernorm = nn.LayerNorm(
270
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
271
+ )
272
+ self.k_layernorm = nn.LayerNorm(
273
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
274
+ )
275
+
276
+ self._init_rope()
277
+
278
+ def _init_rope(self):
279
+ if self.config.rope_scaling is None:
280
+ self.rotary_emb = PhiRotaryEmbedding(
281
+ int(self.partial_rotary_factor * self.head_dim),
282
+ max_position_embeddings=self.max_position_embeddings,
283
+ base=self.rope_theta,
284
+ )
285
+ else:
286
+ scaling_type = self.config.rope_scaling["type"]
287
+ scaling_factor = self.config.rope_scaling["factor"]
288
+ if scaling_type == "linear":
289
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
290
+ int(self.partial_rotary_factor * self.head_dim),
291
+ max_position_embeddings=self.max_position_embeddings,
292
+ scaling_factor=scaling_factor,
293
+ base=self.rope_theta,
294
+ )
295
+ elif scaling_type == "dynamic":
296
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
297
+ int(self.partial_rotary_factor * self.head_dim),
298
+ max_position_embeddings=self.max_position_embeddings,
299
+ scaling_factor=scaling_factor,
300
+ base=self.rope_theta,
301
+ )
302
+ else:
303
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states: torch.Tensor,
308
+ attention_mask: Optional[torch.Tensor] = None,
309
+ position_ids: Optional[torch.LongTensor] = None,
310
+ past_key_value: Optional[Cache] = None,
311
+ output_attentions: bool = False,
312
+ use_cache: bool = False,
313
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
314
+ bsz, q_len, _ = hidden_states.size()
315
+
316
+ query_states = self.q_proj(hidden_states)
317
+ key_states = self.k_proj(hidden_states)
318
+ value_states = self.v_proj(hidden_states)
319
+
320
+ if self.qk_layernorm:
321
+ query_states = self.q_layernorm(query_states)
322
+ key_states = self.k_layernorm(key_states)
323
+
324
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
325
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
326
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
327
+
328
+ kv_seq_len = key_states.shape[-2]
329
+ if past_key_value is not None:
330
+ if self.layer_idx is None:
331
+ raise ValueError(
332
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
333
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
334
+ "with a layer index."
335
+ )
336
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
337
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
338
+
339
+ # Partial rotary embedding
340
+ query_rot, query_pass = (
341
+ query_states[..., : self.rotary_emb.dim],
342
+ query_states[..., self.rotary_emb.dim :],
343
+ )
344
+ key_rot, key_pass = (
345
+ key_states[..., : self.rotary_emb.dim],
346
+ key_states[..., self.rotary_emb.dim :],
347
+ )
348
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
349
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
350
+
351
+ # [batch_size, seq_length, num_heads, head_dim]
352
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
353
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
354
+
355
+ if past_key_value is not None:
356
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
357
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
358
+
359
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
360
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
361
+
362
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
363
+ attn_weights = torch.matmul(
364
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
365
+ ) / math.sqrt(self.head_dim)
366
+
367
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
368
+ raise ValueError(
369
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
370
+ f" {attn_weights.size()}"
371
+ )
372
+
373
+ if attention_mask is not None:
374
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
375
+ raise ValueError(
376
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
377
+ )
378
+ attn_weights = attn_weights + attention_mask
379
+
380
+ # upcast attention to fp32
381
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
382
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
383
+
384
+ attn_output = torch.matmul(attn_weights, value_states)
385
+
386
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
387
+ raise ValueError(
388
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
389
+ f" {attn_output.size()}"
390
+ )
391
+
392
+ attn_output = attn_output.transpose(1, 2).contiguous()
393
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
394
+
395
+ attn_output = self.dense(attn_output)
396
+
397
+ if not output_attentions:
398
+ attn_weights = None
399
+
400
+ return attn_output, attn_weights, past_key_value
401
+
402
+
403
+ class PhiFlashAttention2(PhiAttention):
404
+ """
405
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
406
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
407
+ flash attention and deal with padding tokens in case the input contains any of them.
408
+ """
409
+
410
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
411
+ def __init__(self, *args, **kwargs):
412
+ super().__init__(*args, **kwargs)
413
+
414
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
415
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
416
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
417
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
418
+
419
+ def forward(
420
+ self,
421
+ hidden_states: torch.Tensor,
422
+ attention_mask: Optional[torch.LongTensor] = None,
423
+ position_ids: Optional[torch.LongTensor] = None,
424
+ past_key_value: Optional[Cache] = None,
425
+ output_attentions: bool = False,
426
+ use_cache: bool = False,
427
+ **kwargs,
428
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
429
+ # PhiFlashAttention2 attention does not support output_attentions
430
+
431
+ output_attentions = False
432
+
433
+ bsz, q_len, _ = hidden_states.size()
434
+
435
+ query_states = self.q_proj(hidden_states)
436
+ key_states = self.k_proj(hidden_states)
437
+ value_states = self.v_proj(hidden_states)
438
+
439
+ if self.qk_layernorm:
440
+ query_states = self.q_layernorm(query_states)
441
+ key_states = self.k_layernorm(key_states)
442
+
443
+ # Flash attention requires the input to have the shape
444
+ # batch_size x seq_length x head_dim x hidden_dim
445
+ # therefore we just need to keep the original shape
446
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
447
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
448
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
449
+
450
+ kv_seq_len = key_states.shape[-2]
451
+ if past_key_value is not None:
452
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
453
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
454
+
455
+ # Partial rotary embedding
456
+ query_rot, query_pass = (
457
+ query_states[..., : self.rotary_emb.dim],
458
+ query_states[..., self.rotary_emb.dim :],
459
+ )
460
+ key_rot, key_pass = (
461
+ key_states[..., : self.rotary_emb.dim],
462
+ key_states[..., self.rotary_emb.dim :],
463
+ )
464
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
465
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
466
+
467
+ # [batch_size, seq_length, num_heads, head_dim]
468
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
469
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
470
+
471
+ if past_key_value is not None:
472
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
473
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
474
+
475
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
476
+ # to be able to avoid many of these transpose/reshape/view.
477
+ query_states = query_states.transpose(1, 2)
478
+ key_states = key_states.transpose(1, 2)
479
+ value_states = value_states.transpose(1, 2)
480
+
481
+ attn_dropout = self.attention_dropout if self.training else 0.0
482
+
483
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
484
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
485
+ # cast them back in the correct dtype just to be sure everything works as expected.
486
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
487
+ # in fp32.
488
+
489
+ if query_states.dtype == torch.float32:
490
+ if torch.is_autocast_enabled():
491
+ target_dtype = torch.get_autocast_gpu_dtype()
492
+ # Handle the case where the model is quantized
493
+ elif hasattr(self.config, "_pre_quantization_dtype"):
494
+ target_dtype = self.config._pre_quantization_dtype
495
+ else:
496
+ target_dtype = self.q_proj.weight.dtype
497
+
498
+ logger.warning_once(
499
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
500
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
501
+ f" {target_dtype}."
502
+ )
503
+
504
+ query_states = query_states.to(target_dtype)
505
+ key_states = key_states.to(target_dtype)
506
+ value_states = value_states.to(target_dtype)
507
+
508
+ attn_output = self._flash_attention_forward(
509
+ query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
510
+ )
511
+
512
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
513
+ attn_output = self.dense(attn_output)
514
+
515
+ if not output_attentions:
516
+ attn_weights = None
517
+
518
+ return attn_output, attn_weights, past_key_value
519
+
520
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
521
+ def _flash_attention_forward(
522
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
523
+ ):
524
+ """
525
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
526
+ first unpad the input, then computes the attention scores and pad the final attention scores.
527
+
528
+ Args:
529
+ query_states (`torch.Tensor`):
530
+ Input query states to be passed to Flash Attention API
531
+ key_states (`torch.Tensor`):
532
+ Input key states to be passed to Flash Attention API
533
+ value_states (`torch.Tensor`):
534
+ Input value states to be passed to Flash Attention API
535
+ attention_mask (`torch.Tensor`):
536
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
537
+ position of padding tokens and 1 for the position of non-padding tokens.
538
+ dropout (`int`, *optional*):
539
+ Attention dropout
540
+ softmax_scale (`float`, *optional*):
541
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
542
+ """
543
+ if not self._flash_attn_uses_top_left_mask:
544
+ causal = self.is_causal
545
+ else:
546
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
547
+ causal = self.is_causal and query_length != 1
548
+
549
+ # Contains at least one padding token in the sequence
550
+ if attention_mask is not None:
551
+ batch_size = query_states.shape[0]
552
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
553
+ query_states, key_states, value_states, attention_mask, query_length
554
+ )
555
+
556
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
557
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
558
+
559
+ attn_output_unpad = flash_attn_varlen_func(
560
+ query_states,
561
+ key_states,
562
+ value_states,
563
+ cu_seqlens_q=cu_seqlens_q,
564
+ cu_seqlens_k=cu_seqlens_k,
565
+ max_seqlen_q=max_seqlen_in_batch_q,
566
+ max_seqlen_k=max_seqlen_in_batch_k,
567
+ dropout_p=dropout,
568
+ softmax_scale=softmax_scale,
569
+ causal=causal,
570
+ )
571
+
572
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
573
+ else:
574
+ attn_output = flash_attn_func(
575
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
576
+ )
577
+
578
+ return attn_output
579
+
580
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
581
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
582
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
583
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
584
+
585
+ key_layer = index_first_axis(
586
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
587
+ )
588
+ value_layer = index_first_axis(
589
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
590
+ )
591
+ if query_length == kv_seq_len:
592
+ query_layer = index_first_axis(
593
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
594
+ )
595
+ cu_seqlens_q = cu_seqlens_k
596
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
597
+ indices_q = indices_k
598
+ elif query_length == 1:
599
+ max_seqlen_in_batch_q = 1
600
+ cu_seqlens_q = torch.arange(
601
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
602
+ ) # There is a memcpy here, that is very bad.
603
+ indices_q = cu_seqlens_q[:-1]
604
+ query_layer = query_layer.squeeze(1)
605
+ else:
606
+ # The -q_len: slice assumes left padding.
607
+ attention_mask = attention_mask[:, -query_length:]
608
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
609
+
610
+ return (
611
+ query_layer,
612
+ key_layer,
613
+ value_layer,
614
+ indices_q,
615
+ (cu_seqlens_q, cu_seqlens_k),
616
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
617
+ )
618
+
619
+
620
+ PHI_ATTENTION_CLASSES = {
621
+ "eager": PhiAttention,
622
+ "flash_attention_2": PhiFlashAttention2,
623
+ }
624
+
625
+
626
+ class PhiDecoderLayer(nn.Module):
627
+ def __init__(self, config: PhiConfig, layer_idx: int):
628
+ super().__init__()
629
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
630
+ self.mlp = PhiMLP(config)
631
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
632
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
633
+
634
+ def forward(
635
+ self,
636
+ hidden_states: torch.Tensor,
637
+ attention_mask: Optional[torch.Tensor] = None,
638
+ position_ids: Optional[torch.LongTensor] = None,
639
+ output_attentions: Optional[bool] = False,
640
+ use_cache: Optional[bool] = False,
641
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
642
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
643
+ """
644
+ Args:
645
+ hidden_states (`torch.FloatTensor`):
646
+ input to the layer of shape `(batch, seq_len, embed_dim)`
647
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
648
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
649
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
650
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
651
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
652
+ output_attentions (`bool`, *optional*):
653
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
654
+ returned tensors for more detail.
655
+ use_cache (`bool`, *optional*):
656
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
657
+ (see `past_key_values`).
658
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
659
+ """
660
+
661
+ residual = hidden_states
662
+
663
+ hidden_states = self.input_layernorm(hidden_states)
664
+
665
+ # Self Attention
666
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
667
+ hidden_states=hidden_states,
668
+ attention_mask=attention_mask,
669
+ position_ids=position_ids,
670
+ past_key_value=past_key_value,
671
+ output_attentions=output_attentions,
672
+ use_cache=use_cache,
673
+ )
674
+ attn_outputs = self.resid_dropout(attn_outputs)
675
+
676
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
677
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
678
+ outputs = (hidden_states,)
679
+
680
+ if output_attentions:
681
+ outputs += (self_attn_weights,)
682
+
683
+ if use_cache:
684
+ outputs += (present_key_value,)
685
+
686
+ return outputs
687
+
688
+
689
+ PHI_START_DOCSTRING = r"""
690
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
691
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
692
+ etc.)
693
+
694
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
695
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
696
+ and behavior.
697
+
698
+ Parameters:
699
+ config ([`PhiConfig`]):
700
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
701
+ load the weights associated with the model, only the configuration. Check out the
702
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
703
+ """
704
+
705
+
706
+ @add_start_docstrings(
707
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
708
+ PHI_START_DOCSTRING,
709
+ )
710
+ class PhiPreTrainedModel(PreTrainedModel):
711
+ config_class = PhiConfig
712
+ base_model_prefix = "model"
713
+ supports_gradient_checkpointing = True
714
+ _no_split_modules = ["PhiDecoderLayer"]
715
+ _skip_keys_device_placement = "past_key_values"
716
+ _supports_flash_attn_2 = True
717
+ _supports_cache_class = True
718
+
719
+ def _init_weights(self, module):
720
+ std = self.config.initializer_range
721
+ if isinstance(module, nn.Linear):
722
+ module.weight.data.normal_(mean=0.0, std=std)
723
+ if module.bias is not None:
724
+ module.bias.data.zero_()
725
+ elif isinstance(module, nn.Embedding):
726
+ module.weight.data.normal_(mean=0.0, std=std)
727
+ if module.padding_idx is not None:
728
+ module.weight.data[module.padding_idx].zero_()
729
+
730
+
731
+ PHI_INPUTS_DOCSTRING = r"""
732
+ Args:
733
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
734
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
735
+ it.
736
+
737
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
738
+ [`PreTrainedTokenizer.__call__`] for details.
739
+
740
+ [What are input IDs?](../glossary#input-ids)
741
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
742
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
743
+
744
+ - 1 for tokens that are **not masked**,
745
+ - 0 for tokens that are **masked**.
746
+
747
+ [What are attention masks?](../glossary#attention-mask)
748
+
749
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
750
+ [`PreTrainedTokenizer.__call__`] for details.
751
+
752
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
753
+ `past_key_values`).
754
+
755
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
756
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
757
+ information on the default strategy.
758
+
759
+ - 1 indicates the head is **not masked**,
760
+ - 0 indicates the head is **masked**.
761
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
762
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
763
+ config.n_positions - 1]`.
764
+
765
+ [What are position IDs?](../glossary#position-ids)
766
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
767
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
768
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
769
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
770
+
771
+ Two formats are allowed:
772
+ - a [`~cache_utils.Cache`] instance;
773
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
774
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
775
+ cache format.
776
+
777
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
778
+ legacy cache format will be returned.
779
+
780
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
781
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
782
+ of shape `(batch_size, sequence_length)`.
783
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
784
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
785
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
786
+ model's internal embedding lookup matrix.
787
+ use_cache (`bool`, *optional*):
788
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
789
+ `past_key_values`).
790
+ output_attentions (`bool`, *optional*):
791
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
792
+ tensors for more detail.
793
+ output_hidden_states (`bool`, *optional*):
794
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
795
+ more detail.
796
+ return_dict (`bool`, *optional*):
797
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
798
+ """
799
+
800
+
801
+ @add_start_docstrings(
802
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
803
+ PHI_START_DOCSTRING,
804
+ )
805
+ class PhiModel(PhiPreTrainedModel):
806
+ """
807
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
808
+
809
+ Args:
810
+ config: PhiConfig
811
+ """
812
+
813
+ def __init__(self, config: PhiConfig):
814
+ super().__init__(config)
815
+ self.padding_idx = config.pad_token_id
816
+ self.vocab_size = config.vocab_size
817
+
818
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
819
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
820
+ self.layers = nn.ModuleList(
821
+ [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
822
+ )
823
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
824
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
825
+
826
+ self.gradient_checkpointing = False
827
+ # Initialize weights and apply final processing
828
+ self.post_init()
829
+
830
+ def get_input_embeddings(self):
831
+ return self.embed_tokens
832
+
833
+ def set_input_embeddings(self, value):
834
+ self.embed_tokens = value
835
+
836
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
837
+ def forward(
838
+ self,
839
+ input_ids: torch.LongTensor = None,
840
+ attention_mask: Optional[torch.Tensor] = None,
841
+ position_ids: Optional[torch.LongTensor] = None,
842
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
843
+ inputs_embeds: Optional[torch.FloatTensor] = None,
844
+ use_cache: Optional[bool] = None,
845
+ output_attentions: Optional[bool] = None,
846
+ output_hidden_states: Optional[bool] = None,
847
+ return_dict: Optional[bool] = None,
848
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
849
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
850
+ output_hidden_states = (
851
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
852
+ )
853
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
854
+
855
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
856
+
857
+ # retrieve input_ids and inputs_embeds
858
+ if input_ids is not None and inputs_embeds is not None:
859
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
860
+ elif input_ids is not None:
861
+ batch_size, seq_length = input_ids.shape[:2]
862
+ elif inputs_embeds is not None:
863
+ batch_size, seq_length = inputs_embeds.shape[:2]
864
+ else:
865
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
866
+
867
+ past_key_values_length = 0
868
+
869
+ if self.gradient_checkpointing and self.training:
870
+ if use_cache:
871
+ logger.warning_once(
872
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
873
+ )
874
+ use_cache = False
875
+
876
+ if use_cache:
877
+ use_legacy_cache = not isinstance(past_key_values, Cache)
878
+ if use_legacy_cache:
879
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
880
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
881
+
882
+ if position_ids is None:
883
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
884
+ position_ids = torch.arange(
885
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
886
+ )
887
+ position_ids = position_ids.unsqueeze(0)
888
+
889
+ if inputs_embeds is None:
890
+ inputs_embeds = self.embed_tokens(input_ids)
891
+
892
+ inputs_embeds = self.embed_dropout(inputs_embeds)
893
+
894
+ # Attention mask.
895
+ if self._use_flash_attention_2:
896
+ # 2d mask is passed through the layers
897
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
898
+ else:
899
+ # 4d mask is passed through the layers
900
+ attention_mask = _prepare_4d_causal_attention_mask(
901
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
902
+ )
903
+
904
+ hidden_states = inputs_embeds
905
+
906
+ # decoder layers
907
+ all_hidden_states = () if output_hidden_states else None
908
+ all_self_attns = () if output_attentions else None
909
+ next_decoder_cache = None
910
+
911
+ for decoder_layer in self.layers:
912
+ if output_hidden_states:
913
+ all_hidden_states += (hidden_states,)
914
+
915
+ if self.gradient_checkpointing and self.training:
916
+ layer_outputs = self._gradient_checkpointing_func(
917
+ decoder_layer.__call__,
918
+ hidden_states,
919
+ attention_mask,
920
+ position_ids,
921
+ past_key_values,
922
+ output_attentions,
923
+ )
924
+ else:
925
+ layer_outputs = decoder_layer(
926
+ hidden_states,
927
+ attention_mask=attention_mask,
928
+ position_ids=position_ids,
929
+ past_key_value=past_key_values,
930
+ output_attentions=output_attentions,
931
+ use_cache=use_cache,
932
+ )
933
+
934
+ hidden_states = layer_outputs[0]
935
+
936
+ if use_cache:
937
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
938
+
939
+ if output_attentions:
940
+ all_self_attns += (layer_outputs[1],)
941
+
942
+ hidden_states = self.final_layernorm(hidden_states)
943
+
944
+ # add hidden states from the last decoder layer
945
+ if output_hidden_states:
946
+ all_hidden_states += (hidden_states,)
947
+
948
+ next_cache = None
949
+ if use_cache:
950
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
951
+ if not return_dict:
952
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
953
+ return BaseModelOutputWithPast(
954
+ last_hidden_state=hidden_states,
955
+ past_key_values=next_cache,
956
+ hidden_states=all_hidden_states,
957
+ attentions=all_self_attns,
958
+ )
959
+
960
+
961
+ class PhiForCausalLM(PhiPreTrainedModel):
962
+ _tied_weights_keys = ["lm_head.weight"]
963
+
964
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
965
+ def __init__(self, config):
966
+ super().__init__(config)
967
+ self.model = PhiModel(config)
968
+ self.vocab_size = config.vocab_size
969
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
970
+
971
+ # Initialize weights and apply final processing
972
+ self.post_init()
973
+
974
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
975
+ def get_input_embeddings(self):
976
+ return self.model.embed_tokens
977
+
978
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
979
+ def set_input_embeddings(self, value):
980
+ self.model.embed_tokens = value
981
+
982
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
983
+ def get_output_embeddings(self):
984
+ return self.lm_head
985
+
986
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
987
+ def set_output_embeddings(self, new_embeddings):
988
+ self.lm_head = new_embeddings
989
+
990
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
991
+ def set_decoder(self, decoder):
992
+ self.model = decoder
993
+
994
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
995
+ def get_decoder(self):
996
+ return self.model
997
+
998
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
999
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1000
+ def forward(
1001
+ self,
1002
+ input_ids: torch.LongTensor = None,
1003
+ attention_mask: Optional[torch.Tensor] = None,
1004
+ position_ids: Optional[torch.LongTensor] = None,
1005
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1006
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1007
+ labels: Optional[torch.LongTensor] = None,
1008
+ use_cache: Optional[bool] = None,
1009
+ output_attentions: Optional[bool] = None,
1010
+ output_hidden_states: Optional[bool] = None,
1011
+ return_dict: Optional[bool] = None,
1012
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1013
+ r"""
1014
+ Args:
1015
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1016
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1017
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1018
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1019
+
1020
+ Returns:
1021
+
1022
+ Example:
1023
+
1024
+ ```python
1025
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1026
+
1027
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1028
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1029
+
1030
+ >>> prompt = "This is an example script ."
1031
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1032
+
1033
+ >>> # Generate
1034
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1035
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1036
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1037
+ ```"""
1038
+
1039
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1040
+ output_hidden_states = (
1041
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1042
+ )
1043
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1044
+
1045
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1046
+ outputs = self.model(
1047
+ input_ids=input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ past_key_values=past_key_values,
1051
+ inputs_embeds=inputs_embeds,
1052
+ use_cache=use_cache,
1053
+ output_attentions=output_attentions,
1054
+ output_hidden_states=output_hidden_states,
1055
+ return_dict=return_dict,
1056
+ )
1057
+
1058
+ hidden_states = outputs[0]
1059
+ logits = self.lm_head(hidden_states)
1060
+ logits = logits.float()
1061
+
1062
+ loss = None
1063
+ if labels is not None:
1064
+ # Shift so that tokens < n predict n
1065
+ shift_logits = logits[..., :-1, :].contiguous()
1066
+ shift_labels = labels[..., 1:].contiguous()
1067
+ # Flatten the tokens
1068
+ loss_fct = CrossEntropyLoss()
1069
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1070
+ shift_labels = shift_labels.view(-1)
1071
+ # Enable model parallelism
1072
+ shift_labels = shift_labels.to(shift_logits.device)
1073
+ loss = loss_fct(shift_logits, shift_labels)
1074
+
1075
+ if not return_dict:
1076
+ output = (logits,) + outputs[1:]
1077
+ return (loss,) + output if loss is not None else output
1078
+
1079
+ return CausalLMOutputWithPast(
1080
+ loss=loss,
1081
+ logits=logits,
1082
+ past_key_values=outputs.past_key_values,
1083
+ hidden_states=outputs.hidden_states,
1084
+ attentions=outputs.attentions,
1085
+ )
1086
+
1087
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1088
+ def prepare_inputs_for_generation(
1089
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1090
+ ):
1091
+ if past_key_values is not None:
1092
+ if isinstance(past_key_values, Cache):
1093
+ cache_length = past_key_values.get_seq_length()
1094
+ past_length = past_key_values.seen_tokens
1095
+ max_cache_length = past_key_values.get_max_length()
1096
+ else:
1097
+ cache_length = past_length = past_key_values[0][0].shape[2]
1098
+ max_cache_length = None
1099
+
1100
+ # Keep only the unprocessed tokens:
1101
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1102
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1103
+ # input)
1104
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1105
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1106
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1107
+ # input_ids based on the past_length.
1108
+ elif past_length < input_ids.shape[1]:
1109
+ input_ids = input_ids[:, past_length:]
1110
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1111
+ else:
1112
+ remove_prefix_length = input_ids.shape[1] - 1
1113
+ input_ids = input_ids[:, remove_prefix_length:]
1114
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1115
+ if (
1116
+ max_cache_length is not None
1117
+ and attention_mask is not None
1118
+ and cache_length + input_ids.shape[1] > max_cache_length
1119
+ ):
1120
+ attention_mask = attention_mask[:, -max_cache_length:]
1121
+
1122
+ position_ids = kwargs.get("position_ids", None)
1123
+ if attention_mask is not None and position_ids is None:
1124
+ # create position_ids on the fly for batch generation
1125
+ position_ids = attention_mask.long().cumsum(-1) - 1
1126
+ position_ids.masked_fill_(attention_mask == 0, 1)
1127
+ if past_key_values:
1128
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1129
+
1130
+ if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
1131
+ # generation with static cache
1132
+ seen_tokens = past_key_value.get_seq_length()
1133
+ input_ids = input_ids[:, seen_tokens:]
1134
+ position_ids = position_ids[:, seen_tokens:]
1135
+
1136
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1137
+ if inputs_embeds is not None and past_key_values is None:
1138
+ model_inputs = {"inputs_embeds": inputs_embeds}
1139
+ else:
1140
+ model_inputs = {"input_ids": input_ids}
1141
+
1142
+ model_inputs.update(
1143
+ {
1144
+ "position_ids": position_ids,
1145
+ "past_key_values": past_key_values,
1146
+ "use_cache": kwargs.get("use_cache"),
1147
+ "attention_mask": attention_mask,
1148
+ }
1149
+ )
1150
+ return model_inputs
1151
+
1152
+ @staticmethod
1153
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1154
+ def _reorder_cache(past_key_values, beam_idx):
1155
+ reordered_past = ()
1156
+ for layer_past in past_key_values:
1157
+ reordered_past += (
1158
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1159
+ )
1160
+ return reordered_past
1161
+
1162
+
1163
+ @add_start_docstrings(
1164
+ """
1165
+ The PhiModel with a sequence classification head on top (linear layer).
1166
+
1167
+ [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1168
+ (e.g. GPT-2) do.
1169
+
1170
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1171
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1172
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1173
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1174
+ each row of the batch).
1175
+ """,
1176
+ PHI_START_DOCSTRING,
1177
+ )
1178
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
1179
+ class PhiForSequenceClassification(PhiPreTrainedModel):
1180
+ def __init__(self, config):
1181
+ super().__init__(config)
1182
+ self.num_labels = config.num_labels
1183
+ self.model = PhiModel(config)
1184
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1185
+
1186
+ # Initialize weights and apply final processing
1187
+ self.post_init()
1188
+
1189
+ def get_input_embeddings(self):
1190
+ return self.model.embed_tokens
1191
+
1192
+ def set_input_embeddings(self, value):
1193
+ self.model.embed_tokens = value
1194
+
1195
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1196
+ def forward(
1197
+ self,
1198
+ input_ids: torch.LongTensor = None,
1199
+ attention_mask: Optional[torch.Tensor] = None,
1200
+ position_ids: Optional[torch.LongTensor] = None,
1201
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1202
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1203
+ labels: Optional[torch.LongTensor] = None,
1204
+ use_cache: Optional[bool] = None,
1205
+ output_attentions: Optional[bool] = None,
1206
+ output_hidden_states: Optional[bool] = None,
1207
+ return_dict: Optional[bool] = None,
1208
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1209
+ r"""
1210
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1211
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1212
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1213
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1214
+ """
1215
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1216
+
1217
+ model_outputs = self.model(
1218
+ input_ids,
1219
+ attention_mask=attention_mask,
1220
+ position_ids=position_ids,
1221
+ past_key_values=past_key_values,
1222
+ inputs_embeds=inputs_embeds,
1223
+ use_cache=use_cache,
1224
+ output_attentions=output_attentions,
1225
+ output_hidden_states=output_hidden_states,
1226
+ return_dict=return_dict,
1227
+ )
1228
+ hidden_states = model_outputs[0]
1229
+ logits = self.score(hidden_states)
1230
+
1231
+ if input_ids is not None:
1232
+ batch_size = input_ids.shape[0]
1233
+ else:
1234
+ batch_size = inputs_embeds.shape[0]
1235
+
1236
+ if self.config.pad_token_id is None and batch_size != 1:
1237
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1238
+ if self.config.pad_token_id is None:
1239
+ sequence_lengths = -1
1240
+ else:
1241
+ if input_ids is not None:
1242
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1243
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1244
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1245
+ sequence_lengths = sequence_lengths.to(logits.device)
1246
+ else:
1247
+ sequence_lengths = -1
1248
+
1249
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1250
+
1251
+ loss = None
1252
+ if labels is not None:
1253
+ labels = labels.to(logits.device)
1254
+ if self.config.problem_type is None:
1255
+ if self.num_labels == 1:
1256
+ self.config.problem_type = "regression"
1257
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1258
+ self.config.problem_type = "single_label_classification"
1259
+ else:
1260
+ self.config.problem_type = "multi_label_classification"
1261
+
1262
+ if self.config.problem_type == "regression":
1263
+ loss_fct = MSELoss()
1264
+ if self.num_labels == 1:
1265
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1266
+ else:
1267
+ loss = loss_fct(pooled_logits, labels)
1268
+ elif self.config.problem_type == "single_label_classification":
1269
+ loss_fct = CrossEntropyLoss()
1270
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1271
+ elif self.config.problem_type == "multi_label_classification":
1272
+ loss_fct = BCEWithLogitsLoss()
1273
+ loss = loss_fct(pooled_logits, labels)
1274
+ if not return_dict:
1275
+ output = (pooled_logits,) + model_outputs[1:]
1276
+ return ((loss,) + output) if loss is not None else output
1277
+
1278
+ return SequenceClassifierOutputWithPast(
1279
+ loss=loss,
1280
+ logits=pooled_logits,
1281
+ past_key_values=model_outputs.past_key_values,
1282
+ hidden_states=model_outputs.hidden_states,
1283
+ attentions=model_outputs.attentions,
1284
+ )
1285
+
1286
+
1287
+ @add_start_docstrings(
1288
+ """
1289
+ PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1290
+ Named-Entity-Recognition (NER) tasks.
1291
+ """,
1292
+ PHI_START_DOCSTRING,
1293
+ )
1294
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
1295
+ class PhiForTokenClassification(PhiPreTrainedModel):
1296
+ def __init__(self, config: PhiConfig):
1297
+ super().__init__(config)
1298
+ self.num_labels = config.num_labels
1299
+
1300
+ self.model = PhiModel(config)
1301
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1302
+ classifier_dropout = config.classifier_dropout
1303
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1304
+ classifier_dropout = config.hidden_dropout
1305
+ else:
1306
+ classifier_dropout = 0.1
1307
+ self.dropout = nn.Dropout(classifier_dropout)
1308
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1309
+
1310
+ # Initialize weights and apply final processing
1311
+ self.post_init()
1312
+
1313
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1314
+ @add_code_sample_docstrings(
1315
+ checkpoint=_CHECKPOINT_FOR_DOC,
1316
+ output_type=TokenClassifierOutput,
1317
+ config_class=_CONFIG_FOR_DOC,
1318
+ )
1319
+ def forward(
1320
+ self,
1321
+ input_ids: Optional[torch.LongTensor] = None,
1322
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1323
+ attention_mask: Optional[torch.Tensor] = None,
1324
+ inputs_embeds: Optional[torch.Tensor] = None,
1325
+ labels: Optional[torch.Tensor] = None,
1326
+ use_cache: Optional[bool] = None,
1327
+ output_attentions: Optional[bool] = None,
1328
+ output_hidden_states: Optional[bool] = None,
1329
+ return_dict: Optional[bool] = None,
1330
+ **deprecated_arguments,
1331
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1332
+ r"""
1333
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1334
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1335
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1336
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1337
+ """
1338
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1339
+
1340
+ model_outputs = self.model(
1341
+ input_ids,
1342
+ past_key_values=past_key_values,
1343
+ attention_mask=attention_mask,
1344
+ inputs_embeds=inputs_embeds,
1345
+ use_cache=use_cache,
1346
+ output_attentions=output_attentions,
1347
+ output_hidden_states=output_hidden_states,
1348
+ return_dict=return_dict,
1349
+ )
1350
+
1351
+ hidden_states = model_outputs[0]
1352
+ hidden_states = self.dropout(hidden_states)
1353
+ logits = self.classifier(hidden_states)
1354
+
1355
+ loss = None
1356
+ if labels is not None:
1357
+ # move labels to correct device to enable model parallelism
1358
+ labels = labels.to(logits.device)
1359
+ batch_size, seq_length = labels.shape
1360
+ loss_fct = CrossEntropyLoss()
1361
+ loss = loss_fct(
1362
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1363
+ )
1364
+
1365
+ if not return_dict:
1366
+ output = (logits,) + model_outputs[2:]
1367
+ return ((loss,) + output) if loss is not None else output
1368
+
1369
+ return TokenClassifierOutput(
1370
+ loss=loss,
1371
+ logits=logits,
1372
+ hidden_states=model_outputs.hidden_states,
1373
+ attentions=model_outputs.attentions,
1374
+ )
bunny/model/language_model/phi3/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Microsoft and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ from transformers.utils import (
19
+ OptionalDependencyNotAvailable,
20
+ _LazyModule,
21
+ is_sentencepiece_available,
22
+ is_tokenizers_available,
23
+ is_torch_available,
24
+ )
25
+
26
+
27
+ _import_structure = {
28
+ "configuration_phi3": ["PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP", "Phi3Config"],
29
+ }
30
+
31
+ try:
32
+ if not is_torch_available():
33
+ raise OptionalDependencyNotAvailable()
34
+ except OptionalDependencyNotAvailable:
35
+ pass
36
+ else:
37
+ _import_structure["modeling_phi3"] = [
38
+ "PHI3_PRETRAINED_MODEL_ARCHIVE_LIST",
39
+ "Phi3PreTrainedModel",
40
+ "Phi3Model",
41
+ "Phi3ForCausalLM",
42
+ "Phi3ForSequenceClassification",
43
+ "Phi3ForTokenClassification",
44
+ ]
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from .configuration_phi3 import PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP, Phi3Config
49
+
50
+ try:
51
+ if not is_torch_available():
52
+ raise OptionalDependencyNotAvailable()
53
+ except OptionalDependencyNotAvailable:
54
+ pass
55
+ else:
56
+ from .modeling_phi3 import (
57
+ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST,
58
+ Phi3ForCausalLM,
59
+ Phi3ForSequenceClassification,
60
+ Phi3ForTokenClassification,
61
+ Phi3Model,
62
+ Phi3PreTrainedModel,
63
+ )
64
+
65
+
66
+ else:
67
+ import sys
68
+
69
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
bunny/model/language_model/phi3/configuration_phi3.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Phi-3 model configuration"""
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json",
27
+ "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json",
28
+ }
29
+
30
+
31
+ class Phi3Config(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
34
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35
+ defaults will yield a similar configuration to that of the
36
+ [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32064):
43
+ Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`Phi3Model`].
45
+ hidden_size (`int`, *optional*, defaults to 3072):
46
+ Dimension of the hidden representations.
47
+ intermediate_size (`int`, *optional*, defaults to 8192):
48
+ Dimension of the MLP representations.
49
+ num_hidden_layers (`int`, *optional*, defaults to 32):
50
+ Number of hidden layers in the Transformer decoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 32):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ num_key_value_heads (`int`, *optional*):
54
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58
+ by meanpooling all the original heads within that group. For more details checkout [this
59
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60
+ `num_attention_heads`.
61
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
62
+ Dropout probability for mlp outputs.
63
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
64
+ The dropout ratio for the embeddings.
65
+ attention_dropout (`float`, *optional*, defaults to 0.0):
66
+ The dropout ratio after computing the attention scores.
67
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
68
+ The non-linear activation function (function or string) in the decoder.
69
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
70
+ The maximum sequence length that this model might ever be used with.
71
+ original_max_position_embeddings (`int`, *optional*, defaults to 4096):
72
+ The maximum sequence length that this model was trained with. This is used to determine the size of the
73
+ original RoPE embeddings when using long scaling.
74
+ initializer_range (`float`, *optional*, defaults to 0.02):
75
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
77
+ The epsilon value used for the RMSNorm.
78
+ use_cache (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
80
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
81
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82
+ Whether to tie weight embeddings
83
+ rope_theta (`float`, *optional*, defaults to 10000.0):
84
+ The base period of the RoPE embeddings.
85
+ rope_scaling (`dict`, *optional*):
86
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
87
+ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
88
+ the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
89
+ divided by the number of attention heads divided by 2.
90
+ bos_token_id (`int`, *optional*, defaults to 1):
91
+ The id of the "beginning-of-sequence" token.
92
+ eos_token_id (`int`, *optional*, defaults to 32000):
93
+ The id of the "end-of-sequence" token.
94
+ pad_token_id (`int`, *optional*, defaults to 32000):
95
+ The id of the padding token.
96
+ sliding_window (`int`, *optional*):
97
+ Sliding window attention window size. If `None`, no sliding window is applied.
98
+
99
+ Example:
100
+
101
+ ```python
102
+ >>> from transformers import Phi3Model, Phi3Config
103
+
104
+ >>> # Initializing a Phi-3 style configuration
105
+ >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
106
+
107
+ >>> # Initializing a model from the configuration
108
+ >>> model = Phi3Model(configuration)
109
+
110
+ >>> # Accessing the model configuration
111
+ >>> configuration = model.config
112
+ ```"""
113
+
114
+ model_type = "phi3"
115
+ keys_to_ignore_at_inference = ["past_key_values"]
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_size=32064,
120
+ hidden_size=3072,
121
+ intermediate_size=8192,
122
+ num_hidden_layers=32,
123
+ num_attention_heads=32,
124
+ num_key_value_heads=None,
125
+ resid_pdrop=0.0,
126
+ embd_pdrop=0.0,
127
+ attention_dropout=0.0,
128
+ hidden_act="silu",
129
+ max_position_embeddings=4096,
130
+ original_max_position_embeddings=4096,
131
+ initializer_range=0.02,
132
+ rms_norm_eps=1e-5,
133
+ use_cache=True,
134
+ tie_word_embeddings=False,
135
+ rope_theta=10000.0,
136
+ rope_scaling=None,
137
+ bos_token_id=1,
138
+ eos_token_id=32000,
139
+ pad_token_id=32000,
140
+ sliding_window=None,
141
+ **kwargs,
142
+ ):
143
+ self.vocab_size = vocab_size
144
+ self.hidden_size = hidden_size
145
+ self.intermediate_size = intermediate_size
146
+ self.num_hidden_layers = num_hidden_layers
147
+ self.num_attention_heads = num_attention_heads
148
+
149
+ if num_key_value_heads is None:
150
+ num_key_value_heads = num_attention_heads
151
+
152
+ self.num_key_value_heads = num_key_value_heads
153
+ self.resid_pdrop = resid_pdrop
154
+ self.embd_pdrop = embd_pdrop
155
+ self.attention_dropout = attention_dropout
156
+ self.hidden_act = hidden_act
157
+ self.max_position_embeddings = max_position_embeddings
158
+ self.original_max_position_embeddings = original_max_position_embeddings
159
+ self.initializer_range = initializer_range
160
+ self.rms_norm_eps = rms_norm_eps
161
+ self.use_cache = use_cache
162
+ self.rope_theta = rope_theta
163
+ self.rope_scaling = rope_scaling
164
+ self._rope_scaling_validation()
165
+ self.sliding_window = sliding_window
166
+
167
+ super().__init__(
168
+ bos_token_id=bos_token_id,
169
+ eos_token_id=eos_token_id,
170
+ pad_token_id=pad_token_id,
171
+ tie_word_embeddings=tie_word_embeddings,
172
+ **kwargs,
173
+ )
174
+
175
+ def _rope_scaling_validation(self):
176
+ """
177
+ Validate the `rope_scaling` configuration.
178
+ """
179
+ if self.rope_scaling is None:
180
+ return
181
+
182
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
183
+ raise ValueError(
184
+ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
185
+ f"got {self.rope_scaling}"
186
+ )
187
+ rope_scaling_type = self.rope_scaling.get("type", None)
188
+ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
189
+ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
190
+ if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
191
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
192
+ if not (
193
+ isinstance(rope_scaling_short_factor, list)
194
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
195
+ ):
196
+ raise ValueError(
197
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
198
+ )
199
+ if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
200
+ raise ValueError(
201
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
202
+ )
203
+ if not (
204
+ isinstance(rope_scaling_long_factor, list)
205
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
206
+ ):
207
+ raise ValueError(
208
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
209
+ )
210
+ if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
211
+ raise ValueError(
212
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
213
+ )
bunny/model/language_model/phi3/modeling_phi3.py ADDED
@@ -0,0 +1,1597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch Phi-3 model."""
17
+
18
+ import inspect
19
+ import math
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
+ from transformers.modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ SequenceClassifierOutputWithPast,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.utils import (
40
+ add_code_sample_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ is_flash_attn_2_available,
44
+ is_flash_attn_greater_or_equal_2_10,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from .configuration_phi3 import Phi3Config
49
+
50
+
51
+ if is_flash_attn_2_available():
52
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
53
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
54
+
55
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
60
+ _CONFIG_FOR_DOC = "Phi3Config"
61
+
62
+ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "microsoft/Phi-3-mini-4k-instruct",
64
+ "microsoft/Phi-3-mini-128k-instruct",
65
+ # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
66
+ ]
67
+
68
+
69
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
70
+ class Phi3RMSNorm(nn.Module):
71
+ def __init__(self, hidden_size, eps=1e-6):
72
+ """
73
+ Phi3RMSNorm is equivalent to T5LayerNorm
74
+ """
75
+ super().__init__()
76
+ self.weight = nn.Parameter(torch.ones(hidden_size))
77
+ self.variance_epsilon = eps
78
+
79
+ def forward(self, hidden_states):
80
+ input_dtype = hidden_states.dtype
81
+ hidden_states = hidden_states.to(torch.float32)
82
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
83
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
84
+ return self.weight * hidden_states.to(input_dtype)
85
+
86
+
87
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
88
+ def _get_unpad_data(attention_mask):
89
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
90
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
91
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
92
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
93
+ return (
94
+ indices,
95
+ cu_seqlens,
96
+ max_seqlen_in_batch,
97
+ )
98
+
99
+
100
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
101
+ class Phi3RotaryEmbedding(nn.Module):
102
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
103
+ super().__init__()
104
+
105
+ self.dim = dim
106
+ self.max_position_embeddings = max_position_embeddings
107
+ self.base = base
108
+ self.register_buffer("inv_freq", None, persistent=False)
109
+
110
+ @torch.no_grad()
111
+ def forward(self, x, position_ids, seq_len=None):
112
+ # x: [bs, num_attention_heads, seq_len, head_size]
113
+ if self.inv_freq is None:
114
+ self.inv_freq = 1.0 / (
115
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
116
+ )
117
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
118
+ position_ids_expanded = position_ids[:, None, :].float()
119
+ # Force float32 since bfloat16 loses precision on long contexts
120
+ # See https://github.com/huggingface/transformers/pull/29285
121
+ device_type = x.device.type
122
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
123
+ with torch.autocast(device_type=device_type, enabled=False):
124
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
125
+ emb = torch.cat((freqs, freqs), dim=-1)
126
+ cos = emb.cos()
127
+ sin = emb.sin()
128
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
129
+
130
+
131
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
132
+ def __init__(self, dim, config, device=None):
133
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
134
+
135
+ self.short_factor = config.rope_scaling["short_factor"]
136
+ self.long_factor = config.rope_scaling["long_factor"]
137
+ self.original_max_position_embeddings = config.original_max_position_embeddings
138
+
139
+ @torch.no_grad()
140
+ def forward(self, x, position_ids, seq_len=None):
141
+ seq_len = torch.max(position_ids) + 1
142
+ if seq_len > self.original_max_position_embeddings:
143
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
144
+ else:
145
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
146
+
147
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
148
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
149
+
150
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
151
+ position_ids_expanded = position_ids[:, None, :].float()
152
+
153
+ # Force float32 since bfloat16 loses precision on long contexts
154
+ # See https://github.com/huggingface/transformers/pull/29285
155
+ device_type = x.device.type
156
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
157
+ with torch.autocast(device_type=device_type, enabled=False):
158
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
159
+ emb = torch.cat((freqs, freqs), dim=-1)
160
+
161
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
162
+ if scale <= 1.0:
163
+ scaling_factor = 1.0
164
+ else:
165
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
166
+
167
+ cos = emb.cos() * scaling_factor
168
+ sin = emb.sin() * scaling_factor
169
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
170
+
171
+
172
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
173
+ def __init__(self, dim, config, device=None):
174
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
175
+
176
+ self.short_factor = config.rope_scaling["short_factor"]
177
+ self.long_factor = config.rope_scaling["long_factor"]
178
+ self.original_max_position_embeddings = config.original_max_position_embeddings
179
+
180
+ @torch.no_grad()
181
+ def forward(self, x, position_ids, seq_len=None):
182
+ seq_len = torch.max(position_ids) + 1
183
+ if seq_len > self.original_max_position_embeddings:
184
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
185
+ else:
186
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
187
+
188
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
189
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
190
+
191
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
192
+ position_ids_expanded = position_ids[:, None, :].float()
193
+
194
+ # Force float32 since bfloat16 loses precision on long contexts
195
+ # See https://github.com/huggingface/transformers/pull/29285
196
+ device_type = x.device.type
197
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
198
+ with torch.autocast(device_type=device_type, enabled=False):
199
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
200
+ emb = torch.cat((freqs, freqs), dim=-1)
201
+
202
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
203
+ if scale <= 1.0:
204
+ scaling_factor = 1.0
205
+ else:
206
+ scaling_factor = 0.1 * math.log(scale) + 1.0
207
+
208
+ cos = emb.cos() * scaling_factor
209
+ sin = emb.sin() * scaling_factor
210
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
211
+
212
+
213
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
214
+ def rotate_half(x):
215
+ """Rotates half the hidden dims of the input."""
216
+ x1 = x[..., : x.shape[-1] // 2]
217
+ x2 = x[..., x.shape[-1] // 2 :]
218
+ return torch.cat((-x2, x1), dim=-1)
219
+
220
+
221
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
222
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
223
+ """Applies Rotary Position Embedding to the query and key tensors.
224
+
225
+ Args:
226
+ q (`torch.Tensor`): The query tensor.
227
+ k (`torch.Tensor`): The key tensor.
228
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
229
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
230
+ position_ids (`torch.Tensor`, *optional*):
231
+ Deprecated and unused.
232
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
233
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
234
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
235
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
236
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
237
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
238
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
239
+ Returns:
240
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
241
+ """
242
+ cos = cos.unsqueeze(unsqueeze_dim)
243
+ sin = sin.unsqueeze(unsqueeze_dim)
244
+ q_embed = (q * cos) + (rotate_half(q) * sin)
245
+ k_embed = (k * cos) + (rotate_half(k) * sin)
246
+ return q_embed, k_embed
247
+
248
+
249
+ class Phi3MLP(nn.Module):
250
+ def __init__(self, config):
251
+ super().__init__()
252
+
253
+ self.config = config
254
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
255
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
256
+
257
+ self.activation_fn = ACT2FN[config.hidden_act]
258
+
259
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
260
+ up_states = self.gate_up_proj(hidden_states)
261
+
262
+ gate, up_states = up_states.chunk(2, dim=-1)
263
+ up_states = up_states * self.activation_fn(gate)
264
+
265
+ return self.down_proj(up_states)
266
+
267
+
268
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
269
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
270
+ """
271
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
272
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
273
+ """
274
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
275
+ if n_rep == 1:
276
+ return hidden_states
277
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
278
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
279
+
280
+
281
+ class Phi3Attention(nn.Module):
282
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
283
+
284
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
285
+ super().__init__()
286
+ self.config = config
287
+ self.layer_idx = layer_idx
288
+ if layer_idx is None:
289
+ logger.warning_once(
290
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
291
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
292
+ "when creating this class."
293
+ )
294
+
295
+ self.attention_dropout = config.attention_dropout
296
+ self.hidden_size = config.hidden_size
297
+ self.num_heads = config.num_attention_heads
298
+ self.head_dim = self.hidden_size // self.num_heads
299
+ self.num_key_value_heads = config.num_key_value_heads
300
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
301
+ self.max_position_embeddings = config.max_position_embeddings
302
+ self.original_max_position_embeddings = config.original_max_position_embeddings
303
+ self.rope_theta = config.rope_theta
304
+ self.rope_scaling = config.rope_scaling
305
+ self.is_causal = True
306
+
307
+ if (self.head_dim * self.num_heads) != self.hidden_size:
308
+ raise ValueError(
309
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
310
+ f" and `num_heads`: {self.num_heads})."
311
+ )
312
+
313
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
314
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
315
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
316
+ self._init_rope()
317
+
318
+ def _init_rope(self):
319
+ if self.rope_scaling is None:
320
+ self.rotary_emb = Phi3RotaryEmbedding(
321
+ self.head_dim,
322
+ max_position_embeddings=self.max_position_embeddings,
323
+ base=self.rope_theta,
324
+ )
325
+ else:
326
+ scaling_type = self.config.rope_scaling["type"]
327
+ if scaling_type == "su":
328
+ self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
329
+ elif scaling_type == "yarn":
330
+ self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
331
+ else:
332
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
333
+
334
+ def forward(
335
+ self,
336
+ hidden_states: torch.Tensor,
337
+ attention_mask: Optional[torch.Tensor] = None,
338
+ position_ids: Optional[torch.LongTensor] = None,
339
+ past_key_value: Optional[Cache] = None,
340
+ output_attentions: bool = False,
341
+ use_cache: bool = False,
342
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
343
+ logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
344
+
345
+ bsz, q_len, _ = hidden_states.size()
346
+
347
+ qkv = self.qkv_proj(hidden_states)
348
+ query_pos = self.num_heads * self.head_dim
349
+ query_states = qkv[..., :query_pos]
350
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
351
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
352
+
353
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
354
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
355
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
356
+
357
+ kv_seq_len = key_states.shape[-2]
358
+ if past_key_value is not None:
359
+ if self.layer_idx is None:
360
+ raise ValueError(
361
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
362
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
363
+ "with a layer index."
364
+ )
365
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
366
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
367
+
368
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
369
+
370
+ if past_key_value is not None:
371
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
372
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
373
+
374
+ # repeat k/v heads if n_kv_heads < n_heads
375
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
376
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
377
+
378
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
379
+
380
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
381
+ raise ValueError(
382
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
383
+ f" {attn_weights.size()}"
384
+ )
385
+
386
+ if attention_mask is not None:
387
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
388
+ raise ValueError(
389
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
390
+ )
391
+ attn_weights = attn_weights + attention_mask
392
+
393
+ # upcast attention to fp32
394
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
395
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
396
+
397
+ attn_output = torch.matmul(attn_weights, value_states)
398
+
399
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
400
+ raise ValueError(
401
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
402
+ f" {attn_output.size()}"
403
+ )
404
+
405
+ attn_output = attn_output.transpose(1, 2).contiguous()
406
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
407
+
408
+ attn_output = self.o_proj(attn_output)
409
+
410
+ if not output_attentions:
411
+ attn_weights = None
412
+
413
+ return attn_output, attn_weights, past_key_value
414
+
415
+
416
+ class Phi3FlashAttention2(Phi3Attention):
417
+ """
418
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
419
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
420
+ flash attention and deal with padding tokens in case the input contains any of them.
421
+ """
422
+
423
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
424
+ def __init__(self, *args, **kwargs):
425
+ super().__init__(*args, **kwargs)
426
+
427
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
428
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
429
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
430
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
431
+
432
+ def forward(
433
+ self,
434
+ hidden_states: torch.Tensor,
435
+ attention_mask: Optional[torch.LongTensor] = None,
436
+ position_ids: Optional[torch.LongTensor] = None,
437
+ past_key_value: Optional[Cache] = None,
438
+ output_attentions: bool = False,
439
+ use_cache: bool = False,
440
+ **kwargs,
441
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
442
+ # Phi3FlashAttention2 attention does not support output_attentions
443
+
444
+ if not _flash_supports_window_size:
445
+ logger.warning_once(
446
+ "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
447
+ )
448
+ raise ValueError("The current flash attention version does not support sliding window attention.")
449
+
450
+ output_attentions = False
451
+
452
+ if "padding_mask" in kwargs:
453
+ warnings.warn(
454
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
455
+ )
456
+
457
+ # overwrite attention_mask with padding_mask
458
+ attention_mask = kwargs.pop("padding_mask")
459
+
460
+ bsz, q_len, _ = hidden_states.size()
461
+
462
+ qkv = self.qkv_proj(hidden_states)
463
+ query_pos = self.num_heads * self.head_dim
464
+ query_states = qkv[..., :query_pos]
465
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
466
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
467
+
468
+ # Flash attention requires the input to have the shape
469
+ # batch_size x seq_length x head_dim x hidden_dim
470
+ # therefore we just need to keep the original shape
471
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
472
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
473
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
474
+
475
+ kv_seq_len = key_states.shape[-2]
476
+ if past_key_value is not None:
477
+ if self.layer_idx is None:
478
+ raise ValueError(
479
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
480
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
481
+ "with a layer index."
482
+ )
483
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
484
+
485
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
486
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
487
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
488
+
489
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
490
+
491
+ use_sliding_windows = (
492
+ _flash_supports_window_size
493
+ and getattr(self.config, "sliding_window", None) is not None
494
+ and kv_seq_len > self.config.sliding_window
495
+ )
496
+
497
+ if past_key_value is not None:
498
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
499
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
500
+ if (
501
+ getattr(self.config, "sliding_window", None) is not None
502
+ and kv_seq_len > self.config.sliding_window
503
+ and cache_has_contents
504
+ ):
505
+ slicing_tokens = 1 - self.config.sliding_window
506
+
507
+ past_key = past_key_value[self.layer_idx][0]
508
+ past_value = past_key_value[self.layer_idx][1]
509
+
510
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
511
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
512
+
513
+ if past_key.shape[-2] != self.config.sliding_window - 1:
514
+ raise ValueError(
515
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
516
+ f" {past_key.shape}"
517
+ )
518
+
519
+ if attention_mask is not None:
520
+ attention_mask = attention_mask[:, slicing_tokens:]
521
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
522
+
523
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
524
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
525
+
526
+ # repeat k/v heads if n_kv_heads < n_heads
527
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
528
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
529
+
530
+ attn_dropout = self.attention_dropout if self.training else 0.0
531
+
532
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
533
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
534
+ # cast them back in the correct dtype just to be sure everything works as expected.
535
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
536
+ # in fp32.
537
+
538
+ if query_states.dtype == torch.float32:
539
+ if torch.is_autocast_enabled():
540
+ target_dtype = torch.get_autocast_gpu_dtype()
541
+ # Handle the case where the model is quantized
542
+ elif hasattr(self.config, "_pre_quantization_dtype"):
543
+ target_dtype = self.config._pre_quantization_dtype
544
+ else:
545
+ target_dtype = self.qkv_proj.weight.dtype
546
+
547
+ logger.warning_once(
548
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
549
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
550
+ f" {target_dtype}."
551
+ )
552
+
553
+ query_states = query_states.to(target_dtype)
554
+ key_states = key_states.to(target_dtype)
555
+ value_states = value_states.to(target_dtype)
556
+
557
+ # Reashape to the expected shape for Flash Attention
558
+ query_states = query_states.transpose(1, 2)
559
+ key_states = key_states.transpose(1, 2)
560
+ value_states = value_states.transpose(1, 2)
561
+
562
+ attn_output = self._flash_attention_forward(
563
+ query_states,
564
+ key_states,
565
+ value_states,
566
+ attention_mask,
567
+ q_len,
568
+ dropout=attn_dropout,
569
+ use_sliding_windows=use_sliding_windows,
570
+ )
571
+
572
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
573
+ attn_output = self.o_proj(attn_output)
574
+
575
+ if not output_attentions:
576
+ attn_weights = None
577
+
578
+ return attn_output, attn_weights, past_key_value
579
+
580
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
581
+ def _flash_attention_forward(
582
+ self,
583
+ query_states,
584
+ key_states,
585
+ value_states,
586
+ attention_mask,
587
+ query_length,
588
+ dropout=0.0,
589
+ softmax_scale=None,
590
+ use_sliding_windows=False,
591
+ ):
592
+ """
593
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
594
+ first unpad the input, then computes the attention scores and pad the final attention scores.
595
+
596
+ Args:
597
+ query_states (`torch.Tensor`):
598
+ Input query states to be passed to Flash Attention API
599
+ key_states (`torch.Tensor`):
600
+ Input key states to be passed to Flash Attention API
601
+ value_states (`torch.Tensor`):
602
+ Input value states to be passed to Flash Attention API
603
+ attention_mask (`torch.Tensor`):
604
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
605
+ position of padding tokens and 1 for the position of non-padding tokens.
606
+ dropout (`float`):
607
+ Attention dropout
608
+ softmax_scale (`float`, *optional*):
609
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
610
+ use_sliding_windows (`bool`, *optional*):
611
+ Whether to activate sliding window attention.
612
+ """
613
+ if not self._flash_attn_uses_top_left_mask:
614
+ causal = self.is_causal
615
+ else:
616
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
617
+ causal = self.is_causal and query_length != 1
618
+
619
+ # Contains at least one padding token in the sequence
620
+ if attention_mask is not None:
621
+ batch_size = query_states.shape[0]
622
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
623
+ query_states, key_states, value_states, attention_mask, query_length
624
+ )
625
+
626
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
627
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
628
+
629
+ if not use_sliding_windows:
630
+ attn_output_unpad = flash_attn_varlen_func(
631
+ query_states,
632
+ key_states,
633
+ value_states,
634
+ cu_seqlens_q=cu_seqlens_q,
635
+ cu_seqlens_k=cu_seqlens_k,
636
+ max_seqlen_q=max_seqlen_in_batch_q,
637
+ max_seqlen_k=max_seqlen_in_batch_k,
638
+ dropout_p=dropout,
639
+ softmax_scale=softmax_scale,
640
+ causal=causal,
641
+ )
642
+ else:
643
+ attn_output_unpad = flash_attn_varlen_func(
644
+ query_states,
645
+ key_states,
646
+ value_states,
647
+ cu_seqlens_q=cu_seqlens_q,
648
+ cu_seqlens_k=cu_seqlens_k,
649
+ max_seqlen_q=max_seqlen_in_batch_q,
650
+ max_seqlen_k=max_seqlen_in_batch_k,
651
+ dropout_p=dropout,
652
+ softmax_scale=softmax_scale,
653
+ causal=causal,
654
+ window_size=(self.config.sliding_window, self.config.sliding_window),
655
+ )
656
+
657
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
658
+ else:
659
+ if not use_sliding_windows:
660
+ attn_output = flash_attn_func(
661
+ query_states,
662
+ key_states,
663
+ value_states,
664
+ dropout,
665
+ softmax_scale=softmax_scale,
666
+ causal=causal,
667
+ )
668
+ else:
669
+ attn_output = flash_attn_func(
670
+ query_states,
671
+ key_states,
672
+ value_states,
673
+ dropout,
674
+ softmax_scale=softmax_scale,
675
+ causal=causal,
676
+ window_size=(self.config.sliding_window, self.config.sliding_window),
677
+ )
678
+
679
+ return attn_output
680
+
681
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
682
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
683
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
684
+
685
+ # On the first iteration we need to properly re-create the padding mask
686
+ # by slicing it on the proper place
687
+ if kv_seq_len != attention_mask.shape[-1]:
688
+ attention_mask_num_tokens = attention_mask.shape[-1]
689
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
690
+
691
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
692
+
693
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
694
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
695
+
696
+ if query_length == kv_seq_len:
697
+ query_layer = index_first_axis(
698
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
699
+ )
700
+ cu_seqlens_q = cu_seqlens_k
701
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
702
+ indices_q = indices_k
703
+ elif query_length == 1:
704
+ max_seqlen_in_batch_q = 1
705
+ cu_seqlens_q = torch.arange(
706
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
707
+ ) # There is a memcpy here, that is very bad.
708
+ indices_q = cu_seqlens_q[:-1]
709
+ query_layer = query_layer.squeeze(1)
710
+ else:
711
+ # The -q_len: slice assumes left padding.
712
+ attention_mask = attention_mask[:, -query_length:]
713
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
714
+
715
+ return (
716
+ query_layer,
717
+ key_layer,
718
+ value_layer,
719
+ indices_q,
720
+ (cu_seqlens_q, cu_seqlens_k),
721
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
722
+ )
723
+
724
+
725
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
726
+ # TODO @Arthur no longer copied from LLama after static cache
727
+ class Phi3SdpaAttention(Phi3Attention):
728
+ """
729
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
730
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
731
+ SDPA API.
732
+ """
733
+
734
+ # Adapted from Phi3Attention.forward
735
+ def forward(
736
+ self,
737
+ hidden_states: torch.Tensor,
738
+ attention_mask: Optional[torch.Tensor] = None,
739
+ position_ids: Optional[torch.LongTensor] = None,
740
+ past_key_value: Optional[Cache] = None,
741
+ output_attentions: bool = False,
742
+ use_cache: bool = False,
743
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
744
+ if output_attentions:
745
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
746
+ logger.warning_once(
747
+ "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
748
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
749
+ )
750
+ return super().forward(
751
+ hidden_states=hidden_states,
752
+ attention_mask=attention_mask,
753
+ position_ids=position_ids,
754
+ past_key_value=past_key_value,
755
+ output_attentions=output_attentions,
756
+ use_cache=use_cache,
757
+ )
758
+
759
+ bsz, q_len, _ = hidden_states.size()
760
+
761
+ qkv = self.qkv_proj(hidden_states)
762
+ query_pos = self.num_heads * self.head_dim
763
+ query_states = qkv[..., :query_pos]
764
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
765
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
766
+
767
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
768
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
769
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
770
+
771
+ kv_seq_len = key_states.shape[-2]
772
+ if past_key_value is not None:
773
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
774
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
775
+
776
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
777
+
778
+ if past_key_value is not None:
779
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
780
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
781
+
782
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
783
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
784
+
785
+ if attention_mask is not None:
786
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
787
+ raise ValueError(
788
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
789
+ )
790
+
791
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
792
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
793
+ if query_states.device.type == "cuda" and attention_mask is not None:
794
+ query_states = query_states.contiguous()
795
+ key_states = key_states.contiguous()
796
+ value_states = value_states.contiguous()
797
+
798
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
799
+ query_states,
800
+ key_states,
801
+ value_states,
802
+ attn_mask=attention_mask,
803
+ dropout_p=self.attention_dropout if self.training else 0.0,
804
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
805
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
806
+ )
807
+
808
+ attn_output = attn_output.transpose(1, 2).contiguous()
809
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
810
+
811
+ attn_output = self.o_proj(attn_output)
812
+
813
+ return attn_output, None, past_key_value
814
+
815
+
816
+ PHI3_ATTENTION_CLASSES = {
817
+ "eager": Phi3Attention,
818
+ "flash_attention_2": Phi3FlashAttention2,
819
+ "sdpa": Phi3SdpaAttention,
820
+ }
821
+
822
+
823
+ class Phi3DecoderLayer(nn.Module):
824
+ def __init__(self, config: Phi3Config, layer_idx: int):
825
+ super().__init__()
826
+
827
+ self.config = config
828
+ self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
829
+
830
+ self.mlp = Phi3MLP(config)
831
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
832
+
833
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
834
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
835
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
836
+
837
+ def forward(
838
+ self,
839
+ hidden_states: torch.Tensor,
840
+ attention_mask: Optional[torch.Tensor] = None,
841
+ position_ids: Optional[torch.LongTensor] = None,
842
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
843
+ output_attentions: Optional[bool] = False,
844
+ use_cache: Optional[bool] = False,
845
+ **kwargs,
846
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
847
+ if "padding_mask" in kwargs:
848
+ warnings.warn(
849
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
850
+ )
851
+ """
852
+ Args:
853
+ hidden_states (`torch.FloatTensor`):
854
+ input to the layer of shape `(batch, seq_len, embed_dim)`
855
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
856
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
857
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
858
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
859
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
860
+ output_attentions (`bool`, *optional*):
861
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
862
+ returned tensors for more detail.
863
+ use_cache (`bool`, *optional*):
864
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
865
+ (see `past_key_values`).
866
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
867
+ """
868
+
869
+ residual = hidden_states
870
+
871
+ hidden_states = self.input_layernorm(hidden_states)
872
+
873
+ # Self Attention
874
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
875
+ hidden_states=hidden_states,
876
+ attention_mask=attention_mask,
877
+ position_ids=position_ids,
878
+ past_key_value=past_key_value,
879
+ output_attentions=output_attentions,
880
+ use_cache=use_cache,
881
+ )
882
+
883
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
884
+
885
+ residual = hidden_states
886
+ hidden_states = self.post_attention_layernorm(hidden_states)
887
+ hidden_states = self.mlp(hidden_states)
888
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
889
+
890
+ outputs = (hidden_states,)
891
+
892
+ if output_attentions:
893
+ outputs += (self_attn_weights,)
894
+
895
+ if use_cache:
896
+ outputs += (present_key_value,)
897
+
898
+ return outputs
899
+
900
+
901
+ PHI3_START_DOCSTRING = r"""
902
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
903
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
904
+ etc.)
905
+
906
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
907
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
908
+ and behavior.
909
+
910
+ Parameters:
911
+ config ([`Phi3Config`]):
912
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
913
+ load the weights associated with the model, only the configuration. Check out the
914
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
915
+ """
916
+
917
+
918
+ @add_start_docstrings(
919
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
920
+ PHI3_START_DOCSTRING,
921
+ )
922
+ class Phi3PreTrainedModel(PreTrainedModel):
923
+ config_class = Phi3Config
924
+ base_model_prefix = "model"
925
+ supports_gradient_checkpointing = True
926
+ _no_split_modules = ["Phi3DecoderLayer"]
927
+ _skip_keys_device_placement = "past_key_values"
928
+ _supports_flash_attn_2 = True
929
+ _supports_sdpa = False
930
+ _supports_cache_class = True
931
+
932
+ _version = "0.0.5"
933
+
934
+ def _init_weights(self, module):
935
+ std = self.config.initializer_range
936
+ if isinstance(module, nn.Linear):
937
+ module.weight.data.normal_(mean=0.0, std=std)
938
+ if module.bias is not None:
939
+ module.bias.data.zero_()
940
+ elif isinstance(module, nn.Embedding):
941
+ module.weight.data.normal_(mean=0.0, std=std)
942
+ if module.padding_idx is not None:
943
+ module.weight.data[module.padding_idx].zero_()
944
+
945
+
946
+ PHI3_INPUTS_DOCSTRING = r"""
947
+ Args:
948
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
949
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
950
+ it.
951
+
952
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
953
+ [`PreTrainedTokenizer.__call__`] for details.
954
+
955
+ [What are input IDs?](../glossary#input-ids)
956
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
957
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
958
+
959
+ - 1 for tokens that are **not masked**,
960
+ - 0 for tokens that are **masked**.
961
+
962
+ [What are attention masks?](../glossary#attention-mask)
963
+
964
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
965
+ [`PreTrainedTokenizer.__call__`] for details.
966
+
967
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
968
+ `past_key_values`).
969
+
970
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
971
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
972
+ information on the default strategy.
973
+
974
+ - 1 indicates the head is **not masked**,
975
+ - 0 indicates the head is **masked**.
976
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
977
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
978
+ config.n_positions - 1]`.
979
+
980
+ [What are position IDs?](../glossary#position-ids)
981
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
982
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
983
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
984
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
985
+
986
+ Two formats are allowed:
987
+ - a [`~cache_utils.Cache`] instance;
988
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
989
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
990
+ cache format.
991
+
992
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
993
+ legacy cache format will be returned.
994
+
995
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
996
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
997
+ of shape `(batch_size, sequence_length)`.
998
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
999
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1000
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1001
+ model's internal embedding lookup matrix.
1002
+ use_cache (`bool`, *optional*):
1003
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1004
+ `past_key_values`).
1005
+ output_attentions (`bool`, *optional*):
1006
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1007
+ tensors for more detail.
1008
+ output_hidden_states (`bool`, *optional*):
1009
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1010
+ more detail.
1011
+ return_dict (`bool`, *optional*):
1012
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1013
+ """
1014
+
1015
+
1016
+ @add_start_docstrings(
1017
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
1018
+ PHI3_START_DOCSTRING,
1019
+ )
1020
+ class Phi3Model(Phi3PreTrainedModel):
1021
+ """
1022
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
1023
+
1024
+ Args:
1025
+ config: Phi3Config
1026
+ """
1027
+
1028
+ def __init__(self, config: Phi3Config):
1029
+ super().__init__(config)
1030
+ self.padding_idx = config.pad_token_id
1031
+ self.vocab_size = config.vocab_size
1032
+
1033
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1034
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
1035
+ self.layers = nn.ModuleList(
1036
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1037
+ )
1038
+ self._attn_implementation = config._attn_implementation
1039
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1040
+
1041
+ self.gradient_checkpointing = False
1042
+ # Initialize weights and apply final processing
1043
+ self.post_init()
1044
+
1045
+ def get_input_embeddings(self):
1046
+ return self.embed_tokens
1047
+
1048
+ def set_input_embeddings(self, value):
1049
+ self.embed_tokens = value
1050
+
1051
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1052
+ def forward(
1053
+ self,
1054
+ input_ids: torch.LongTensor = None,
1055
+ attention_mask: Optional[torch.Tensor] = None,
1056
+ position_ids: Optional[torch.LongTensor] = None,
1057
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1058
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1059
+ use_cache: Optional[bool] = None,
1060
+ output_attentions: Optional[bool] = None,
1061
+ output_hidden_states: Optional[bool] = None,
1062
+ return_dict: Optional[bool] = None,
1063
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1064
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1065
+ output_hidden_states = (
1066
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1067
+ )
1068
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1069
+
1070
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1071
+
1072
+ # retrieve input_ids and inputs_embeds
1073
+ if input_ids is not None and inputs_embeds is not None:
1074
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1075
+ elif input_ids is not None:
1076
+ batch_size, seq_length = input_ids.shape[:2]
1077
+ elif inputs_embeds is not None:
1078
+ batch_size, seq_length = inputs_embeds.shape[:2]
1079
+ else:
1080
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1081
+
1082
+ past_key_values_length = 0
1083
+
1084
+ if self.gradient_checkpointing and self.training:
1085
+ if use_cache:
1086
+ logger.warning_once(
1087
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1088
+ )
1089
+ use_cache = False
1090
+
1091
+ if use_cache:
1092
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1093
+ if use_legacy_cache:
1094
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1095
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1096
+
1097
+ if position_ids is None:
1098
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1099
+ position_ids = torch.arange(
1100
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1101
+ )
1102
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1103
+ else:
1104
+ position_ids = position_ids.view(-1, seq_length).long()
1105
+
1106
+ if inputs_embeds is None:
1107
+ inputs_embeds = self.embed_tokens(input_ids)
1108
+
1109
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1110
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1111
+ if is_padding_right:
1112
+ raise ValueError(
1113
+ "You are attempting to perform batched generation with padding_side='right'"
1114
+ " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
1115
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1116
+ )
1117
+
1118
+ if self._attn_implementation == "flash_attention_2":
1119
+ # 2d mask is passed through the layers
1120
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1121
+ else:
1122
+ # 4d mask is passed through the layers
1123
+ attention_mask = _prepare_4d_causal_attention_mask(
1124
+ attention_mask,
1125
+ (batch_size, seq_length),
1126
+ inputs_embeds,
1127
+ past_key_values_length,
1128
+ sliding_window=self.config.sliding_window,
1129
+ )
1130
+
1131
+ hidden_states = inputs_embeds
1132
+
1133
+ # decoder layers
1134
+ all_hidden_states = () if output_hidden_states else None
1135
+ all_self_attns = () if output_attentions else None
1136
+ next_decoder_cache = None
1137
+
1138
+ for decoder_layer in self.layers:
1139
+ if output_hidden_states:
1140
+ all_hidden_states += (hidden_states,)
1141
+
1142
+ if self.gradient_checkpointing and self.training:
1143
+ layer_outputs = self._gradient_checkpointing_func(
1144
+ decoder_layer.__call__,
1145
+ hidden_states,
1146
+ attention_mask,
1147
+ position_ids,
1148
+ past_key_values,
1149
+ output_attentions,
1150
+ use_cache,
1151
+ )
1152
+ else:
1153
+ layer_outputs = decoder_layer(
1154
+ hidden_states,
1155
+ attention_mask=attention_mask,
1156
+ position_ids=position_ids,
1157
+ past_key_value=past_key_values,
1158
+ output_attentions=output_attentions,
1159
+ use_cache=use_cache,
1160
+ )
1161
+
1162
+ hidden_states = layer_outputs[0]
1163
+
1164
+ if use_cache:
1165
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1166
+
1167
+ if output_attentions:
1168
+ all_self_attns += (layer_outputs[1],)
1169
+
1170
+ hidden_states = self.norm(hidden_states)
1171
+
1172
+ # add hidden states from the last decoder layer
1173
+ if output_hidden_states:
1174
+ all_hidden_states += (hidden_states,)
1175
+
1176
+ next_cache = None
1177
+ if use_cache:
1178
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1179
+ if not return_dict:
1180
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1181
+ return BaseModelOutputWithPast(
1182
+ last_hidden_state=hidden_states,
1183
+ past_key_values=next_cache,
1184
+ hidden_states=all_hidden_states,
1185
+ attentions=all_self_attns,
1186
+ )
1187
+
1188
+
1189
+ class Phi3ForCausalLM(Phi3PreTrainedModel):
1190
+ _tied_weights_keys = ["lm_head.weight"]
1191
+
1192
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1193
+ def __init__(self, config):
1194
+ super().__init__(config)
1195
+ self.model = Phi3Model(config)
1196
+ self.vocab_size = config.vocab_size
1197
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1198
+
1199
+ # Initialize weights and apply final processing
1200
+ self.post_init()
1201
+
1202
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1203
+ def get_input_embeddings(self):
1204
+ return self.model.embed_tokens
1205
+
1206
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1207
+ def set_input_embeddings(self, value):
1208
+ self.model.embed_tokens = value
1209
+
1210
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1211
+ def get_output_embeddings(self):
1212
+ return self.lm_head
1213
+
1214
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1215
+ def set_output_embeddings(self, new_embeddings):
1216
+ self.lm_head = new_embeddings
1217
+
1218
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1219
+ def set_decoder(self, decoder):
1220
+ self.model = decoder
1221
+
1222
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1223
+ def get_decoder(self):
1224
+ return self.model
1225
+
1226
+ # Ignore copy
1227
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1228
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1229
+ def forward(
1230
+ self,
1231
+ input_ids: torch.LongTensor = None,
1232
+ attention_mask: Optional[torch.Tensor] = None,
1233
+ position_ids: Optional[torch.LongTensor] = None,
1234
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1235
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1236
+ labels: Optional[torch.LongTensor] = None,
1237
+ use_cache: Optional[bool] = None,
1238
+ output_attentions: Optional[bool] = None,
1239
+ output_hidden_states: Optional[bool] = None,
1240
+ return_dict: Optional[bool] = None,
1241
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1242
+ r"""
1243
+ Args:
1244
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1245
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1246
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1247
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1248
+
1249
+ Returns:
1250
+
1251
+ Example:
1252
+
1253
+ ```python
1254
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1255
+
1256
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1257
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1258
+
1259
+ >>> prompt = "This is an example script ."
1260
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1261
+
1262
+ >>> # Generate
1263
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1264
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1265
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1266
+ ```"""
1267
+
1268
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1269
+ output_hidden_states = (
1270
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1271
+ )
1272
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1273
+
1274
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1275
+ outputs = self.model(
1276
+ input_ids=input_ids,
1277
+ attention_mask=attention_mask,
1278
+ position_ids=position_ids,
1279
+ past_key_values=past_key_values,
1280
+ inputs_embeds=inputs_embeds,
1281
+ use_cache=use_cache,
1282
+ output_attentions=output_attentions,
1283
+ output_hidden_states=output_hidden_states,
1284
+ return_dict=return_dict,
1285
+ )
1286
+
1287
+ hidden_states = outputs[0]
1288
+ logits = self.lm_head(hidden_states)
1289
+ logits = logits.float()
1290
+
1291
+ loss = None
1292
+ if labels is not None:
1293
+ # Shift so that tokens < n predict n
1294
+ shift_logits = logits[..., :-1, :].contiguous()
1295
+ shift_labels = labels[..., 1:].contiguous()
1296
+ # Flatten the tokens
1297
+ loss_fct = CrossEntropyLoss()
1298
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1299
+ shift_labels = shift_labels.view(-1)
1300
+ # Enable model parallelism
1301
+ shift_labels = shift_labels.to(shift_logits.device)
1302
+ loss = loss_fct(shift_logits, shift_labels)
1303
+
1304
+ if not return_dict:
1305
+ output = (logits,) + outputs[1:]
1306
+ return (loss,) + output if loss is not None else output
1307
+
1308
+ return CausalLMOutputWithPast(
1309
+ loss=loss,
1310
+ logits=logits,
1311
+ past_key_values=outputs.past_key_values,
1312
+ hidden_states=outputs.hidden_states,
1313
+ attentions=outputs.attentions,
1314
+ )
1315
+
1316
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1317
+ def prepare_inputs_for_generation(
1318
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1319
+ ):
1320
+ if past_key_values is not None:
1321
+ if isinstance(past_key_values, Cache):
1322
+ cache_length = past_key_values.get_seq_length()
1323
+ past_length = past_key_values.seen_tokens
1324
+ max_cache_length = past_key_values.get_max_length()
1325
+ else:
1326
+ cache_length = past_length = past_key_values[0][0].shape[2]
1327
+ max_cache_length = None
1328
+
1329
+ # Keep only the unprocessed tokens:
1330
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1331
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1332
+ # input)
1333
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1334
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1335
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1336
+ # input_ids based on the past_length.
1337
+ elif past_length < input_ids.shape[1]:
1338
+ input_ids = input_ids[:, past_length:]
1339
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1340
+ else:
1341
+ remove_prefix_length = input_ids.shape[1] - 1
1342
+ input_ids = input_ids[:, remove_prefix_length:]
1343
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1344
+ if (
1345
+ max_cache_length is not None
1346
+ and attention_mask is not None
1347
+ and cache_length + input_ids.shape[1] > max_cache_length
1348
+ ):
1349
+ attention_mask = attention_mask[:, -max_cache_length:]
1350
+
1351
+ position_ids = kwargs.get("position_ids", None)
1352
+ if attention_mask is not None and position_ids is None:
1353
+ # create position_ids on the fly for batch generation
1354
+ position_ids = attention_mask.long().cumsum(-1) - 1
1355
+ position_ids.masked_fill_(attention_mask == 0, 1)
1356
+ if past_key_values:
1357
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1358
+
1359
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1360
+ if inputs_embeds is not None and past_key_values is None:
1361
+ model_inputs = {"inputs_embeds": inputs_embeds}
1362
+ else:
1363
+ model_inputs = {"input_ids": input_ids}
1364
+
1365
+ model_inputs.update(
1366
+ {
1367
+ "position_ids": position_ids,
1368
+ "past_key_values": past_key_values,
1369
+ "use_cache": kwargs.get("use_cache"),
1370
+ "attention_mask": attention_mask,
1371
+ }
1372
+ )
1373
+ return model_inputs
1374
+
1375
+ @staticmethod
1376
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1377
+ def _reorder_cache(past_key_values, beam_idx):
1378
+ reordered_past = ()
1379
+ for layer_past in past_key_values:
1380
+ reordered_past += (
1381
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1382
+ )
1383
+ return reordered_past
1384
+
1385
+
1386
+ @add_start_docstrings(
1387
+ """
1388
+ The [`Phi3Model`] with a sequence classification head on top (linear layer).
1389
+
1390
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1391
+ (e.g. GPT-2) do.
1392
+
1393
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1394
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1395
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1396
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1397
+ each row of the batch).
1398
+ """,
1399
+ PHI3_START_DOCSTRING,
1400
+ )
1401
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1402
+ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1403
+ def __init__(self, config):
1404
+ super().__init__(config)
1405
+ self.num_labels = config.num_labels
1406
+ self.model = Phi3Model(config)
1407
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1408
+
1409
+ # Initialize weights and apply final processing
1410
+ self.post_init()
1411
+
1412
+ def get_input_embeddings(self):
1413
+ return self.model.embed_tokens
1414
+
1415
+ def set_input_embeddings(self, value):
1416
+ self.model.embed_tokens = value
1417
+
1418
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1419
+ def forward(
1420
+ self,
1421
+ input_ids: torch.LongTensor = None,
1422
+ attention_mask: Optional[torch.Tensor] = None,
1423
+ position_ids: Optional[torch.LongTensor] = None,
1424
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1425
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1426
+ labels: Optional[torch.LongTensor] = None,
1427
+ use_cache: Optional[bool] = None,
1428
+ output_attentions: Optional[bool] = None,
1429
+ output_hidden_states: Optional[bool] = None,
1430
+ return_dict: Optional[bool] = None,
1431
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1432
+ r"""
1433
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1434
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1435
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1436
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1437
+ """
1438
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1439
+
1440
+ model_outputs = self.model(
1441
+ input_ids,
1442
+ attention_mask=attention_mask,
1443
+ position_ids=position_ids,
1444
+ past_key_values=past_key_values,
1445
+ inputs_embeds=inputs_embeds,
1446
+ use_cache=use_cache,
1447
+ output_attentions=output_attentions,
1448
+ output_hidden_states=output_hidden_states,
1449
+ return_dict=return_dict,
1450
+ )
1451
+ hidden_states = model_outputs[0]
1452
+ logits = self.score(hidden_states)
1453
+
1454
+ if input_ids is not None:
1455
+ batch_size = input_ids.shape[0]
1456
+ else:
1457
+ batch_size = inputs_embeds.shape[0]
1458
+
1459
+ if self.config.pad_token_id is None and batch_size != 1:
1460
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1461
+ if self.config.pad_token_id is None:
1462
+ sequence_lengths = -1
1463
+ else:
1464
+ if input_ids is not None:
1465
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1466
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1467
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1468
+ sequence_lengths = sequence_lengths.to(logits.device)
1469
+ else:
1470
+ sequence_lengths = -1
1471
+
1472
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1473
+
1474
+ loss = None
1475
+ if labels is not None:
1476
+ labels = labels.to(logits.device)
1477
+ if self.config.problem_type is None:
1478
+ if self.num_labels == 1:
1479
+ self.config.problem_type = "regression"
1480
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1481
+ self.config.problem_type = "single_label_classification"
1482
+ else:
1483
+ self.config.problem_type = "multi_label_classification"
1484
+
1485
+ if self.config.problem_type == "regression":
1486
+ loss_fct = MSELoss()
1487
+ if self.num_labels == 1:
1488
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1489
+ else:
1490
+ loss = loss_fct(pooled_logits, labels)
1491
+ elif self.config.problem_type == "single_label_classification":
1492
+ loss_fct = CrossEntropyLoss()
1493
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1494
+ elif self.config.problem_type == "multi_label_classification":
1495
+ loss_fct = BCEWithLogitsLoss()
1496
+ loss = loss_fct(pooled_logits, labels)
1497
+ if not return_dict:
1498
+ output = (pooled_logits,) + model_outputs[1:]
1499
+ return ((loss,) + output) if loss is not None else output
1500
+
1501
+ return SequenceClassifierOutputWithPast(
1502
+ loss=loss,
1503
+ logits=pooled_logits,
1504
+ past_key_values=model_outputs.past_key_values,
1505
+ hidden_states=model_outputs.hidden_states,
1506
+ attentions=model_outputs.attentions,
1507
+ )
1508
+
1509
+
1510
+ @add_start_docstrings(
1511
+ """
1512
+ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1513
+ Named-Entity-Recognition (NER) tasks.
1514
+ """,
1515
+ PHI3_START_DOCSTRING,
1516
+ )
1517
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1518
+ class Phi3ForTokenClassification(Phi3PreTrainedModel):
1519
+ def __init__(self, config: Phi3Config):
1520
+ super().__init__(config)
1521
+ self.num_labels = config.num_labels
1522
+
1523
+ self.model = Phi3Model(config)
1524
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1525
+ classifier_dropout = config.classifier_dropout
1526
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1527
+ classifier_dropout = config.hidden_dropout
1528
+ else:
1529
+ classifier_dropout = 0.1
1530
+ self.dropout = nn.Dropout(classifier_dropout)
1531
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1532
+
1533
+ # Initialize weights and apply final processing
1534
+ self.post_init()
1535
+
1536
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1537
+ @add_code_sample_docstrings(
1538
+ checkpoint=_CHECKPOINT_FOR_DOC,
1539
+ output_type=TokenClassifierOutput,
1540
+ config_class=_CONFIG_FOR_DOC,
1541
+ )
1542
+ def forward(
1543
+ self,
1544
+ input_ids: Optional[torch.LongTensor] = None,
1545
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1546
+ attention_mask: Optional[torch.Tensor] = None,
1547
+ inputs_embeds: Optional[torch.Tensor] = None,
1548
+ labels: Optional[torch.Tensor] = None,
1549
+ use_cache: Optional[bool] = None,
1550
+ output_attentions: Optional[bool] = None,
1551
+ output_hidden_states: Optional[bool] = None,
1552
+ return_dict: Optional[bool] = None,
1553
+ **deprecated_arguments,
1554
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1555
+ r"""
1556
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1557
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1558
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1559
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1560
+ """
1561
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1562
+
1563
+ model_outputs = self.model(
1564
+ input_ids,
1565
+ past_key_values=past_key_values,
1566
+ attention_mask=attention_mask,
1567
+ inputs_embeds=inputs_embeds,
1568
+ use_cache=use_cache,
1569
+ output_attentions=output_attentions,
1570
+ output_hidden_states=output_hidden_states,
1571
+ return_dict=return_dict,
1572
+ )
1573
+
1574
+ hidden_states = model_outputs[0]
1575
+ hidden_states = self.dropout(hidden_states)
1576
+ logits = self.classifier(hidden_states)
1577
+
1578
+ loss = None
1579
+ if labels is not None:
1580
+ # move labels to correct device to enable model parallelism
1581
+ labels = labels.to(logits.device)
1582
+ batch_size, seq_length = labels.shape
1583
+ loss_fct = CrossEntropyLoss()
1584
+ loss = loss_fct(
1585
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1586
+ )
1587
+
1588
+ if not return_dict:
1589
+ output = (logits,) + model_outputs[2:]
1590
+ return ((loss,) + output) if loss is not None else output
1591
+
1592
+ return TokenClassifierOutput(
1593
+ loss=loss,
1594
+ logits=logits,
1595
+ hidden_states=model_outputs.hidden_states,
1596
+ attentions=model_outputs.attentions,
1597
+ )
bunny/model/language_model/qwen2/__init__.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import (
17
+ OptionalDependencyNotAvailable,
18
+ _LazyModule,
19
+ is_tokenizers_available,
20
+ is_torch_available,
21
+ )
22
+
23
+
24
+ _import_structure = {
25
+ "configuration_qwen2": ["QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Qwen2Config"],
26
+ "tokenization_qwen2": ["Qwen2Tokenizer"],
27
+ }
28
+
29
+ try:
30
+ if not is_tokenizers_available():
31
+ raise OptionalDependencyNotAvailable()
32
+ except OptionalDependencyNotAvailable:
33
+ pass
34
+ else:
35
+ _import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
36
+
37
+ try:
38
+ if not is_torch_available():
39
+ raise OptionalDependencyNotAvailable()
40
+ except OptionalDependencyNotAvailable:
41
+ pass
42
+ else:
43
+ _import_structure["modeling_qwen2"] = [
44
+ "Qwen2ForCausalLM",
45
+ "Qwen2Model",
46
+ "Qwen2PreTrainedModel",
47
+ "Qwen2ForSequenceClassification",
48
+ ]
49
+
50
+
51
+ if TYPE_CHECKING:
52
+ from .configuration_qwen2 import QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP, Qwen2Config
53
+ from .tokenization_qwen2 import Qwen2Tokenizer
54
+
55
+ try:
56
+ if not is_tokenizers_available():
57
+ raise OptionalDependencyNotAvailable()
58
+ except OptionalDependencyNotAvailable:
59
+ pass
60
+ else:
61
+ from .tokenization_qwen2_fast import Qwen2TokenizerFast
62
+
63
+ try:
64
+ if not is_torch_available():
65
+ raise OptionalDependencyNotAvailable()
66
+ except OptionalDependencyNotAvailable:
67
+ pass
68
+ else:
69
+ from .modeling_qwen2 import (
70
+ Qwen2ForCausalLM,
71
+ Qwen2ForSequenceClassification,
72
+ Qwen2Model,
73
+ Qwen2PreTrainedModel,
74
+ )
75
+
76
+
77
+ else:
78
+ import sys
79
+
80
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
bunny/model/language_model/qwen2/configuration_qwen2.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Qwen2 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json",
25
+ }
26
+
27
+
28
+ class Qwen2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
31
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
+ with the defaults will yield a similar configuration to that of
33
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 151936):
41
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`Qwen2Model`]
43
+ hidden_size (`int`, *optional*, defaults to 4096):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 22016):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ num_key_value_heads (`int`, *optional*, defaults to 32):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details checkout [this
57
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
61
+ The maximum sequence length that this model might ever be used with.
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
70
+ Whether the model's input and output word embeddings should be tied.
71
+ rope_theta (`float`, *optional*, defaults to 10000.0):
72
+ The base period of the RoPE embeddings.
73
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
74
+ Whether to use sliding window attention.
75
+ sliding_window (`int`, *optional*, defaults to 4096):
76
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
77
+ max_window_layers (`int`, *optional*, defaults to 28):
78
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
79
+ attention_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout ratio for the attention probabilities.
81
+
82
+ ```python
83
+ >>> from transformers import Qwen2Model, Qwen2Config
84
+
85
+ >>> # Initializing a Qwen2 style configuration
86
+ >>> configuration = Qwen2Config()
87
+
88
+ >>> # Initializing a model from the Qwen2-7B style configuration
89
+ >>> model = Qwen2Model(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+
95
+ model_type = "qwen2"
96
+ keys_to_ignore_at_inference = ["past_key_values"]
97
+
98
+ def __init__(
99
+ self,
100
+ vocab_size=151936,
101
+ hidden_size=4096,
102
+ intermediate_size=22016,
103
+ num_hidden_layers=32,
104
+ num_attention_heads=32,
105
+ num_key_value_heads=32,
106
+ hidden_act="silu",
107
+ max_position_embeddings=32768,
108
+ initializer_range=0.02,
109
+ rms_norm_eps=1e-6,
110
+ use_cache=True,
111
+ tie_word_embeddings=False,
112
+ rope_theta=10000.0,
113
+ use_sliding_window=False,
114
+ sliding_window=4096,
115
+ max_window_layers=28,
116
+ attention_dropout=0.0,
117
+ **kwargs,
118
+ ):
119
+ self.vocab_size = vocab_size
120
+ self.max_position_embeddings = max_position_embeddings
121
+ self.hidden_size = hidden_size
122
+ self.intermediate_size = intermediate_size
123
+ self.num_hidden_layers = num_hidden_layers
124
+ self.num_attention_heads = num_attention_heads
125
+ self.use_sliding_window = use_sliding_window
126
+ self.sliding_window = sliding_window
127
+ self.max_window_layers = max_window_layers
128
+
129
+ # for backward compatibility
130
+ if num_key_value_heads is None:
131
+ num_key_value_heads = num_attention_heads
132
+
133
+ self.num_key_value_heads = num_key_value_heads
134
+ self.hidden_act = hidden_act
135
+ self.initializer_range = initializer_range
136
+ self.rms_norm_eps = rms_norm_eps
137
+ self.use_cache = use_cache
138
+ self.rope_theta = rope_theta
139
+ self.attention_dropout = attention_dropout
140
+
141
+ super().__init__(
142
+ tie_word_embeddings=tie_word_embeddings,
143
+ **kwargs,
144
+ )
bunny/model/language_model/qwen2/modeling_qwen2.py ADDED
@@ -0,0 +1,1403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Qwen2 model."""
21
+ import inspect
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
35
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ is_flash_attn_2_available,
41
+ is_flash_attn_greater_or_equal_2_10,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from .configuration_qwen2 import Qwen2Config
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
50
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
+
52
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+
58
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
+ _CONFIG_FOR_DOC = "Qwen2Config"
60
+
61
+ QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
62
+ "Qwen/Qwen2-7B-beta",
63
+ # See all Qwen2 models at https://huggingface.co/models?filter=qwen2
64
+ ]
65
+
66
+
67
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
+
80
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
81
+ class Qwen2RMSNorm(nn.Module):
82
+ def __init__(self, hidden_size, eps=1e-6):
83
+ """
84
+ Qwen2RMSNorm is equivalent to T5LayerNorm
85
+ """
86
+ super().__init__()
87
+ self.weight = nn.Parameter(torch.ones(hidden_size))
88
+ self.variance_epsilon = eps
89
+
90
+ def forward(self, hidden_states):
91
+ input_dtype = hidden_states.dtype
92
+ hidden_states = hidden_states.to(torch.float32)
93
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
94
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
95
+ return self.weight * hidden_states.to(input_dtype)
96
+
97
+
98
+ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
99
+ class Qwen2RotaryEmbedding(nn.Module):
100
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
101
+ super().__init__()
102
+
103
+ self.dim = dim
104
+ self.max_position_embeddings = max_position_embeddings
105
+ self.base = base
106
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
107
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
108
+
109
+ # Build here to make `torch.jit.trace` work.
110
+ self._set_cos_sin_cache(
111
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
112
+ )
113
+
114
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
115
+ self.max_seq_len_cached = seq_len
116
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
117
+
118
+ freqs = torch.outer(t, self.inv_freq)
119
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
120
+ emb = torch.cat((freqs, freqs), dim=-1)
121
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
122
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
123
+
124
+ def forward(self, x, seq_len=None):
125
+ # x: [bs, num_attention_heads, seq_len, head_size]
126
+ if seq_len > self.max_seq_len_cached:
127
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
128
+
129
+ return (
130
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
131
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
132
+ )
133
+
134
+
135
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
136
+ def rotate_half(x):
137
+ """Rotates half the hidden dims of the input."""
138
+ x1 = x[..., : x.shape[-1] // 2]
139
+ x2 = x[..., x.shape[-1] // 2 :]
140
+ return torch.cat((-x2, x1), dim=-1)
141
+
142
+
143
+ # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
144
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
145
+ """Applies Rotary Position Embedding to the query and key tensors.
146
+
147
+ Args:
148
+ q (`torch.Tensor`): The query tensor.
149
+ k (`torch.Tensor`): The key tensor.
150
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
151
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
152
+ position_ids (`torch.Tensor`):
153
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
154
+ used to pass offsetted position ids when working with a KV-cache.
155
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
156
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
157
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
158
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
159
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
160
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
161
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
162
+ Returns:
163
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
164
+ """
165
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
166
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
167
+ q_embed = (q * cos) + (rotate_half(q) * sin)
168
+ k_embed = (k * cos) + (rotate_half(k) * sin)
169
+ return q_embed, k_embed
170
+
171
+
172
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
173
+ class Qwen2MLP(nn.Module):
174
+ def __init__(self, config):
175
+ super().__init__()
176
+ self.config = config
177
+ self.hidden_size = config.hidden_size
178
+ self.intermediate_size = config.intermediate_size
179
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
180
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
181
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
182
+ self.act_fn = ACT2FN[config.hidden_act]
183
+
184
+ def forward(self, x):
185
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
186
+
187
+
188
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
189
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
190
+ """
191
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
192
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
193
+ """
194
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
195
+ if n_rep == 1:
196
+ return hidden_states
197
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
198
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
199
+
200
+
201
+ class Qwen2Attention(nn.Module):
202
+ """
203
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
204
+ and "Generating Long Sequences with Sparse Transformers".
205
+ """
206
+
207
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
208
+ super().__init__()
209
+ self.config = config
210
+ self.layer_idx = layer_idx
211
+ if layer_idx is None:
212
+ logger.warning_once(
213
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
214
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
215
+ "when creating this class."
216
+ )
217
+
218
+ self.hidden_size = config.hidden_size
219
+ self.num_heads = config.num_attention_heads
220
+ self.head_dim = self.hidden_size // self.num_heads
221
+ self.num_key_value_heads = config.num_key_value_heads
222
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
223
+ self.max_position_embeddings = config.max_position_embeddings
224
+ self.rope_theta = config.rope_theta
225
+ self.is_causal = True
226
+ self.attention_dropout = config.attention_dropout
227
+
228
+ if (self.head_dim * self.num_heads) != self.hidden_size:
229
+ raise ValueError(
230
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
231
+ f" and `num_heads`: {self.num_heads})."
232
+ )
233
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
234
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
235
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
236
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
237
+
238
+ self.rotary_emb = Qwen2RotaryEmbedding(
239
+ self.head_dim,
240
+ max_position_embeddings=self.max_position_embeddings,
241
+ base=self.rope_theta,
242
+ )
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.Tensor,
247
+ attention_mask: Optional[torch.Tensor] = None,
248
+ position_ids: Optional[torch.LongTensor] = None,
249
+ past_key_value: Optional[Cache] = None,
250
+ output_attentions: bool = False,
251
+ use_cache: bool = False,
252
+ **kwargs,
253
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
254
+ if "padding_mask" in kwargs:
255
+ warnings.warn(
256
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
257
+ )
258
+ bsz, q_len, _ = hidden_states.size()
259
+
260
+ query_states = self.q_proj(hidden_states)
261
+ key_states = self.k_proj(hidden_states)
262
+ value_states = self.v_proj(hidden_states)
263
+
264
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
265
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
266
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
267
+
268
+ kv_seq_len = key_states.shape[-2]
269
+ if past_key_value is not None:
270
+ if self.layer_idx is None:
271
+ raise ValueError(
272
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
273
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
274
+ "with a layer index."
275
+ )
276
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
277
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
278
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
279
+
280
+ if past_key_value is not None:
281
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
282
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
283
+
284
+ # repeat k/v heads if n_kv_heads < n_heads
285
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
286
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
287
+
288
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
289
+
290
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
291
+ raise ValueError(
292
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
293
+ f" {attn_weights.size()}"
294
+ )
295
+
296
+ if attention_mask is not None:
297
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
298
+ raise ValueError(
299
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
300
+ )
301
+
302
+ attn_weights = attn_weights + attention_mask
303
+
304
+ # upcast attention to fp32
305
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
306
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
307
+ attn_output = torch.matmul(attn_weights, value_states)
308
+
309
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
310
+ raise ValueError(
311
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
312
+ f" {attn_output.size()}"
313
+ )
314
+
315
+ attn_output = attn_output.transpose(1, 2).contiguous()
316
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
317
+
318
+ attn_output = self.o_proj(attn_output)
319
+
320
+ if not output_attentions:
321
+ attn_weights = None
322
+
323
+ return attn_output, attn_weights, past_key_value
324
+
325
+
326
+ class Qwen2FlashAttention2(Qwen2Attention):
327
+ """
328
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
329
+ as the weights of the module stays untouched. The only required change would be on the forward pass
330
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
331
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
332
+ config.max_window_layers layers.
333
+ """
334
+
335
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
336
+ def __init__(self, *args, **kwargs):
337
+ super().__init__(*args, **kwargs)
338
+
339
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
340
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
341
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
342
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
343
+
344
+ def forward(
345
+ self,
346
+ hidden_states: torch.Tensor,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ position_ids: Optional[torch.LongTensor] = None,
349
+ past_key_value: Optional[Cache] = None,
350
+ output_attentions: bool = False,
351
+ use_cache: bool = False,
352
+ **kwargs,
353
+ ):
354
+ if "padding_mask" in kwargs:
355
+ warnings.warn(
356
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
357
+ )
358
+
359
+ # overwrite attention_mask with padding_mask
360
+ attention_mask = kwargs.pop("padding_mask")
361
+ bsz, q_len, _ = hidden_states.size()
362
+
363
+ query_states = self.q_proj(hidden_states)
364
+ key_states = self.k_proj(hidden_states)
365
+ value_states = self.v_proj(hidden_states)
366
+
367
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
368
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
369
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
370
+
371
+ kv_seq_len = key_states.shape[-2]
372
+ if past_key_value is not None:
373
+ if self.layer_idx is None:
374
+ raise ValueError(
375
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
376
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
377
+ "with a layer index."
378
+ )
379
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
380
+
381
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
382
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
383
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
384
+
385
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
386
+
387
+ use_sliding_windows = (
388
+ _flash_supports_window_size
389
+ and getattr(self.config, "sliding_window", None) is not None
390
+ and kv_seq_len > self.config.sliding_window
391
+ and self.config.use_sliding_window
392
+ )
393
+
394
+ if not _flash_supports_window_size:
395
+ logger.warning_once(
396
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
397
+ " make sure to upgrade flash-attn library."
398
+ )
399
+
400
+ if past_key_value is not None:
401
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
402
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
403
+ if (
404
+ getattr(self.config, "sliding_window", None) is not None
405
+ and kv_seq_len > self.config.sliding_window
406
+ and cache_has_contents
407
+ ):
408
+ slicing_tokens = 1 - self.config.sliding_window
409
+
410
+ past_key = past_key_value[self.layer_idx][0]
411
+ past_value = past_key_value[self.layer_idx][1]
412
+
413
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
414
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
415
+
416
+ if past_key.shape[-2] != self.config.sliding_window - 1:
417
+ raise ValueError(
418
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
419
+ f" {past_key.shape}"
420
+ )
421
+
422
+ if attention_mask is not None:
423
+ attention_mask = attention_mask[:, slicing_tokens:]
424
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
425
+
426
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
427
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
428
+
429
+ # repeat k/v heads if n_kv_heads < n_heads
430
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
431
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
432
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
433
+
434
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
435
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
436
+ # cast them back in float16 just to be sure everything works as expected.
437
+ input_dtype = query_states.dtype
438
+ if input_dtype == torch.float32:
439
+ if torch.is_autocast_enabled():
440
+ target_dtype = torch.get_autocast_gpu_dtype()
441
+ # Handle the case where the model is quantized
442
+ elif hasattr(self.config, "_pre_quantization_dtype"):
443
+ target_dtype = self.config._pre_quantization_dtype
444
+ else:
445
+ target_dtype = self.q_proj.weight.dtype
446
+
447
+ logger.warning_once(
448
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
449
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
450
+ f" {target_dtype}."
451
+ )
452
+
453
+ query_states = query_states.to(target_dtype)
454
+ key_states = key_states.to(target_dtype)
455
+ value_states = value_states.to(target_dtype)
456
+
457
+ # Reashape to the expected shape for Flash Attention
458
+ query_states = query_states.transpose(1, 2)
459
+ key_states = key_states.transpose(1, 2)
460
+ value_states = value_states.transpose(1, 2)
461
+
462
+ attn_output = self._flash_attention_forward(
463
+ query_states,
464
+ key_states,
465
+ value_states,
466
+ attention_mask,
467
+ q_len,
468
+ dropout=dropout_rate,
469
+ use_sliding_windows=use_sliding_windows,
470
+ )
471
+
472
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
473
+ attn_output = self.o_proj(attn_output)
474
+
475
+ if not output_attentions:
476
+ attn_weights = None
477
+
478
+ return attn_output, attn_weights, past_key_value
479
+
480
+ def _flash_attention_forward(
481
+ self,
482
+ query_states,
483
+ key_states,
484
+ value_states,
485
+ attention_mask,
486
+ query_length,
487
+ dropout=0.0,
488
+ softmax_scale=None,
489
+ use_sliding_windows=False,
490
+ ):
491
+ """
492
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
493
+ first unpad the input, then computes the attention scores and pad the final attention scores.
494
+
495
+ Args:
496
+ query_states (`torch.Tensor`):
497
+ Input query states to be passed to Flash Attention API
498
+ key_states (`torch.Tensor`):
499
+ Input key states to be passed to Flash Attention API
500
+ value_states (`torch.Tensor`):
501
+ Input value states to be passed to Flash Attention API
502
+ attention_mask (`torch.Tensor`):
503
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
504
+ position of padding tokens and 1 for the position of non-padding tokens.
505
+ dropout (`float`):
506
+ Attention dropout
507
+ softmax_scale (`float`, *optional*):
508
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
509
+ use_sliding_windows (`bool`, *optional*):
510
+ Whether to activate sliding window attention.
511
+ """
512
+ if not self._flash_attn_uses_top_left_mask:
513
+ causal = self.is_causal
514
+ else:
515
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
516
+ causal = self.is_causal and query_length != 1
517
+
518
+ # Decide whether to use SWA or not by layer index.
519
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
520
+ use_sliding_windows = False
521
+
522
+ # Contains at least one padding token in the sequence
523
+ if attention_mask is not None:
524
+ batch_size = query_states.shape[0]
525
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
526
+ query_states, key_states, value_states, attention_mask, query_length
527
+ )
528
+
529
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
530
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
531
+
532
+ if not use_sliding_windows:
533
+ attn_output_unpad = flash_attn_varlen_func(
534
+ query_states,
535
+ key_states,
536
+ value_states,
537
+ cu_seqlens_q=cu_seqlens_q,
538
+ cu_seqlens_k=cu_seqlens_k,
539
+ max_seqlen_q=max_seqlen_in_batch_q,
540
+ max_seqlen_k=max_seqlen_in_batch_k,
541
+ dropout_p=dropout,
542
+ softmax_scale=softmax_scale,
543
+ causal=causal,
544
+ )
545
+ else:
546
+ attn_output_unpad = flash_attn_varlen_func(
547
+ query_states,
548
+ key_states,
549
+ value_states,
550
+ cu_seqlens_q=cu_seqlens_q,
551
+ cu_seqlens_k=cu_seqlens_k,
552
+ max_seqlen_q=max_seqlen_in_batch_q,
553
+ max_seqlen_k=max_seqlen_in_batch_k,
554
+ dropout_p=dropout,
555
+ softmax_scale=softmax_scale,
556
+ causal=causal,
557
+ window_size=(self.config.sliding_window, self.config.sliding_window),
558
+ )
559
+
560
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
561
+ else:
562
+ if not use_sliding_windows:
563
+ attn_output = flash_attn_func(
564
+ query_states,
565
+ key_states,
566
+ value_states,
567
+ dropout,
568
+ softmax_scale=softmax_scale,
569
+ causal=causal,
570
+ )
571
+ else:
572
+ attn_output = flash_attn_func(
573
+ query_states,
574
+ key_states,
575
+ value_states,
576
+ dropout,
577
+ softmax_scale=softmax_scale,
578
+ causal=causal,
579
+ window_size=(self.config.sliding_window, self.config.sliding_window),
580
+ )
581
+
582
+ return attn_output
583
+
584
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
585
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
586
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
587
+
588
+ # On the first iteration we need to properly re-create the padding mask
589
+ # by slicing it on the proper place
590
+ if kv_seq_len != attention_mask.shape[-1]:
591
+ attention_mask_num_tokens = attention_mask.shape[-1]
592
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
593
+
594
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
595
+
596
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
597
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
598
+
599
+ if query_length == kv_seq_len:
600
+ query_layer = index_first_axis(
601
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
602
+ )
603
+ cu_seqlens_q = cu_seqlens_k
604
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
605
+ indices_q = indices_k
606
+ elif query_length == 1:
607
+ max_seqlen_in_batch_q = 1
608
+ cu_seqlens_q = torch.arange(
609
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
610
+ ) # There is a memcpy here, that is very bad.
611
+ indices_q = cu_seqlens_q[:-1]
612
+ query_layer = query_layer.squeeze(1)
613
+ else:
614
+ # The -q_len: slice assumes left padding.
615
+ attention_mask = attention_mask[:, -query_length:]
616
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
617
+
618
+ return (
619
+ query_layer,
620
+ key_layer,
621
+ value_layer,
622
+ indices_q,
623
+ (cu_seqlens_q, cu_seqlens_k),
624
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
625
+ )
626
+
627
+
628
+ # Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2
629
+ class Qwen2SdpaAttention(Qwen2Attention):
630
+ """
631
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
632
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
633
+ SDPA API.
634
+ """
635
+
636
+ # Adapted from Qwen2Attention.forward
637
+ def forward(
638
+ self,
639
+ hidden_states: torch.Tensor,
640
+ attention_mask: Optional[torch.Tensor] = None,
641
+ position_ids: Optional[torch.LongTensor] = None,
642
+ past_key_value: Optional[Cache] = None,
643
+ output_attentions: bool = False,
644
+ use_cache: bool = False,
645
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
646
+ if output_attentions:
647
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
648
+ logger.warning_once(
649
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
650
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
651
+ )
652
+ return super().forward(
653
+ hidden_states=hidden_states,
654
+ attention_mask=attention_mask,
655
+ position_ids=position_ids,
656
+ past_key_value=past_key_value,
657
+ output_attentions=output_attentions,
658
+ use_cache=use_cache,
659
+ )
660
+
661
+ bsz, q_len, _ = hidden_states.size()
662
+
663
+ query_states = self.q_proj(hidden_states)
664
+ key_states = self.k_proj(hidden_states)
665
+ value_states = self.v_proj(hidden_states)
666
+
667
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
668
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
669
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
670
+
671
+ kv_seq_len = key_states.shape[-2]
672
+ if past_key_value is not None:
673
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
674
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
675
+
676
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
677
+
678
+ if past_key_value is not None:
679
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
680
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
681
+
682
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
683
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
684
+
685
+ if attention_mask is not None:
686
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
687
+ raise ValueError(
688
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
689
+ )
690
+
691
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
692
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
693
+ if query_states.device.type == "cuda" and attention_mask is not None:
694
+ query_states = query_states.contiguous()
695
+ key_states = key_states.contiguous()
696
+ value_states = value_states.contiguous()
697
+
698
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
699
+ query_states,
700
+ key_states,
701
+ value_states,
702
+ attn_mask=attention_mask,
703
+ dropout_p=self.attention_dropout if self.training else 0.0,
704
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
705
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
706
+ )
707
+
708
+ attn_output = attn_output.transpose(1, 2).contiguous()
709
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
710
+
711
+ attn_output = self.o_proj(attn_output)
712
+
713
+ return attn_output, None, past_key_value
714
+
715
+
716
+ QWEN2_ATTENTION_CLASSES = {
717
+ "eager": Qwen2Attention,
718
+ "flash_attention_2": Qwen2FlashAttention2,
719
+ "sdpa": Qwen2SdpaAttention,
720
+ }
721
+
722
+
723
+ class Qwen2DecoderLayer(nn.Module):
724
+ def __init__(self, config: Qwen2Config, layer_idx: int):
725
+ super().__init__()
726
+ self.hidden_size = config.hidden_size
727
+
728
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
729
+ logger.warning_once(
730
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
731
+ "unexpected results may be encountered."
732
+ )
733
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
734
+
735
+ self.mlp = Qwen2MLP(config)
736
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
737
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
738
+
739
+ def forward(
740
+ self,
741
+ hidden_states: torch.Tensor,
742
+ attention_mask: Optional[torch.Tensor] = None,
743
+ position_ids: Optional[torch.LongTensor] = None,
744
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
745
+ output_attentions: Optional[bool] = False,
746
+ use_cache: Optional[bool] = False,
747
+ **kwargs,
748
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
749
+ if "padding_mask" in kwargs:
750
+ warnings.warn(
751
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
752
+ "Please make sure use `attention_mask` instead.`"
753
+ )
754
+ """
755
+ Args:
756
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
757
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
758
+ `(batch, sequence_length)` where padding elements are indicated by 0.
759
+ output_attentions (`bool`, *optional*):
760
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
761
+ returned tensors for more detail.
762
+ use_cache (`bool`, *optional*):
763
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
764
+ (see `past_key_values`).
765
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
766
+ """
767
+
768
+ residual = hidden_states
769
+
770
+ hidden_states = self.input_layernorm(hidden_states)
771
+
772
+ # Self Attention
773
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
774
+ hidden_states=hidden_states,
775
+ attention_mask=attention_mask,
776
+ position_ids=position_ids,
777
+ past_key_value=past_key_value,
778
+ output_attentions=output_attentions,
779
+ use_cache=use_cache,
780
+ )
781
+ hidden_states = residual + hidden_states
782
+
783
+ # Fully Connected
784
+ residual = hidden_states
785
+ hidden_states = self.post_attention_layernorm(hidden_states)
786
+ hidden_states = self.mlp(hidden_states)
787
+ hidden_states = residual + hidden_states
788
+
789
+ outputs = (hidden_states,)
790
+
791
+ if output_attentions:
792
+ outputs += (self_attn_weights,)
793
+
794
+ if use_cache:
795
+ outputs += (present_key_value,)
796
+
797
+ return outputs
798
+
799
+
800
+ QWEN2_START_DOCSTRING = r"""
801
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
802
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
803
+ etc.)
804
+
805
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
806
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
807
+ and behavior.
808
+
809
+ Parameters:
810
+ config ([`Qwen2Config`]):
811
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
812
+ load the weights associated with the model, only the configuration. Check out the
813
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
814
+ """
815
+
816
+
817
+ @add_start_docstrings(
818
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
819
+ QWEN2_START_DOCSTRING,
820
+ )
821
+ class Qwen2PreTrainedModel(PreTrainedModel):
822
+ config_class = Qwen2Config
823
+ base_model_prefix = "model"
824
+ supports_gradient_checkpointing = True
825
+ _no_split_modules = ["Qwen2DecoderLayer"]
826
+ _skip_keys_device_placement = "past_key_values"
827
+ _supports_flash_attn_2 = True
828
+ _supports_sdpa = True
829
+ _supports_cache_class = True
830
+
831
+ def _init_weights(self, module):
832
+ std = self.config.initializer_range
833
+ if isinstance(module, nn.Linear):
834
+ module.weight.data.normal_(mean=0.0, std=std)
835
+ if module.bias is not None:
836
+ module.bias.data.zero_()
837
+ elif isinstance(module, nn.Embedding):
838
+ module.weight.data.normal_(mean=0.0, std=std)
839
+ if module.padding_idx is not None:
840
+ module.weight.data[module.padding_idx].zero_()
841
+
842
+
843
+ QWEN2_INPUTS_DOCSTRING = r"""
844
+ Args:
845
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
846
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
847
+ it.
848
+
849
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
850
+ [`PreTrainedTokenizer.__call__`] for details.
851
+
852
+ [What are input IDs?](../glossary#input-ids)
853
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
854
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
855
+
856
+ - 1 for tokens that are **not masked**,
857
+ - 0 for tokens that are **masked**.
858
+
859
+ [What are attention masks?](../glossary#attention-mask)
860
+
861
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
862
+ [`PreTrainedTokenizer.__call__`] for details.
863
+
864
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
865
+ `past_key_values`).
866
+
867
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
868
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
869
+ information on the default strategy.
870
+
871
+ - 1 indicates the head is **not masked**,
872
+ - 0 indicates the head is **masked**.
873
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
874
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
875
+ config.n_positions - 1]`.
876
+
877
+ [What are position IDs?](../glossary#position-ids)
878
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
879
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
880
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
881
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
882
+
883
+ Two formats are allowed:
884
+ - a [`~cache_utils.Cache`] instance;
885
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
886
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
887
+ cache format.
888
+
889
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
890
+ legacy cache format will be returned.
891
+
892
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
893
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
894
+ of shape `(batch_size, sequence_length)`.
895
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
896
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
897
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
898
+ model's internal embedding lookup matrix.
899
+ use_cache (`bool`, *optional*):
900
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
901
+ `past_key_values`).
902
+ output_attentions (`bool`, *optional*):
903
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
904
+ tensors for more detail.
905
+ output_hidden_states (`bool`, *optional*):
906
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
907
+ more detail.
908
+ return_dict (`bool`, *optional*):
909
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
910
+ """
911
+
912
+
913
+ @add_start_docstrings(
914
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
915
+ QWEN2_START_DOCSTRING,
916
+ )
917
+ class Qwen2Model(Qwen2PreTrainedModel):
918
+ """
919
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
920
+
921
+ Args:
922
+ config: Qwen2Config
923
+ """
924
+
925
+ def __init__(self, config: Qwen2Config):
926
+ super().__init__(config)
927
+ self.padding_idx = config.pad_token_id
928
+ self.vocab_size = config.vocab_size
929
+
930
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
931
+ self.layers = nn.ModuleList(
932
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
933
+ )
934
+ self._attn_implementation = config._attn_implementation
935
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
936
+
937
+ self.gradient_checkpointing = False
938
+ # Initialize weights and apply final processing
939
+ self.post_init()
940
+
941
+ def get_input_embeddings(self):
942
+ return self.embed_tokens
943
+
944
+ def set_input_embeddings(self, value):
945
+ self.embed_tokens = value
946
+
947
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
948
+ def forward(
949
+ self,
950
+ input_ids: torch.LongTensor = None,
951
+ attention_mask: Optional[torch.Tensor] = None,
952
+ position_ids: Optional[torch.LongTensor] = None,
953
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
954
+ inputs_embeds: Optional[torch.FloatTensor] = None,
955
+ use_cache: Optional[bool] = None,
956
+ output_attentions: Optional[bool] = None,
957
+ output_hidden_states: Optional[bool] = None,
958
+ return_dict: Optional[bool] = None,
959
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
960
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
961
+ output_hidden_states = (
962
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
963
+ )
964
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
965
+
966
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
967
+
968
+ # retrieve input_ids and inputs_embeds
969
+ if input_ids is not None and inputs_embeds is not None:
970
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
971
+ elif input_ids is not None:
972
+ batch_size, seq_length = input_ids.shape
973
+ elif inputs_embeds is not None:
974
+ batch_size, seq_length, _ = inputs_embeds.shape
975
+ else:
976
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
977
+
978
+ if self.gradient_checkpointing and self.training:
979
+ if use_cache:
980
+ logger.warning_once(
981
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
982
+ )
983
+ use_cache = False
984
+
985
+ past_key_values_length = 0
986
+
987
+ if use_cache:
988
+ use_legacy_cache = not isinstance(past_key_values, Cache)
989
+ if use_legacy_cache:
990
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
991
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
992
+
993
+ if position_ids is None:
994
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
995
+ position_ids = torch.arange(
996
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
997
+ )
998
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
999
+ else:
1000
+ position_ids = position_ids.view(-1, seq_length).long()
1001
+
1002
+ if inputs_embeds is None:
1003
+ inputs_embeds = self.embed_tokens(input_ids)
1004
+
1005
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1006
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1007
+ if is_padding_right:
1008
+ raise ValueError(
1009
+ "You are attempting to perform batched generation with padding_side='right'"
1010
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
1011
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1012
+ )
1013
+
1014
+ if self._attn_implementation == "flash_attention_2":
1015
+ # 2d mask is passed through the layers
1016
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1017
+ elif self._attn_implementation == "sdpa" and not output_attentions:
1018
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1019
+ # the manual implementation that requires a 4D causal mask in all cases.
1020
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1021
+ attention_mask,
1022
+ (batch_size, seq_length),
1023
+ inputs_embeds,
1024
+ past_key_values_length,
1025
+ )
1026
+ else:
1027
+ # 4d mask is passed through the layers
1028
+ attention_mask = _prepare_4d_causal_attention_mask(
1029
+ attention_mask,
1030
+ (batch_size, seq_length),
1031
+ inputs_embeds,
1032
+ past_key_values_length,
1033
+ sliding_window=self.config.sliding_window,
1034
+ )
1035
+
1036
+ hidden_states = inputs_embeds
1037
+
1038
+ # decoder layers
1039
+ all_hidden_states = () if output_hidden_states else None
1040
+ all_self_attns = () if output_attentions else None
1041
+ next_decoder_cache = None
1042
+
1043
+ for decoder_layer in self.layers:
1044
+ if output_hidden_states:
1045
+ all_hidden_states += (hidden_states,)
1046
+
1047
+ if self.gradient_checkpointing and self.training:
1048
+ layer_outputs = self._gradient_checkpointing_func(
1049
+ decoder_layer.__call__,
1050
+ hidden_states,
1051
+ attention_mask,
1052
+ position_ids,
1053
+ past_key_values,
1054
+ output_attentions,
1055
+ use_cache,
1056
+ )
1057
+ else:
1058
+ layer_outputs = decoder_layer(
1059
+ hidden_states,
1060
+ attention_mask=attention_mask,
1061
+ position_ids=position_ids,
1062
+ past_key_value=past_key_values,
1063
+ output_attentions=output_attentions,
1064
+ use_cache=use_cache,
1065
+ )
1066
+
1067
+ hidden_states = layer_outputs[0]
1068
+
1069
+ if use_cache:
1070
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1071
+
1072
+ if output_attentions:
1073
+ all_self_attns += (layer_outputs[1],)
1074
+
1075
+ hidden_states = self.norm(hidden_states)
1076
+
1077
+ # add hidden states from the last decoder layer
1078
+ if output_hidden_states:
1079
+ all_hidden_states += (hidden_states,)
1080
+
1081
+ next_cache = None
1082
+ if use_cache:
1083
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1084
+
1085
+ if not return_dict:
1086
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1087
+ return BaseModelOutputWithPast(
1088
+ last_hidden_state=hidden_states,
1089
+ past_key_values=next_cache,
1090
+ hidden_states=all_hidden_states,
1091
+ attentions=all_self_attns,
1092
+ )
1093
+
1094
+
1095
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1096
+ _tied_weights_keys = ["lm_head.weight"]
1097
+
1098
+ def __init__(self, config):
1099
+ super().__init__(config)
1100
+ self.model = Qwen2Model(config)
1101
+ self.vocab_size = config.vocab_size
1102
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1103
+
1104
+ # Initialize weights and apply final processing
1105
+ self.post_init()
1106
+
1107
+ def get_input_embeddings(self):
1108
+ return self.model.embed_tokens
1109
+
1110
+ def set_input_embeddings(self, value):
1111
+ self.model.embed_tokens = value
1112
+
1113
+ def get_output_embeddings(self):
1114
+ return self.lm_head
1115
+
1116
+ def set_output_embeddings(self, new_embeddings):
1117
+ self.lm_head = new_embeddings
1118
+
1119
+ def set_decoder(self, decoder):
1120
+ self.model = decoder
1121
+
1122
+ def get_decoder(self):
1123
+ return self.model
1124
+
1125
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1126
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1127
+ def forward(
1128
+ self,
1129
+ input_ids: torch.LongTensor = None,
1130
+ attention_mask: Optional[torch.Tensor] = None,
1131
+ position_ids: Optional[torch.LongTensor] = None,
1132
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1133
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1134
+ labels: Optional[torch.LongTensor] = None,
1135
+ use_cache: Optional[bool] = None,
1136
+ output_attentions: Optional[bool] = None,
1137
+ output_hidden_states: Optional[bool] = None,
1138
+ return_dict: Optional[bool] = None,
1139
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1140
+ r"""
1141
+ Args:
1142
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1143
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1144
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1145
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1146
+
1147
+ Returns:
1148
+
1149
+ Example:
1150
+
1151
+ ```python
1152
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1153
+
1154
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1155
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1156
+
1157
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1158
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1159
+
1160
+ >>> # Generate
1161
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1162
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1163
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1164
+ ```"""
1165
+
1166
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1167
+ output_hidden_states = (
1168
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1169
+ )
1170
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1171
+
1172
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1173
+ outputs = self.model(
1174
+ input_ids=input_ids,
1175
+ attention_mask=attention_mask,
1176
+ position_ids=position_ids,
1177
+ past_key_values=past_key_values,
1178
+ inputs_embeds=inputs_embeds,
1179
+ use_cache=use_cache,
1180
+ output_attentions=output_attentions,
1181
+ output_hidden_states=output_hidden_states,
1182
+ return_dict=return_dict,
1183
+ )
1184
+
1185
+ hidden_states = outputs[0]
1186
+ logits = self.lm_head(hidden_states)
1187
+ logits = logits.float()
1188
+
1189
+ loss = None
1190
+ if labels is not None:
1191
+ # Shift so that tokens < n predict n
1192
+ shift_logits = logits[..., :-1, :].contiguous()
1193
+ shift_labels = labels[..., 1:].contiguous()
1194
+ # Flatten the tokens
1195
+ loss_fct = CrossEntropyLoss()
1196
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1197
+ shift_labels = shift_labels.view(-1)
1198
+ # Enable model parallelism
1199
+ shift_labels = shift_labels.to(shift_logits.device)
1200
+ loss = loss_fct(shift_logits, shift_labels)
1201
+
1202
+ if not return_dict:
1203
+ output = (logits,) + outputs[1:]
1204
+ return (loss,) + output if loss is not None else output
1205
+
1206
+ return CausalLMOutputWithPast(
1207
+ loss=loss,
1208
+ logits=logits,
1209
+ past_key_values=outputs.past_key_values,
1210
+ hidden_states=outputs.hidden_states,
1211
+ attentions=outputs.attentions,
1212
+ )
1213
+
1214
+ def prepare_inputs_for_generation(
1215
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1216
+ ):
1217
+ # Omit tokens covered by past_key_values
1218
+ if past_key_values is not None:
1219
+ if isinstance(past_key_values, Cache):
1220
+ cache_length = past_key_values.get_seq_length()
1221
+ past_length = past_key_values.seen_tokens
1222
+ max_cache_length = past_key_values.get_max_length()
1223
+ else:
1224
+ cache_length = past_length = past_key_values[0][0].shape[2]
1225
+ max_cache_length = None
1226
+
1227
+ # Keep only the unprocessed tokens:
1228
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1229
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1230
+ # input)
1231
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1232
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1233
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1234
+ # input_ids based on the past_length.
1235
+ elif past_length < input_ids.shape[1]:
1236
+ input_ids = input_ids[:, past_length:]
1237
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1238
+ else:
1239
+ remove_prefix_length = input_ids.shape[1] - 1
1240
+ input_ids = input_ids[:, remove_prefix_length:]
1241
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1242
+ if (
1243
+ max_cache_length is not None
1244
+ and attention_mask is not None
1245
+ and cache_length + input_ids.shape[1] > max_cache_length
1246
+ ):
1247
+ attention_mask = attention_mask[:, -max_cache_length:]
1248
+
1249
+ position_ids = kwargs.get("position_ids", None)
1250
+ if attention_mask is not None and position_ids is None:
1251
+ # create position_ids on the fly for batch generation
1252
+ position_ids = attention_mask.long().cumsum(-1) - 1
1253
+ position_ids.masked_fill_(attention_mask == 0, 1)
1254
+ if past_key_values:
1255
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1256
+
1257
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1258
+ if inputs_embeds is not None and past_key_values is None:
1259
+ model_inputs = {"inputs_embeds": inputs_embeds}
1260
+ else:
1261
+ model_inputs = {"input_ids": input_ids}
1262
+
1263
+ model_inputs.update(
1264
+ {
1265
+ "position_ids": position_ids,
1266
+ "past_key_values": past_key_values,
1267
+ "use_cache": kwargs.get("use_cache"),
1268
+ "attention_mask": attention_mask,
1269
+ }
1270
+ )
1271
+ return model_inputs
1272
+
1273
+ @staticmethod
1274
+ def _reorder_cache(past_key_values, beam_idx):
1275
+ reordered_past = ()
1276
+ for layer_past in past_key_values:
1277
+ reordered_past += (
1278
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1279
+ )
1280
+ return reordered_past
1281
+
1282
+
1283
+ @add_start_docstrings(
1284
+ """
1285
+ The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1286
+
1287
+ [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1288
+ (e.g. GPT-2) do.
1289
+
1290
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1291
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1292
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1293
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1294
+ each row of the batch).
1295
+ """,
1296
+ QWEN2_START_DOCSTRING,
1297
+ )
1298
+ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1299
+ def __init__(self, config):
1300
+ super().__init__(config)
1301
+ self.num_labels = config.num_labels
1302
+ self.model = Qwen2Model(config)
1303
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1304
+
1305
+ # Initialize weights and apply final processing
1306
+ self.post_init()
1307
+
1308
+ def get_input_embeddings(self):
1309
+ return self.model.embed_tokens
1310
+
1311
+ def set_input_embeddings(self, value):
1312
+ self.model.embed_tokens = value
1313
+
1314
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1315
+ def forward(
1316
+ self,
1317
+ input_ids: torch.LongTensor = None,
1318
+ attention_mask: Optional[torch.Tensor] = None,
1319
+ position_ids: Optional[torch.LongTensor] = None,
1320
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1321
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1322
+ labels: Optional[torch.LongTensor] = None,
1323
+ use_cache: Optional[bool] = None,
1324
+ output_attentions: Optional[bool] = None,
1325
+ output_hidden_states: Optional[bool] = None,
1326
+ return_dict: Optional[bool] = None,
1327
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1328
+ r"""
1329
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1330
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1331
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1332
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1333
+ """
1334
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1335
+
1336
+ transformer_outputs = self.model(
1337
+ input_ids,
1338
+ attention_mask=attention_mask,
1339
+ position_ids=position_ids,
1340
+ past_key_values=past_key_values,
1341
+ inputs_embeds=inputs_embeds,
1342
+ use_cache=use_cache,
1343
+ output_attentions=output_attentions,
1344
+ output_hidden_states=output_hidden_states,
1345
+ return_dict=return_dict,
1346
+ )
1347
+ hidden_states = transformer_outputs[0]
1348
+ logits = self.score(hidden_states)
1349
+
1350
+ if input_ids is not None:
1351
+ batch_size = input_ids.shape[0]
1352
+ else:
1353
+ batch_size = inputs_embeds.shape[0]
1354
+
1355
+ if self.config.pad_token_id is None and batch_size != 1:
1356
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1357
+ if self.config.pad_token_id is None:
1358
+ sequence_lengths = -1
1359
+ else:
1360
+ if input_ids is not None:
1361
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1362
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1363
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1364
+ sequence_lengths = sequence_lengths.to(logits.device)
1365
+ else:
1366
+ sequence_lengths = -1
1367
+
1368
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1369
+
1370
+ loss = None
1371
+ if labels is not None:
1372
+ labels = labels.to(logits.device)
1373
+ if self.config.problem_type is None:
1374
+ if self.num_labels == 1:
1375
+ self.config.problem_type = "regression"
1376
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1377
+ self.config.problem_type = "single_label_classification"
1378
+ else:
1379
+ self.config.problem_type = "multi_label_classification"
1380
+
1381
+ if self.config.problem_type == "regression":
1382
+ loss_fct = MSELoss()
1383
+ if self.num_labels == 1:
1384
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1385
+ else:
1386
+ loss = loss_fct(pooled_logits, labels)
1387
+ elif self.config.problem_type == "single_label_classification":
1388
+ loss_fct = CrossEntropyLoss()
1389
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1390
+ elif self.config.problem_type == "multi_label_classification":
1391
+ loss_fct = BCEWithLogitsLoss()
1392
+ loss = loss_fct(pooled_logits, labels)
1393
+ if not return_dict:
1394
+ output = (pooled_logits,) + transformer_outputs[1:]
1395
+ return ((loss,) + output) if loss is not None else output
1396
+
1397
+ return SequenceClassifierOutputWithPast(
1398
+ loss=loss,
1399
+ logits=pooled_logits,
1400
+ past_key_values=transformer_outputs.past_key_values,
1401
+ hidden_states=transformer_outputs.hidden_states,
1402
+ attentions=transformer_outputs.attentions,
1403
+ )
bunny/model/language_model/qwen2/tokenization_qwen2.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ import json
18
+ import os
19
+ import unicodedata
20
+ from functools import lru_cache
21
+ from typing import Optional, Tuple
22
+
23
+ import regex as re
24
+
25
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "vocab.json",
33
+ "merges_file": "merges.txt",
34
+ }
35
+
36
+ PRETRAINED_VOCAB_FILES_MAP = {
37
+ "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
38
+ "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
39
+ }
40
+
41
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
+
43
+ PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
44
+
45
+
46
+ @lru_cache()
47
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
48
+ def bytes_to_unicode():
49
+ """
50
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
51
+ characters the bpe code barfs on.
52
+
53
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
54
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
55
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
56
+ tables between utf-8 bytes and unicode strings.
57
+ """
58
+ bs = (
59
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
60
+ )
61
+ cs = bs[:]
62
+ n = 0
63
+ for b in range(2**8):
64
+ if b not in bs:
65
+ bs.append(b)
66
+ cs.append(2**8 + n)
67
+ n += 1
68
+ cs = [chr(n) for n in cs]
69
+ return dict(zip(bs, cs))
70
+
71
+
72
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
73
+ def get_pairs(word):
74
+ """
75
+ Return set of symbol pairs in a word.
76
+
77
+ Word is represented as tuple of symbols (symbols being variable-length strings).
78
+ """
79
+ pairs = set()
80
+ prev_char = word[0]
81
+ for char in word[1:]:
82
+ pairs.add((prev_char, char))
83
+ prev_char = char
84
+ return pairs
85
+
86
+
87
+ class Qwen2Tokenizer(PreTrainedTokenizer):
88
+ """
89
+ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
90
+
91
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
92
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
93
+
94
+ ```python
95
+ >>> from transformers import Qwen2Tokenizer
96
+
97
+ >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
98
+ >>> tokenizer("Hello world")["input_ids"]
99
+ [9707, 1879]
100
+
101
+ >>> tokenizer(" Hello world")["input_ids"]
102
+ [21927, 1879]
103
+ ```
104
+ This is expected.
105
+
106
+ You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
107
+
108
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
109
+ this superclass for more information regarding those methods.
110
+
111
+ Args:
112
+ vocab_file (`str`):
113
+ Path to the vocabulary file.
114
+ merges_file (`str`):
115
+ Path to the merges file.
116
+ errors (`str`, *optional*, defaults to `"replace"`):
117
+ Paradigm to follow when decoding bytes to UTF-8. See
118
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
119
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
120
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
121
+ token instead.
122
+ bos_token (`str`, *optional*):
123
+ The beginning of sequence token. Not applicable for this tokenizer.
124
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
125
+ The end of sequence token.
126
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
127
+ The token used for padding, for example when batching sequences of different lengths.
128
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
129
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
130
+ tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
131
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
132
+ Whether or not the special tokens should be split during the tokenization process. The default behavior is
133
+ to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
134
+ ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
135
+ '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
136
+ """
137
+
138
+ vocab_files_names = VOCAB_FILES_NAMES
139
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
140
+ max_model_input_sizes = MAX_MODEL_INPUT_SIZES
141
+ model_input_names = ["input_ids", "attention_mask"]
142
+
143
+ def __init__(
144
+ self,
145
+ vocab_file,
146
+ merges_file,
147
+ errors="replace",
148
+ unk_token="<|endoftext|>",
149
+ bos_token=None,
150
+ eos_token="<|endoftext|>",
151
+ pad_token="<|endoftext|>",
152
+ clean_up_tokenization_spaces=False,
153
+ split_special_tokens=False,
154
+ **kwargs,
155
+ ):
156
+ # Qwen vocab does not contain control tokens; added tokens need to be special
157
+ bos_token = (
158
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
159
+ if isinstance(bos_token, str)
160
+ else bos_token
161
+ )
162
+ eos_token = (
163
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
164
+ if isinstance(eos_token, str)
165
+ else eos_token
166
+ )
167
+ unk_token = (
168
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
169
+ if isinstance(unk_token, str)
170
+ else unk_token
171
+ )
172
+ pad_token = (
173
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
174
+ if isinstance(pad_token, str)
175
+ else pad_token
176
+ )
177
+
178
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
179
+ self.encoder = json.load(vocab_handle)
180
+ self.decoder = {v: k for k, v in self.encoder.items()}
181
+ self.errors = errors # how to handle errors in decoding
182
+ self.byte_encoder = bytes_to_unicode()
183
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
184
+ bpe_merges = []
185
+ with open(merges_file, encoding="utf-8") as merges_handle:
186
+ for line in merges_handle:
187
+ line = line.strip()
188
+ if not line or line.startswith("#"):
189
+ continue
190
+ bpe_merges.append(tuple(line.split()))
191
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
192
+ # NOTE: the cache can grow without bound and will get really large for long running processes
193
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
194
+ # not a memory leak but appears as one.
195
+ # GPT2Tokenizer has the same problem, so let's be consistent.
196
+ self.cache = {}
197
+
198
+ self.pat = re.compile(PRETOKENIZE_REGEX)
199
+
200
+ if kwargs.get("add_prefix_space", False):
201
+ logger.warning_once(
202
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
203
+ )
204
+
205
+ super().__init__(
206
+ errors=errors,
207
+ bos_token=bos_token,
208
+ eos_token=eos_token,
209
+ pad_token=pad_token,
210
+ unk_token=unk_token,
211
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
212
+ split_special_tokens=split_special_tokens,
213
+ **kwargs,
214
+ )
215
+
216
+ @property
217
+ def vocab_size(self) -> int:
218
+ return len(self.encoder)
219
+
220
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
221
+ def get_vocab(self):
222
+ return dict(self.encoder, **self.added_tokens_encoder)
223
+
224
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
225
+ def bpe(self, token):
226
+ if token in self.cache:
227
+ return self.cache[token]
228
+ word = tuple(token)
229
+ pairs = get_pairs(word)
230
+
231
+ if not pairs:
232
+ return token
233
+
234
+ while True:
235
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
236
+ if bigram not in self.bpe_ranks:
237
+ break
238
+ first, second = bigram
239
+ new_word = []
240
+ i = 0
241
+ while i < len(word):
242
+ try:
243
+ j = word.index(first, i)
244
+ except ValueError:
245
+ new_word.extend(word[i:])
246
+ break
247
+ else:
248
+ new_word.extend(word[i:j])
249
+ i = j
250
+
251
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
252
+ new_word.append(first + second)
253
+ i += 2
254
+ else:
255
+ new_word.append(word[i])
256
+ i += 1
257
+ new_word = tuple(new_word)
258
+ word = new_word
259
+ if len(word) == 1:
260
+ break
261
+ else:
262
+ pairs = get_pairs(word)
263
+ word = " ".join(word)
264
+ self.cache[token] = word
265
+ return word
266
+
267
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
268
+ def _tokenize(self, text):
269
+ """Tokenize a string."""
270
+ bpe_tokens = []
271
+ for token in re.findall(self.pat, text):
272
+ token = "".join(
273
+ self.byte_encoder[b] for b in token.encode("utf-8")
274
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
275
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
276
+ return bpe_tokens
277
+
278
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
279
+ def _convert_token_to_id(self, token):
280
+ """Converts a token (str) in an id using the vocab."""
281
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
282
+
283
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
284
+ def _convert_id_to_token(self, index):
285
+ """Converts an index (integer) in a token (str) using the vocab."""
286
+ return self.decoder.get(index)
287
+
288
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
289
+ def convert_tokens_to_string(self, tokens):
290
+ """Converts a sequence of tokens (string) in a single string."""
291
+ text = "".join(tokens)
292
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
293
+ return text
294
+
295
+ def decode(
296
+ self,
297
+ token_ids,
298
+ skip_special_tokens: bool = False,
299
+ clean_up_tokenization_spaces: Optional[bool] = False,
300
+ spaces_between_special_tokens: bool = False,
301
+ **kwargs,
302
+ ) -> str:
303
+ # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
304
+ # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
305
+ return super().decode(
306
+ token_ids,
307
+ skip_special_tokens=skip_special_tokens,
308
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
309
+ spaces_between_special_tokens=spaces_between_special_tokens,
310
+ **kwargs,
311
+ )
312
+
313
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
314
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
315
+ if not os.path.isdir(save_directory):
316
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
317
+ return
318
+ vocab_file = os.path.join(
319
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
320
+ )
321
+ merge_file = os.path.join(
322
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
323
+ )
324
+
325
+ with open(vocab_file, "w", encoding="utf-8") as f:
326
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
327
+
328
+ index = 0
329
+ with open(merge_file, "w", encoding="utf-8") as writer:
330
+ writer.write("#version: 0.2\n")
331
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
332
+ if index != token_index:
333
+ logger.warning(
334
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
335
+ " Please check that the tokenizer is not corrupted!"
336
+ )
337
+ index = token_index
338
+ writer.write(" ".join(bpe_tokens) + "\n")
339
+ index += 1
340
+
341
+ return vocab_file, merge_file
342
+
343
+ def prepare_for_tokenization(self, text, **kwargs):
344
+ text = unicodedata.normalize("NFC", text)
345
+ return (text, kwargs)
bunny/model/language_model/qwen2/tokenization_qwen2_fast.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ from transformers.tokenization_utils import AddedToken
20
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
+ from transformers.utils import logging
22
+ from .tokenization_qwen2 import Qwen2Tokenizer
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {
28
+ "vocab_file": "vocab.json",
29
+ "merges_file": "merges.txt",
30
+ "tokenizer_file": "tokenizer.json",
31
+ }
32
+
33
+ PRETRAINED_VOCAB_FILES_MAP = {
34
+ "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
35
+ "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
36
+ "tokenizer_file": {
37
+ "qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/tokenizer.json"
38
+ },
39
+ }
40
+
41
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
+
43
+
44
+ class Qwen2TokenizerFast(PreTrainedTokenizerFast):
45
+ """
46
+ Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
47
+ Byte-Pair-Encoding.
48
+
49
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
50
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
51
+
52
+ ```python
53
+ >>> from transformers import Qwen2TokenizerFast
54
+
55
+ >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
56
+ >>> tokenizer("Hello world")["input_ids"]
57
+ [9707, 1879]
58
+
59
+ >>> tokenizer(" Hello world")["input_ids"]
60
+ [21927, 1879]
61
+ ```
62
+ This is expected.
63
+
64
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
65
+ refer to this superclass for more information regarding those methods.
66
+
67
+ Args:
68
+ vocab_file (`str`, *optional*):
69
+ Path to the vocabulary file.
70
+ merges_file (`str`, *optional*):
71
+ Path to the merges file.
72
+ tokenizer_file (`str`, *optional*):
73
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
74
+ contains everything needed to load the tokenizer.
75
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
76
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
77
+ token instead. Not applicable to this tokenizer.
78
+ bos_token (`str`, *optional*):
79
+ The beginning of sequence token. Not applicable for this tokenizer.
80
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
81
+ The end of sequence token.
82
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
83
+ The token used for padding, for example when batching sequences of different lengths.
84
+ """
85
+
86
+ vocab_files_names = VOCAB_FILES_NAMES
87
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
88
+ max_model_input_sizes = MAX_MODEL_INPUT_SIZES
89
+ model_input_names = ["input_ids", "attention_mask"]
90
+ slow_tokenizer_class = Qwen2Tokenizer
91
+
92
+ def __init__(
93
+ self,
94
+ vocab_file=None,
95
+ merges_file=None,
96
+ tokenizer_file=None,
97
+ unk_token="<|endoftext|>",
98
+ bos_token=None,
99
+ eos_token="<|endoftext|>",
100
+ pad_token="<|endoftext|>",
101
+ **kwargs,
102
+ ):
103
+ # We need to at least pass vocab_file and merges_file to base class
104
+ # in case a slow tokenizer needs to be initialized; other can be
105
+ # configured through files.
106
+ # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
107
+
108
+ bos_token = (
109
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
110
+ if isinstance(bos_token, str)
111
+ else bos_token
112
+ )
113
+ eos_token = (
114
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
115
+ if isinstance(eos_token, str)
116
+ else eos_token
117
+ )
118
+ unk_token = (
119
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
120
+ if isinstance(unk_token, str)
121
+ else unk_token
122
+ )
123
+ pad_token = (
124
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
125
+ if isinstance(pad_token, str)
126
+ else pad_token
127
+ )
128
+
129
+ super().__init__(
130
+ vocab_file,
131
+ merges_file,
132
+ tokenizer_file=tokenizer_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ **kwargs,
138
+ )
139
+
140
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
141
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
142
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
143
+ return tuple(files)
bunny/model/language_model/stable_lm/configuration_stablelm_epoch.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ StableLM Epoch model configuration"""
15
+ from transformers import PretrainedConfig
16
+ from transformers.utils import logging
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class StableLMEpochConfig(PretrainedConfig):
23
+ r"""
24
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
25
+ documentation from [`PretrainedConfig`] for more information.
26
+
27
+ Args:
28
+ vocab_size (`int`, *optional*, defaults to 50_304):
29
+ Vocabulary size of the StableLM model. Defines the number of different tokens that
30
+ can be represented by the `inputs_ids` passed when calling [`StableLMEpochModel`].
31
+ intermediate_size (`int`, *optional*, defaults to 6912):
32
+ Dimension of the MLP representations.
33
+ hidden_size (`int`, *optional*, defaults to 2560):
34
+ Dimension of the decoder layers and the pooler layer.
35
+ num_hidden_layers (`int`, *optional*, defaults to 32):
36
+ Number of hidden layers in the Transformer decoder.
37
+ num_attention_heads (`int`, *optional*, defaults to 32):
38
+ Number of attention heads for each attention layer in the Transformer encoder.
39
+ num_key_value_heads (`int`, *optional*):
40
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
41
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
42
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
43
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
44
+ by meanpooling all the original heads within that group. For more details checkout [this
45
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
46
+ `num_attention_heads`.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
48
+ The non-linear activation function (function or string).
49
+ rope_pct (`float`, *optional*, defaults to 1.0):
50
+ Percentage of hidden dimensions to allocate to rotary embeddings.
51
+ rope_theta (`float`, *optional*, defaults to 10000.0):
52
+ The base period of the RoPE embeddings.
53
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
54
+ The maximum sequence length that this model might ever be used with.
55
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
56
+ initializer_range (`float`, *optional*, defaults to 1e-5):
57
+ The standard deviation of the truncated_normal_initializer for initializing
58
+ all weight matrices.
59
+ norm_eps (`float`, *optional*, defaults to 1e-8):
60
+ The epsilon used by the normalization layers.
61
+ use_cache (`bool`, *optional*, defaults to `True`):
62
+ Whether or not the model should return the last key/values attentions
63
+ (not used by all models). Only relevant if `config.is_decoder=True`.
64
+ use_qkv_bias (`bool`, *optional*, defaults to `True`):
65
+ Whether or not the model should use bias for qkv layers.
66
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
67
+ Whether to tie weight embeddings
68
+ """
69
+ model_type = "stablelm_epoch"
70
+ keys_to_ignore_at_inference = ["past_key_values"]
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_size=50_304,
75
+ intermediate_size=6912,
76
+ hidden_size=2560,
77
+ num_hidden_layers=32,
78
+ num_attention_heads=32,
79
+ num_key_value_heads=32,
80
+ hidden_act="silu",
81
+ rope_pct=0.25,
82
+ rope_theta=10_000,
83
+ max_position_embeddings=4096,
84
+ initializer_range=0.02,
85
+ norm_eps=1.0e-5,
86
+ use_cache=True,
87
+ use_qkv_bias=True,
88
+ bos_token_id=0,
89
+ eos_token_id=2,
90
+ tie_word_embeddings=False,
91
+ **kwargs,
92
+ ):
93
+ self.vocab_size = vocab_size
94
+ self.max_position_embeddings = max_position_embeddings
95
+ self.intermediate_size = intermediate_size
96
+ self.hidden_size = hidden_size
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.num_key_value_heads = num_key_value_heads
100
+ self.hidden_act = hidden_act
101
+ self.rope_pct = rope_pct
102
+ self.rope_theta = rope_theta
103
+ self.initializer_range = initializer_range
104
+ self.norm_eps = norm_eps
105
+ self.use_cache = use_cache
106
+ self.use_qkv_bias = use_qkv_bias
107
+ self.tie_word_embeddings = tie_word_embeddings
108
+ super().__init__(
109
+ bos_token_id=bos_token_id,
110
+ eos_token_id=eos_token_id,
111
+ tie_word_embeddings=tie_word_embeddings,
112
+ **kwargs,
113
+ )
bunny/model/language_model/stable_lm/modeling_stablelm_epoch.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # This code is based off the following work:
17
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
18
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
19
+ """ PyTorch StableLM Epoch model. """
20
+ from typing import Optional, Tuple, Union
21
+ import math
22
+ import warnings
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import CrossEntropyLoss
29
+
30
+ from transformers.cache_utils import Cache
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPast,
33
+ CausalLMOutputWithPast,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import logging, is_flash_attn_greater_or_equal_2_10
37
+
38
+ from .configuration_stablelm_epoch import StableLMEpochConfig
39
+
40
+ try:
41
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
42
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
43
+ except:
44
+ flash_attn_func, flash_attn_varlen_func = None, None
45
+ index_first_axis, pad_input, unpad_input = None, None, None
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
52
+ def _get_unpad_data(attention_mask):
53
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
54
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
55
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
56
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
57
+ return (
58
+ indices,
59
+ cu_seqlens,
60
+ max_seqlen_in_batch,
61
+ )
62
+
63
+
64
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
65
+ def _make_causal_mask(
66
+ input_ids_shape: torch.Size,
67
+ dtype: torch.dtype,
68
+ device: torch.device,
69
+ past_key_values_length: int = 0,
70
+ ):
71
+ """Make causal mask used for bi-directional self-attention."""
72
+ batch_size, tgt_len = input_ids_shape
73
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(torch.float16).min, device=device)
74
+ mask_cond = torch.arange(mask.size(-1), device=device)
75
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
76
+ mask = mask.to(dtype)
77
+ if past_key_values_length > 0:
78
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
79
+ return mask[None, None, :, :].expand(batch_size, 1, tgt_len, tgt_len + past_key_values_length)
80
+
81
+
82
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
83
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
84
+ """Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, tgt_seq_len, src_seq_len]`."""
85
+ batch_size, src_len = mask.size()
86
+ tgt_len = tgt_len if tgt_len is not None else src_len
87
+
88
+ expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
89
+ inverted_mask = 1.0 - expanded_mask
90
+
91
+ return inverted_mask.masked_fill(
92
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
93
+ )
94
+
95
+
96
+ class RotaryEmbedding(nn.Module):
97
+ def __init__(
98
+ self,
99
+ dim: int,
100
+ max_position_embeddings: int,
101
+ base: int = 10_000,
102
+ device: Optional[torch.device] = None,
103
+ ):
104
+ super().__init__()
105
+
106
+ self.dim = dim
107
+ self.max_position_embeddings = max_position_embeddings
108
+ self.base = base
109
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
110
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
111
+
112
+ # Build here to make `torch.jit.trace` work.
113
+ self._set_cos_sin_cache(
114
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype(),
115
+ )
116
+
117
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
118
+ self.max_seq_len_cached = seq_len
119
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
120
+
121
+ # Don't do einsum, it converts fp32 to fp16 under AMP
122
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
123
+ freqs = torch.outer(t, self.inv_freq)
124
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
125
+ emb = torch.cat((freqs, freqs), dim=-1)
126
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
127
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
128
+
129
+ def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
130
+ # x: [batch_size, num_heads, seq_len, head_size]
131
+ if seq_len > self.max_seq_len_cached:
132
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.get_default_dtype())
133
+ return (
134
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
135
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
136
+ )
137
+
138
+
139
+ def rotate_half(x: torch.Tensor):
140
+ """Rotates half the hidden dims of the input."""
141
+ x1, x2 = torch.chunk(x, 2, dim=-1)
142
+ return torch.cat((-x2, x1), dim=-1)
143
+
144
+
145
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
146
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
147
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
148
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
149
+ cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
150
+ sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
151
+ q_embed = (q * cos) + (rotate_half(q) * sin)
152
+ k_embed = (k * cos) + (rotate_half(k) * sin)
153
+ return q_embed, k_embed
154
+
155
+
156
+ class MLP(nn.Module):
157
+ def __init__(self, config: StableLMEpochConfig):
158
+ super().__init__()
159
+ self.config = config
160
+ self.hidden_size = config.hidden_size
161
+ self.intermediate_size = config.intermediate_size
162
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
163
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
164
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
165
+ self.act_fn = nn.SiLU()
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
169
+
170
+
171
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
172
+ """
173
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
174
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
175
+ """
176
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
177
+ if n_rep == 1:
178
+ return hidden_states
179
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
180
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
181
+
182
+
183
+ class Attention(nn.Module):
184
+ def __init__(self, config: StableLMEpochConfig):
185
+ super().__init__()
186
+ self.config = config
187
+ self.hidden_size = config.hidden_size
188
+ self.num_heads = config.num_attention_heads
189
+ self.head_dim = self.hidden_size // self.num_heads
190
+ self.num_key_value_heads = config.num_key_value_heads
191
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
192
+ self.max_position_embeddings = config.max_position_embeddings
193
+ self.is_causal = True
194
+
195
+ if (self.head_dim * self.num_heads) != self.hidden_size:
196
+ raise ValueError(
197
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
198
+ f" and `num_heads`: {self.num_heads})."
199
+ )
200
+
201
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias)
202
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
203
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
204
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
205
+
206
+ self._init_rope()
207
+
208
+ def _init_rope(self):
209
+ self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
210
+ self.rotary_emb = RotaryEmbedding(
211
+ self.rotary_ndims,
212
+ max_position_embeddings=self.config.max_position_embeddings,
213
+ base=self.config.rope_theta,
214
+ )
215
+
216
+ def forward(
217
+ self,
218
+ hidden_states: torch.FloatTensor,
219
+ attention_mask: torch.FloatTensor,
220
+ position_ids: torch.LongTensor,
221
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
222
+ output_attentions: Optional[bool] = False,
223
+ use_cache: Optional[bool] = False,
224
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
225
+ bsz, q_len, _ = hidden_states.size()
226
+
227
+ query_states = self.q_proj(hidden_states)
228
+ key_states = self.k_proj(hidden_states)
229
+ value_states = self.v_proj(hidden_states)
230
+
231
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
232
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
233
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
234
+
235
+ query_rot = query_states[..., : self.rotary_ndims]
236
+ query_pass = query_states[..., self.rotary_ndims :]
237
+ key_rot = key_states[..., : self.rotary_ndims]
238
+ key_pass = key_states[..., self.rotary_ndims :]
239
+
240
+ kv_seq_len = key_states.shape[-2]
241
+ if past_key_value is not None:
242
+ kv_seq_len += past_key_value[0].shape[-2]
243
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
244
+ query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
245
+
246
+ # [batch_size, num_heads, seq_len, head_dim]
247
+ query_states = torch.cat((query_states, query_pass), dim=-1)
248
+ key_states = torch.cat((key_states, key_pass), dim=-1)
249
+
250
+ if past_key_value is not None:
251
+ # Reuse k, v, self_attention
252
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
253
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
254
+
255
+ past_key_value = (key_states, value_states) if use_cache else None
256
+
257
+ # Repeat k/v heads if n_kv_heads < n_heads
258
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
259
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
260
+
261
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
262
+
263
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
264
+ raise ValueError(
265
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
266
+ f" {attn_weights.size()}"
267
+ )
268
+
269
+ if attention_mask is not None:
270
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
271
+ raise ValueError(
272
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
273
+ )
274
+ attn_weights = attn_weights + attention_mask
275
+
276
+ # Upcast attention to fp32
277
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
278
+ attn_output = torch.matmul(attn_weights, value_states)
279
+
280
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
281
+ raise ValueError(
282
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
283
+ f" {attn_output.size()}"
284
+ )
285
+
286
+ # Merge heads
287
+ attn_output = attn_output.transpose(1, 2).contiguous()
288
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
289
+
290
+ # Final linear projection
291
+ attn_output = self.o_proj(attn_output)
292
+
293
+ if not output_attentions:
294
+ attn_weights = None
295
+
296
+ return attn_output, attn_weights, past_key_value
297
+
298
+
299
+ class FlashAttention2(Attention):
300
+ """
301
+ Reference: https://github.com/huggingface/transformers/blob/5d36025ca13d05151b7a0c761e90d429c4644a30/src/transformers/models/llama/modeling_llama.py#L456
302
+ """
303
+
304
+ def __init__(self, *args, **kwargs):
305
+ super().__init__(*args, **kwargs)
306
+
307
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
308
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
309
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
310
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
311
+
312
+ def forward(
313
+ self,
314
+ hidden_states: torch.Tensor,
315
+ attention_mask: Optional[torch.LongTensor] = None,
316
+ position_ids: Optional[torch.LongTensor] = None,
317
+ past_key_value: Optional[Cache] = None,
318
+ output_attentions: bool = False,
319
+ use_cache: bool = False,
320
+ **kwargs,
321
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
322
+ # FlashAttention2 attention does not support output_attentions
323
+ if "padding_mask" in kwargs:
324
+ warnings.warn(
325
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
326
+ )
327
+
328
+ # overwrite attention_mask with padding_mask
329
+ attention_mask = kwargs.pop("padding_mask")
330
+
331
+ output_attentions = False
332
+
333
+ bsz, q_len, _ = hidden_states.size()
334
+
335
+ query_states = self.q_proj(hidden_states)
336
+ key_states = self.k_proj(hidden_states)
337
+ value_states = self.v_proj(hidden_states)
338
+
339
+ # Flash attention requires the input to have the shape
340
+ # batch_size x seq_length x head_dim x hidden_dim
341
+ # therefore we just need to keep the original shape
342
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
343
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
344
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
345
+
346
+ query_rot = query_states[..., : self.rotary_ndims]
347
+ query_pass = query_states[..., self.rotary_ndims :]
348
+ key_rot = key_states[..., : self.rotary_ndims]
349
+ key_pass = key_states[..., self.rotary_ndims :]
350
+
351
+ kv_seq_len = key_states.shape[-2]
352
+ if past_key_value is not None:
353
+ kv_seq_len += past_key_value[0].shape[-2]
354
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
355
+ query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
356
+
357
+ # [batch_size, num_heads, seq_len, head_dim]
358
+ query_states = torch.cat((query_states, query_pass), dim=-1)
359
+ key_states = torch.cat((key_states, key_pass), dim=-1)
360
+
361
+ if past_key_value is not None:
362
+ # Reuse k, v, self_attention
363
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
364
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
365
+
366
+ past_key_value = (key_states, value_states) if use_cache else None
367
+
368
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
369
+ # to be able to avoid many of these transpose/reshape/view.
370
+ query_states = query_states.transpose(1, 2)
371
+ key_states = key_states.transpose(1, 2)
372
+ value_states = value_states.transpose(1, 2)
373
+
374
+ dropout_rate = self.attention_dropout if self.training else 0.0
375
+
376
+ attn_output = self._flash_attention_forward(
377
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
378
+ )
379
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
380
+ attn_output = self.o_proj(attn_output)
381
+
382
+ if not output_attentions:
383
+ attn_weights = None
384
+
385
+ return attn_output, attn_weights, past_key_value
386
+
387
+ def _flash_attention_forward(
388
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
389
+ ):
390
+ """
391
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
392
+ first unpad the input, then computes the attention scores and pad the final attention scores.
393
+
394
+ Args:
395
+ query_states (`torch.Tensor`):
396
+ Input query states to be passed to Flash Attention API
397
+ key_states (`torch.Tensor`):
398
+ Input key states to be passed to Flash Attention API
399
+ value_states (`torch.Tensor`):
400
+ Input value states to be passed to Flash Attention API
401
+ attention_mask (`torch.Tensor`):
402
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
403
+ position of padding tokens and 1 for the position of non-padding tokens.
404
+ dropout (`int`, *optional*):
405
+ Attention dropout
406
+ softmax_scale (`float`, *optional*):
407
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
408
+ """
409
+ if not self._flash_attn_uses_top_left_mask:
410
+ causal = self.is_causal
411
+ else:
412
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in FlashAttention2 __init__.
413
+ causal = self.is_causal and query_length != 1
414
+
415
+ # Contains at least one padding token in the sequence
416
+ if attention_mask is not None:
417
+ batch_size = query_states.shape[0]
418
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
419
+ query_states, key_states, value_states, attention_mask, query_length
420
+ )
421
+
422
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
423
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
424
+
425
+ attn_output_unpad = flash_attn_varlen_func(
426
+ query_states,
427
+ key_states,
428
+ value_states,
429
+ cu_seqlens_q=cu_seqlens_q,
430
+ cu_seqlens_k=cu_seqlens_k,
431
+ max_seqlen_q=max_seqlen_in_batch_q,
432
+ max_seqlen_k=max_seqlen_in_batch_k,
433
+ dropout_p=dropout,
434
+ softmax_scale=softmax_scale,
435
+ causal=causal,
436
+ )
437
+
438
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
439
+ else:
440
+ attn_output = flash_attn_func(
441
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
442
+ )
443
+
444
+ return attn_output
445
+
446
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
447
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
448
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
449
+
450
+ key_layer = index_first_axis(
451
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
452
+ )
453
+ value_layer = index_first_axis(
454
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
455
+ )
456
+ if query_length == kv_seq_len:
457
+ query_layer = index_first_axis(
458
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
459
+ )
460
+ cu_seqlens_q = cu_seqlens_k
461
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
462
+ indices_q = indices_k
463
+ elif query_length == 1:
464
+ max_seqlen_in_batch_q = 1
465
+ cu_seqlens_q = torch.arange(
466
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
467
+ ) # There is a memcpy here, that is very bad.
468
+ indices_q = cu_seqlens_q[:-1]
469
+ query_layer = query_layer.squeeze(1)
470
+ else:
471
+ # The -q_len: slice assumes left padding.
472
+ attention_mask = attention_mask[:, -query_length:]
473
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
474
+
475
+ return (
476
+ query_layer,
477
+ key_layer,
478
+ value_layer,
479
+ indices_q,
480
+ (cu_seqlens_q, cu_seqlens_k),
481
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
482
+ )
483
+
484
+
485
+ ATTENTION_CLASSES = {
486
+ "eager": Attention,
487
+ "flash_attention_2": FlashAttention2,
488
+ }
489
+
490
+
491
+ class DecoderLayer(nn.Module):
492
+ def __init__(self, config: StableLMEpochConfig):
493
+ super().__init__()
494
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config)
495
+ self.mlp = MLP(config)
496
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
497
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: Optional[torch.FloatTensor],
502
+ attention_mask: Optional[torch.FloatTensor] = None,
503
+ position_ids: Optional[torch.LongTensor] = None,
504
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
505
+ output_attentions: Optional[bool] = False,
506
+ use_cache: Optional[bool] = False,
507
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
508
+ residual = hidden_states
509
+
510
+ hidden_states = self.input_layernorm(hidden_states)
511
+
512
+ # Self Attention
513
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
514
+ hidden_states=hidden_states,
515
+ attention_mask=attention_mask,
516
+ position_ids=position_ids,
517
+ past_key_value=past_key_value,
518
+ output_attentions=output_attentions,
519
+ use_cache=use_cache,
520
+ )
521
+ hidden_states = residual + hidden_states
522
+
523
+ # Fully Connected
524
+ residual = hidden_states
525
+ hidden_states = self.post_attention_layernorm(hidden_states)
526
+ hidden_states = self.mlp(hidden_states)
527
+ hidden_states = residual + hidden_states
528
+
529
+ outputs = (hidden_states,)
530
+
531
+ if output_attentions:
532
+ outputs += (self_attn_weights,)
533
+
534
+ if use_cache:
535
+ outputs += (present_key_value,)
536
+
537
+ return outputs
538
+
539
+
540
+ class StableLMEpochPreTrainedModel(PreTrainedModel):
541
+ """An abstract class to handle weights initialization and a simple interface
542
+ for downloading and loading pretrained models.
543
+ """
544
+
545
+ config_class = StableLMEpochConfig
546
+ base_model_prefix = "transformer"
547
+ supports_gradient_checkpointing = True
548
+ _no_split_modules = ["DecoderLayer"]
549
+ _skip_keys_device_placement = "past_key_values"
550
+ _supports_flash_attn_2 = True
551
+
552
+ def _init_weights(self, module: nn.Module):
553
+ """Initialize the weights"""
554
+ if isinstance(module, nn.Linear):
555
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
556
+ if module.bias is not None:
557
+ module.bias.data.zero_()
558
+ elif isinstance(module, nn.Embedding):
559
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
560
+ if module.padding_idx is not None:
561
+ module.weight.data[module.padding_idx].zero_()
562
+ elif isinstance(module, nn.LayerNorm):
563
+ module.bias.data.zero_()
564
+ module.weight.data.fill_(1.0)
565
+
566
+ def _set_gradient_checkpointing(self, module: nn.Module, value=False):
567
+ if isinstance(module, StableLMEpochModel):
568
+ module.gradient_checkpointing = value
569
+
570
+
571
+ class StableLMEpochModel(StableLMEpochPreTrainedModel):
572
+ def __init__(self, config: StableLMEpochConfig):
573
+ super().__init__(config)
574
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
575
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
576
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
577
+
578
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
579
+ self.gradient_checkpointing = False
580
+ # Initialize weights and apply final processing
581
+ self.post_init()
582
+
583
+ def get_input_embeddings(self):
584
+ return self.embed_tokens
585
+
586
+ def set_input_embeddings(self, value: nn.Module):
587
+ self.embed_tokens = value
588
+
589
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
590
+ def _prepare_decoder_attention_mask(
591
+ self,
592
+ attention_mask: torch.Tensor,
593
+ input_shape: torch.Size,
594
+ inputs_embeds: torch.Tensor,
595
+ past_key_values_length: int,
596
+ ):
597
+ # Create causal mask
598
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
599
+ combined_attention_mask = None
600
+ if input_shape[-1] > 1:
601
+ combined_attention_mask = _make_causal_mask(
602
+ input_shape,
603
+ inputs_embeds.dtype,
604
+ device=inputs_embeds.device,
605
+ past_key_values_length=past_key_values_length,
606
+ )
607
+
608
+ if attention_mask is not None:
609
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
610
+ expanded_attn_mask = _expand_mask(
611
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
612
+ ).to(inputs_embeds.device)
613
+ combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
614
+
615
+ return combined_attention_mask
616
+
617
+ def forward(
618
+ self,
619
+ input_ids: Optional[torch.LongTensor] = None,
620
+ attention_mask: Optional[torch.FloatTensor] = None,
621
+ position_ids: Optional[torch.LongTensor] = None,
622
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
623
+ inputs_embeds: Optional[torch.FloatTensor] = None,
624
+ use_cache: Optional[bool] = None,
625
+ output_attentions: Optional[bool] = None,
626
+ output_hidden_states: Optional[bool] = None,
627
+ return_dict: Optional[bool] = None,
628
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
629
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
630
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
631
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
632
+
633
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
634
+
635
+ # Retrieve input_ids and inputs_embeds
636
+ if input_ids is not None and inputs_embeds is not None:
637
+ raise ValueError(
638
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
639
+ )
640
+ elif input_ids is not None:
641
+ batch_size, seq_length = input_ids.shape
642
+ elif inputs_embeds is not None:
643
+ batch_size, seq_length, _ = inputs_embeds.shape
644
+ else:
645
+ raise ValueError(
646
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
647
+ )
648
+
649
+ seq_length_with_past = seq_length
650
+ past_key_values_length = 0
651
+
652
+ if position_ids is None:
653
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
654
+ position_ids = torch.arange(
655
+ past_key_values_length,
656
+ seq_length + past_key_values_length,
657
+ dtype=torch.long,
658
+ device=device,
659
+ )
660
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
661
+ else:
662
+ position_ids = position_ids.view(-1, seq_length).long()
663
+
664
+ if inputs_embeds is None:
665
+ inputs_embeds = self.embed_tokens(input_ids)
666
+ # Embed positions
667
+ if self._use_flash_attention_2:
668
+ # 2d mask is passed through the layers
669
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
670
+ else:
671
+ if attention_mask is None:
672
+ attention_mask = torch.ones(
673
+ (batch_size, seq_length_with_past),
674
+ dtype=torch.bool,
675
+ device=inputs_embeds.device,
676
+ )
677
+ attention_mask = self._prepare_decoder_attention_mask(
678
+ attention_mask,
679
+ (batch_size, seq_length),
680
+ inputs_embeds,
681
+ past_key_values_length,
682
+ )
683
+
684
+ hidden_states = inputs_embeds
685
+
686
+ if self.gradient_checkpointing and self.training:
687
+ if use_cache:
688
+ logger.warning(
689
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
690
+ )
691
+ use_cache = False
692
+
693
+ # Decoder layers
694
+ all_hidden_states = () if output_hidden_states else None
695
+ all_self_attns = () if output_attentions else None
696
+ next_decoder_cache = () if use_cache else None
697
+
698
+ for idx, decoder_layer in enumerate(self.layers):
699
+ if output_hidden_states:
700
+ all_hidden_states += (hidden_states,)
701
+
702
+ past_key_value = (
703
+ past_key_values[idx] if past_key_values is not None else None
704
+ )
705
+
706
+ if self.gradient_checkpointing and self.training:
707
+
708
+ def create_custom_forward(module):
709
+ def custom_forward(*inputs):
710
+ # None for past_key_value
711
+ return module(*inputs, past_key_value, output_attentions)
712
+
713
+ return custom_forward
714
+
715
+ layer_outputs = torch.utils.checkpoint.checkpoint(
716
+ create_custom_forward(decoder_layer),
717
+ hidden_states,
718
+ attention_mask,
719
+ position_ids,
720
+ )
721
+ else:
722
+ layer_outputs = decoder_layer(
723
+ hidden_states,
724
+ attention_mask=attention_mask,
725
+ position_ids=position_ids,
726
+ past_key_value=past_key_value,
727
+ output_attentions=output_attentions,
728
+ use_cache=use_cache,
729
+ )
730
+
731
+ hidden_states = layer_outputs[0]
732
+
733
+ if use_cache:
734
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
735
+
736
+ if output_attentions:
737
+ all_self_attns += (layer_outputs[1],)
738
+
739
+ hidden_states = self.norm(hidden_states)
740
+
741
+ # Add hidden states from the last decoder layer
742
+ if output_hidden_states:
743
+ all_hidden_states += (hidden_states,)
744
+
745
+ next_cache = next_decoder_cache if use_cache else None
746
+ if not return_dict:
747
+ return tuple(
748
+ v
749
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
750
+ if v is not None
751
+ )
752
+ return BaseModelOutputWithPast(
753
+ last_hidden_state=hidden_states,
754
+ past_key_values=next_cache,
755
+ hidden_states=all_hidden_states,
756
+ attentions=all_self_attns,
757
+ )
758
+
759
+
760
+ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
761
+ _tied_weights_keys = ["lm_head.weight"]
762
+
763
+ def __init__(self, config: StableLMEpochConfig):
764
+ super().__init__(config)
765
+
766
+ self.model = StableLMEpochModel(config)
767
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
768
+
769
+ # Initialize weights and apply final processing
770
+ self.post_init()
771
+
772
+ def get_input_embeddings(self):
773
+ return self.model.embed_tokens
774
+
775
+ def set_input_embeddings(self, value):
776
+ self.model.embed_tokens = value
777
+
778
+ def get_output_embeddings(self):
779
+ return self.lm_head
780
+
781
+ def set_output_embeddings(self, new_embeddings: nn.Module):
782
+ self.lm_head = new_embeddings
783
+
784
+ def get_decoder(self):
785
+ return self.model
786
+
787
+ def set_decoder(self, decoder):
788
+ self.model = decoder
789
+
790
+ def forward(
791
+ self,
792
+ input_ids: Optional[torch.LongTensor] = None,
793
+ attention_mask: Optional[torch.FloatTensor] = None,
794
+ position_ids: Optional[torch.LongTensor] = None,
795
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
796
+ inputs_embeds: Optional[torch.FloatTensor] = None,
797
+ labels: Optional[torch.LongTensor] = None,
798
+ use_cache: Optional[bool] = None,
799
+ output_attentions: Optional[bool] = None,
800
+ output_hidden_states: Optional[bool] = None,
801
+ return_dict: Optional[bool] = None,
802
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
803
+ output_attentions = (
804
+ output_attentions
805
+ if output_attentions is not None
806
+ else self.config.output_attentions
807
+ )
808
+ output_hidden_states = (
809
+ output_hidden_states
810
+ if output_hidden_states is not None
811
+ else self.config.output_hidden_states
812
+ )
813
+ return_dict = (
814
+ return_dict if return_dict is not None else self.config.use_return_dict
815
+ )
816
+
817
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
818
+ outputs = self.model(
819
+ input_ids,
820
+ attention_mask=attention_mask,
821
+ position_ids=position_ids,
822
+ past_key_values=past_key_values,
823
+ inputs_embeds=inputs_embeds,
824
+ use_cache=use_cache,
825
+ output_attentions=output_attentions,
826
+ output_hidden_states=output_hidden_states,
827
+ return_dict=return_dict,
828
+ )
829
+
830
+ hidden_states = outputs[0]
831
+ logits = self.lm_head(hidden_states).float()
832
+
833
+ loss = None
834
+ if labels is not None:
835
+ # Shift so that tokens < n predict n
836
+ shift_logits = logits[..., :-1, :].contiguous()
837
+ shift_labels = labels[..., 1:].contiguous()
838
+ # Flatten the tokens
839
+ loss_fct = CrossEntropyLoss()
840
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
841
+ shift_labels = shift_labels.view(-1)
842
+ # Enable model parallelism
843
+ shift_labels = shift_labels.to(shift_logits.device)
844
+ loss = loss_fct(shift_logits, shift_labels)
845
+
846
+ if not return_dict:
847
+ output = (logits,) + outputs[1:]
848
+ return (loss,) + output if loss is not None else output
849
+
850
+ return CausalLMOutputWithPast(
851
+ loss=loss,
852
+ logits=logits,
853
+ past_key_values=outputs.past_key_values,
854
+ hidden_states=outputs.hidden_states,
855
+ attentions=outputs.attentions,
856
+ )
857
+
858
+ def prepare_inputs_for_generation(
859
+ self,
860
+ input_ids,
861
+ past_key_values: Optional[torch.Tensor] = None,
862
+ attention_mask: Optional[torch.Tensor] = None,
863
+ inputs_embeds: Optional[torch.Tensor] = None,
864
+ **kwargs,
865
+ ):
866
+ # Trim decoder_input_ids if past is used
867
+ if past_key_values is not None:
868
+ past_length = past_key_values[0][0].shape[2]
869
+
870
+ # Some generation methods already pass only the last input ID
871
+ if input_ids.shape[1] > past_length:
872
+ remove_prefix_length = past_length
873
+ else:
874
+ # Default to old behavior: keep only final ID
875
+ remove_prefix_length = input_ids.shape[1] - 1
876
+
877
+ input_ids = input_ids[:, remove_prefix_length:]
878
+
879
+ position_ids = kwargs.get("position_ids", None)
880
+ if attention_mask is not None and position_ids is None:
881
+ # Create position_ids on the fly for batch generation
882
+ position_ids = attention_mask.long().cumsum(-1) - 1
883
+ position_ids.masked_fill_(attention_mask == 0, 1)
884
+ if past_key_values:
885
+ position_ids = position_ids[:, -1].unsqueeze(-1)
886
+
887
+ # If `inputs_embeds` are passed, we only want to use them in the 1st generation step
888
+ if inputs_embeds is not None and past_key_values is None:
889
+ model_inputs = {"inputs_embeds": inputs_embeds}
890
+ else:
891
+ model_inputs = {"input_ids": input_ids}
892
+
893
+ model_inputs.update(
894
+ {
895
+ "attention_mask": attention_mask,
896
+ "past_key_values": past_key_values,
897
+ "use_cache": kwargs.get("use_cache"),
898
+ "position_ids": position_ids,
899
+ }
900
+ )
901
+ return model_inputs
902
+
903
+ @staticmethod
904
+ def _reorder_cache(past_key_values, beam_idx):
905
+ reordered_past = ()
906
+ for layer_past in past_key_values:
907
+ reordered_past += (
908
+ tuple(
909
+ past_state.index_select(0, beam_idx.to(past_state.device))
910
+ for past_state in layer_past
911
+ ),
912
+ )
913
+ return reordered_past
914
+
915
+
916
+ StableLMEpochConfig.register_for_auto_class()
917
+ StableLMEpochForCausalLM.register_for_auto_class("AutoModelForCausalLM")
bunny/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .eva_clip.eva_clip_encoder import EvaClipVisionTower
3
+ from .siglip.siglip_encoder import SiglipVisionTower, SiglipVisionTowerS2
4
+ from .clip.clip_encoder import CLIPVisionTower
5
+
6
+
7
+ def build_vision_tower(vision_tower_cfg, **kwargs):
8
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
9
+ use_s2 = getattr(vision_tower_cfg, 'use_s2', False)
10
+
11
+ if 'sig' in vision_tower.lower():
12
+ if use_s2:
13
+ return SiglipVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
14
+ else:
15
+ return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
16
+ elif 'eva' in vision_tower.lower():
17
+ if use_s2:
18
+ raise ValueError(f'Currently not supporting S2 for EVA-CLIP')
19
+ else:
20
+ return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
21
+
22
+ elif 'clip' in vision_tower.lower():
23
+ if use_s2:
24
+ raise ValueError(f'Currently not supporting S2 for CLIP')
25
+ else:
26
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
27
+
28
+ else:
29
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
bunny/model/multimodal_encoder/clip/clip_encoder.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = -2
15
+
16
+ if not delay_load:
17
+ self.load_model()
18
+ else:
19
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
20
+
21
+ def load_model(self):
22
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
23
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
24
+ self.vision_tower.requires_grad_(False)
25
+
26
+ self.is_loaded = True
27
+
28
+ def feature_select(self, image_forward_outs):
29
+ image_features = image_forward_outs.hidden_states[self.select_layer]
30
+
31
+ image_features = image_features[:, 1:]
32
+
33
+ return image_features
34
+
35
+ @torch.no_grad()
36
+ def forward(self, images):
37
+ if type(images) is list:
38
+ image_features = []
39
+ for image in images:
40
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
41
+ output_hidden_states=True)
42
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
43
+ image_features.append(image_feature)
44
+ else:
45
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
46
+ output_hidden_states=True)
47
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
48
+
49
+ return image_features
50
+
51
+ @property
52
+ def dummy_feature(self):
53
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
54
+
55
+ @property
56
+ def dtype(self):
57
+ return self.vision_tower.dtype
58
+
59
+ @property
60
+ def device(self):
61
+ return self.vision_tower.device
62
+
63
+ @property
64
+ def config(self):
65
+ if self.is_loaded:
66
+ return self.vision_tower.config
67
+ else:
68
+ return self.cfg_only
69
+
70
+ @property
71
+ def hidden_size(self):
72
+ return self.config.hidden_size
73
+
74
+ @property
75
+ def num_patches(self):
76
+ return (self.config.image_size // self.config.patch_size) ** 2
bunny/model/multimodal_encoder/eva_clip/eva_clip_encoder.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .eva_clip_processors import EvaClipImageTrainProcessor
5
+ from .eva_vit import Eva2LargePlusEncoder
6
+
7
+
8
+ class EvaClipVisionTower(nn.Module):
9
+ def __init__(self, vision_tower, args, delay_load=False):
10
+ super().__init__()
11
+
12
+ self.is_loaded = False
13
+
14
+ self.vision_tower_path = vision_tower
15
+ self.config = VisionTowerConfig()
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ else:
20
+ self.cfg_only = self.config
21
+
22
+ def load_model(self):
23
+ self.image_processor = EvaClipImageTrainProcessor(self.config.image_size)
24
+ self.vision_tower = Eva2LargePlusEncoder(self.vision_tower_path)
25
+ self.vision_tower.requires_grad_(False)
26
+
27
+ self.is_loaded = True
28
+
29
+ @torch.no_grad()
30
+ def forward(self, images):
31
+ if type(images) is list:
32
+ image_features = []
33
+ for image in images:
34
+ image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(
35
+ image.dtype)
36
+ image_features.append(image_feature)
37
+ else:
38
+ image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype)
39
+
40
+ return image_features
41
+
42
+ @property
43
+ def dtype(self):
44
+ return self.vision_tower.dtype
45
+
46
+ @property
47
+ def device(self):
48
+ return self.vision_tower.device
49
+
50
+ @property
51
+ def hidden_size(self):
52
+ return self.config.hidden_size
53
+
54
+ @property
55
+ def num_patches(self):
56
+ return (self.config.image_size // self.config.patch_size) ** 2
57
+
58
+
59
+ class VisionTowerConfig():
60
+ def __init__(self):
61
+ self.image_size = 336
62
+ self.patch_size = 14
63
+ self.hidden_size = 1024
bunny/model/multimodal_encoder/eva_clip/eva_clip_processors.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3
+ '''
4
+
5
+ from torchvision import transforms
6
+ from torchvision.transforms.functional import InterpolationMode
7
+ from transformers.image_processing_utils import BatchFeature
8
+ from PIL import Image
9
+ from transformers.image_transforms import convert_to_rgb
10
+
11
+
12
+ class BaseProcessor:
13
+ def __init__(self):
14
+ self.transform = lambda x: x
15
+ return
16
+
17
+ def __call__(self, item):
18
+ return self.transform(item)
19
+
20
+
21
+ class EvaClipImageBaseProcessor(BaseProcessor):
22
+ def __init__(self, mean=None, std=None):
23
+ self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
24
+ self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
25
+
26
+ self.normalize = transforms.Normalize(self.mean, self.std)
27
+
28
+ @property
29
+ def image_mean(self):
30
+ return self.mean
31
+
32
+
33
+ class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
34
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
35
+ super().__init__(mean=mean, std=std)
36
+
37
+ self.transform = transforms.Compose(
38
+ [
39
+ convert_to_rgb,
40
+ transforms.Resize(
41
+ image_size,
42
+ interpolation=InterpolationMode.BICUBIC,
43
+ ),
44
+ transforms.CenterCrop(image_size),
45
+ transforms.ToTensor(),
46
+ self.normalize,
47
+ ]
48
+ )
49
+
50
+ self.image_size = image_size
51
+
52
+ def preprocess(self, images, return_tensors):
53
+ if isinstance(images, Image.Image):
54
+ images = [images]
55
+ else:
56
+ assert isinstance(images, list)
57
+
58
+ transformed_images = [self.transform(image).numpy() for image in images]
59
+ data = {"pixel_values": transformed_images}
60
+
61
+ return BatchFeature(data=data, tensor_type=return_tensors)
62
+
63
+ def __call__(self, item):
64
+ return self.transform(item)
65
+
66
+ @property
67
+ def crop_size(self):
68
+ return {'height': self.image_size, 'width': self.image_size}
bunny/model/multimodal_encoder/eva_clip/eva_vit.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3
+ '''
4
+
5
+ from math import pi
6
+ import torch
7
+ from torch import nn
8
+ from einops import rearrange, repeat
9
+ import logging
10
+
11
+
12
+ def broadcat(tensors, dim=-1):
13
+ num_tensors = len(tensors)
14
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
15
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
16
+ shape_len = list(shape_lens)[0]
17
+ dim = (dim + shape_len) if dim < 0 else dim
18
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
19
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
20
+ assert all(
21
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
22
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
23
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
24
+ expanded_dims.insert(dim, (dim, dims[dim]))
25
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
26
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
27
+ return torch.cat(tensors, dim=dim)
28
+
29
+
30
+ def rotate_half(x):
31
+ x = rearrange(x, '... (d r) -> ... d r', r=2)
32
+ x1, x2 = x.unbind(dim=-1)
33
+ x = torch.stack((-x2, x1), dim=-1)
34
+ return rearrange(x, '... d r -> ... (d r)')
35
+
36
+
37
+ class VisionRotaryEmbeddingFast(nn.Module):
38
+ def __init__(
39
+ self,
40
+ dim,
41
+ pt_seq_len,
42
+ ft_seq_len=None,
43
+ custom_freqs=None,
44
+ freqs_for='lang',
45
+ theta=10000,
46
+ max_freq=10,
47
+ num_freqs=1,
48
+ patch_dropout=0.
49
+ ):
50
+ super().__init__()
51
+ if custom_freqs:
52
+ freqs = custom_freqs
53
+ elif freqs_for == 'lang':
54
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
55
+ elif freqs_for == 'pixel':
56
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
57
+ elif freqs_for == 'constant':
58
+ freqs = torch.ones(num_freqs).float()
59
+ else:
60
+ raise ValueError(f'unknown modality {freqs_for}')
61
+
62
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
63
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
64
+
65
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
66
+ freqs = repeat(freqs, '... n -> ... (n r)', r=2)
67
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
68
+
69
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
70
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
71
+
72
+ self.patch_dropout = patch_dropout
73
+
74
+ self.register_buffer("freqs_cos", freqs_cos)
75
+ self.register_buffer("freqs_sin", freqs_sin)
76
+
77
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
78
+
79
+ def forward(self, t, patch_indices_keep=None):
80
+ if patch_indices_keep is not None:
81
+ batch = t.size()[0]
82
+ batch_indices = torch.arange(batch)
83
+ batch_indices = batch_indices[..., None]
84
+
85
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
86
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
87
+
88
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
89
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
90
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
91
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
92
+
93
+ return t * freqs_cos + rotate_half(t) * freqs_sin
94
+
95
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
96
+
97
+
98
+ class LayerNorm(nn.LayerNorm):
99
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
100
+
101
+ def forward(self, x: torch.Tensor):
102
+ orig_type = x.dtype
103
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
104
+ return x.to(orig_type)
105
+
106
+
107
+ class PatchDropout(nn.Module):
108
+ """
109
+ https://arxiv.org/abs/2212.00794
110
+ """
111
+
112
+ def __init__(self, prob, exclude_first_token=True):
113
+ super().__init__()
114
+ assert 0 <= prob < 1.
115
+ self.prob = prob
116
+ self.exclude_first_token = exclude_first_token # exclude CLS token
117
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
118
+
119
+ def forward(self, x):
120
+ if not self.training or self.prob == 0.:
121
+ return x
122
+
123
+ if self.exclude_first_token:
124
+ cls_tokens, x = x[:, :1], x[:, 1:]
125
+ else:
126
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
127
+
128
+ batch = x.size()[0]
129
+ num_tokens = x.size()[1]
130
+
131
+ batch_indices = torch.arange(batch)
132
+ batch_indices = batch_indices[..., None]
133
+
134
+ keep_prob = 1 - self.prob
135
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
136
+
137
+ rand = torch.randn(batch, num_tokens)
138
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
139
+
140
+ x = x[batch_indices, patch_indices_keep]
141
+
142
+ if self.exclude_first_token:
143
+ x = torch.cat((cls_tokens, x), dim=1)
144
+
145
+ if self.training and os.getenv('RoPE') == '1':
146
+ return x, patch_indices_keep
147
+
148
+ return x
149
+
150
+
151
+ # --------------------------------------------------------
152
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
153
+ # --------------------------------------------------------
154
+ import math
155
+ import os
156
+ from functools import partial
157
+ import torch.nn as nn
158
+ import torch.nn.functional as F
159
+
160
+ try:
161
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
162
+ except:
163
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
164
+
165
+ if os.getenv('ENV_TYPE') == 'deepspeed':
166
+ try:
167
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
168
+ except:
169
+ from torch.utils.checkpoint import checkpoint
170
+ else:
171
+ from torch.utils.checkpoint import checkpoint
172
+
173
+ import xformers.ops as xops
174
+
175
+
176
+ class DropPath(nn.Module):
177
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
178
+ """
179
+
180
+ def __init__(self, drop_prob=None):
181
+ super(DropPath, self).__init__()
182
+ self.drop_prob = drop_prob
183
+
184
+ def forward(self, x):
185
+ return drop_path(x, self.drop_prob, self.training)
186
+
187
+ def extra_repr(self) -> str:
188
+ return 'p={}'.format(self.drop_prob)
189
+
190
+
191
+ class Mlp(nn.Module):
192
+ def __init__(
193
+ self,
194
+ in_features,
195
+ hidden_features=None,
196
+ out_features=None,
197
+ act_layer=nn.GELU,
198
+ norm_layer=nn.LayerNorm,
199
+ drop=0.,
200
+ subln=False,
201
+
202
+ ):
203
+ super().__init__()
204
+ out_features = out_features or in_features
205
+ hidden_features = hidden_features or in_features
206
+ self.fc1 = nn.Linear(in_features, hidden_features)
207
+ self.act = act_layer()
208
+
209
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
210
+
211
+ self.fc2 = nn.Linear(hidden_features, out_features)
212
+ self.drop = nn.Dropout(drop)
213
+
214
+ def forward(self, x):
215
+ x = self.fc1(x)
216
+ x = self.act(x)
217
+ # x = self.drop(x)
218
+ # commit this for the orignal BERT implement
219
+ x = self.ffn_ln(x)
220
+
221
+ x = self.fc2(x)
222
+ x = self.drop(x)
223
+ return x
224
+
225
+
226
+ class SwiGLU(nn.Module):
227
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
228
+ norm_layer=nn.LayerNorm, subln=False):
229
+ super().__init__()
230
+ out_features = out_features or in_features
231
+ hidden_features = hidden_features or in_features
232
+
233
+ self.w1 = nn.Linear(in_features, hidden_features)
234
+ self.w2 = nn.Linear(in_features, hidden_features)
235
+
236
+ self.act = act_layer()
237
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
238
+ self.w3 = nn.Linear(hidden_features, out_features)
239
+
240
+ self.drop = nn.Dropout(drop)
241
+
242
+ def forward(self, x):
243
+ x1 = self.w1(x)
244
+ x2 = self.w2(x)
245
+ hidden = self.act(x1) * x2
246
+ x = self.ffn_ln(hidden)
247
+ x = self.w3(x)
248
+ x = self.drop(x)
249
+ return x
250
+
251
+
252
+ class Attention(nn.Module):
253
+ def __init__(
254
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
255
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False,
256
+ norm_layer=nn.LayerNorm):
257
+ super().__init__()
258
+ self.num_heads = num_heads
259
+ head_dim = dim // num_heads
260
+ if attn_head_dim is not None:
261
+ head_dim = attn_head_dim
262
+ all_head_dim = head_dim * self.num_heads
263
+ self.scale = qk_scale or head_dim ** -0.5
264
+
265
+ self.subln = subln
266
+ if self.subln:
267
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
268
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
269
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
270
+ else:
271
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
272
+
273
+ if qkv_bias:
274
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
275
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
276
+ else:
277
+ self.q_bias = None
278
+ self.v_bias = None
279
+
280
+ if window_size:
281
+ self.window_size = window_size
282
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
283
+ self.relative_position_bias_table = nn.Parameter(
284
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
285
+ # cls to token & token 2 cls & cls to cls
286
+
287
+ # get pair-wise relative position index for each token inside the window
288
+ coords_h = torch.arange(window_size[0])
289
+ coords_w = torch.arange(window_size[1])
290
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
291
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
292
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
293
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
294
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
295
+ relative_coords[:, :, 1] += window_size[1] - 1
296
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
297
+ relative_position_index = \
298
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
299
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
300
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
301
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
302
+ relative_position_index[0, 0] = self.num_relative_distance - 1
303
+
304
+ self.register_buffer("relative_position_index", relative_position_index)
305
+ else:
306
+ self.window_size = None
307
+ self.relative_position_bias_table = None
308
+ self.relative_position_index = None
309
+
310
+ self.attn_drop = nn.Dropout(attn_drop)
311
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
312
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
313
+ self.proj = nn.Linear(all_head_dim, dim)
314
+ self.proj_drop = nn.Dropout(proj_drop)
315
+ self.xattn = xattn
316
+ self.xattn_drop = attn_drop
317
+
318
+ self.rope = rope
319
+
320
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
321
+ B, N, C = x.shape
322
+ if self.subln:
323
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
324
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
325
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
326
+
327
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
328
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
329
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
330
+ else:
331
+
332
+ qkv_bias = None
333
+ if self.q_bias is not None:
334
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
335
+
336
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
337
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
338
+ q, k, v = qkv[0], qkv[1], qkv[2]
339
+
340
+ if self.rope:
341
+ # slightly fast impl
342
+ q_t = q[:, :, 1:, :]
343
+ ro_q_t = self.rope(q_t)
344
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
345
+
346
+ k_t = k[:, :, 1:, :]
347
+ ro_k_t = self.rope(k_t)
348
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
349
+
350
+ if self.xattn:
351
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
352
+ k = k.permute(0, 2, 1, 3)
353
+ v = v.permute(0, 2, 1, 3)
354
+
355
+ x = xops.memory_efficient_attention(
356
+ q, k, v,
357
+ p=self.xattn_drop,
358
+ scale=self.scale,
359
+ )
360
+ x = x.reshape(B, N, -1)
361
+ x = self.inner_attn_ln(x)
362
+ x = self.proj(x)
363
+ x = self.proj_drop(x)
364
+ else:
365
+ q = q * self.scale
366
+ attn = (q @ k.transpose(-2, -1))
367
+
368
+ if self.relative_position_bias_table is not None:
369
+ relative_position_bias = \
370
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
371
+ self.window_size[0] * self.window_size[1] + 1,
372
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
373
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
374
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
375
+
376
+ if rel_pos_bias is not None:
377
+ attn = attn + rel_pos_bias.type_as(attn)
378
+
379
+ if attn_mask is not None:
380
+ attn_mask = attn_mask.bool()
381
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
382
+
383
+ attn = attn.softmax(dim=-1)
384
+ attn = self.attn_drop(attn)
385
+
386
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
387
+ x = self.inner_attn_ln(x)
388
+ x = self.proj(x)
389
+ x = self.proj_drop(x)
390
+ return x
391
+
392
+
393
+ class Block(nn.Module):
394
+
395
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
396
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
397
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
398
+ subln=False, naiveswiglu=False):
399
+ super().__init__()
400
+ self.norm1 = norm_layer(dim)
401
+ self.attn = Attention(
402
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
403
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
404
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
405
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
406
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
407
+ self.norm2 = norm_layer(dim)
408
+ mlp_hidden_dim = int(dim * mlp_ratio)
409
+
410
+ if naiveswiglu:
411
+ self.mlp = SwiGLU(
412
+ in_features=dim,
413
+ hidden_features=mlp_hidden_dim,
414
+ subln=subln,
415
+ norm_layer=norm_layer,
416
+ )
417
+ else:
418
+ self.mlp = Mlp(
419
+ in_features=dim,
420
+ hidden_features=mlp_hidden_dim,
421
+ act_layer=act_layer,
422
+ subln=subln,
423
+ drop=drop
424
+ )
425
+
426
+ if init_values is not None and init_values > 0:
427
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
428
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
429
+ else:
430
+ self.gamma_1, self.gamma_2 = None, None
431
+
432
+ self.postnorm = postnorm
433
+
434
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
435
+ if self.gamma_1 is None:
436
+ if self.postnorm:
437
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
438
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
439
+ else:
440
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
441
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
442
+ else:
443
+ if self.postnorm:
444
+ x = x + self.drop_path(
445
+ self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
446
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
447
+ else:
448
+ x = x + self.drop_path(
449
+ self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
450
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
451
+ return x
452
+
453
+
454
+ class PatchEmbed(nn.Module):
455
+ """ Image to Patch Embedding
456
+ """
457
+
458
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
459
+ super().__init__()
460
+ img_size = to_2tuple(img_size)
461
+ patch_size = to_2tuple(patch_size)
462
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
463
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
464
+ self.img_size = img_size
465
+ self.patch_size = patch_size
466
+ self.num_patches = num_patches
467
+
468
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
469
+
470
+ def forward(self, x, **kwargs):
471
+ B, C, H, W = x.shape
472
+ # FIXME look at relaxing size constraints
473
+ assert H == self.img_size[0] and W == self.img_size[1], \
474
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
475
+ x = self.proj(x).flatten(2).transpose(1, 2)
476
+ return x
477
+
478
+
479
+ class RelativePositionBias(nn.Module):
480
+
481
+ def __init__(self, window_size, num_heads):
482
+ super().__init__()
483
+ self.window_size = window_size
484
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
485
+ self.relative_position_bias_table = nn.Parameter(
486
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
487
+ # cls to token & token 2 cls & cls to cls
488
+
489
+ # get pair-wise relative position index for each token inside the window
490
+ coords_h = torch.arange(window_size[0])
491
+ coords_w = torch.arange(window_size[1])
492
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
493
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
494
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
495
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
496
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
497
+ relative_coords[:, :, 1] += window_size[1] - 1
498
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
499
+ relative_position_index = \
500
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
501
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
502
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
503
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
504
+ relative_position_index[0, 0] = self.num_relative_distance - 1
505
+
506
+ self.register_buffer("relative_position_index", relative_position_index)
507
+
508
+ def forward(self):
509
+ relative_position_bias = \
510
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
511
+ self.window_size[0] * self.window_size[1] + 1,
512
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
513
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
514
+
515
+
516
+ class EVAVisionTransformer(nn.Module):
517
+ """ Vision Transformer with support for patch or hybrid CNN input stage
518
+ """
519
+
520
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
521
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
522
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
523
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
524
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
525
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
526
+ super().__init__()
527
+ self.image_size = img_size
528
+ self.num_classes = num_classes
529
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
530
+
531
+ self.patch_embed = PatchEmbed(
532
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
533
+ num_patches = self.patch_embed.num_patches
534
+
535
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
536
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
537
+ if use_abs_pos_emb:
538
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
539
+ else:
540
+ self.pos_embed = None
541
+ self.pos_drop = nn.Dropout(p=drop_rate)
542
+
543
+ if use_shared_rel_pos_bias:
544
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
545
+ else:
546
+ self.rel_pos_bias = None
547
+
548
+ if rope:
549
+ half_head_dim = embed_dim // num_heads // 2
550
+ hw_seq_len = img_size // patch_size
551
+ self.rope = VisionRotaryEmbeddingFast(
552
+ dim=half_head_dim,
553
+ pt_seq_len=pt_hw_seq_len,
554
+ ft_seq_len=hw_seq_len if intp_freq else None,
555
+ # patch_dropout=patch_dropout
556
+ )
557
+ else:
558
+ self.rope = None
559
+
560
+ self.naiveswiglu = naiveswiglu
561
+
562
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
563
+ self.use_rel_pos_bias = use_rel_pos_bias
564
+ self.blocks = nn.ModuleList([
565
+ Block(
566
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
567
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
568
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
569
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
570
+ for i in range(depth)])
571
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
572
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
573
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
574
+
575
+ if self.pos_embed is not None:
576
+ trunc_normal_(self.pos_embed, std=.02)
577
+
578
+ trunc_normal_(self.cls_token, std=.02)
579
+ # trunc_normal_(self.mask_token, std=.02)
580
+
581
+ self.apply(self._init_weights)
582
+ self.fix_init_weight()
583
+
584
+ if isinstance(self.head, nn.Linear):
585
+ trunc_normal_(self.head.weight, std=.02)
586
+ self.head.weight.data.mul_(init_scale)
587
+ self.head.bias.data.mul_(init_scale)
588
+
589
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
590
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
591
+
592
+ self.grad_checkpointing = grad_checkpointing
593
+
594
+ def fix_init_weight(self):
595
+ def rescale(param, layer_id):
596
+ param.div_(math.sqrt(2.0 * layer_id))
597
+
598
+ for layer_id, layer in enumerate(self.blocks):
599
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
600
+ if self.naiveswiglu:
601
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
602
+ else:
603
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
604
+
605
+ def get_cast_dtype(self) -> torch.dtype:
606
+ return self.blocks[0].mlp.fc2.weight.dtype
607
+
608
+ def _init_weights(self, m):
609
+ if isinstance(m, nn.Linear):
610
+ trunc_normal_(m.weight, std=.02)
611
+ if m.bias is not None:
612
+ nn.init.constant_(m.bias, 0)
613
+ elif isinstance(m, nn.LayerNorm):
614
+ nn.init.constant_(m.bias, 0)
615
+ nn.init.constant_(m.weight, 1.0)
616
+
617
+ def get_num_layers(self):
618
+ return len(self.blocks)
619
+
620
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
621
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
622
+ for param in self.parameters():
623
+ param.requires_grad = False
624
+
625
+ @torch.jit.ignore
626
+ def set_grad_checkpointing(self, enable=True):
627
+ self.grad_checkpointing = enable
628
+
629
+ @torch.jit.ignore
630
+ def no_weight_decay(self):
631
+ return {'pos_embed', 'cls_token'}
632
+
633
+ def get_classifier(self):
634
+ return self.head
635
+
636
+ def reset_classifier(self, num_classes, global_pool=''):
637
+ self.num_classes = num_classes
638
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
639
+
640
+ def forward_features(self, x, return_all_features=False):
641
+
642
+ x = self.patch_embed(x)
643
+ batch_size, seq_len, _ = x.size()
644
+
645
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
646
+ x = torch.cat((cls_tokens, x), dim=1)
647
+ if self.pos_embed is not None:
648
+ x = x + self.pos_embed
649
+ x = self.pos_drop(x)
650
+
651
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
652
+ if os.getenv('RoPE') == '1':
653
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
654
+ x, patch_indices_keep = self.patch_dropout(x)
655
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
656
+ else:
657
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
658
+ x = self.patch_dropout(x)
659
+ else:
660
+ x = self.patch_dropout(x)
661
+
662
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
663
+ for i, blk in enumerate(self.blocks):
664
+ if i == len(self.blocks) - 1:
665
+ continue
666
+ if self.grad_checkpointing:
667
+ x = checkpoint(blk, x, (rel_pos_bias,))
668
+ else:
669
+ x = blk(x, rel_pos_bias=rel_pos_bias)
670
+
671
+ if not return_all_features:
672
+ x = self.norm(x)
673
+ if self.fc_norm is not None:
674
+ return self.fc_norm(x.mean(1))
675
+ else:
676
+ return x[:, 0]
677
+ return x
678
+
679
+ def forward(self, x, return_all_features=False):
680
+ if return_all_features:
681
+ return self.forward_features(x, return_all_features)
682
+ x = self.forward_features(x)
683
+ x = self.head(x)
684
+ return x
685
+
686
+
687
+ def load_state_dict(checkpoint_path: str, map_location: str = 'cpu', model_key: str = 'model|module|state_dict',
688
+ is_openai: bool = False, skip_list: list = []):
689
+ if is_openai:
690
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
691
+ state_dict = model.state_dict()
692
+ for key in ["input_resolution", "context_length", "vocab_size"]:
693
+ state_dict.pop(key, None)
694
+ else:
695
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
696
+ for mk in model_key.split('|'):
697
+ if isinstance(checkpoint, dict) and mk in checkpoint:
698
+ state_dict = checkpoint[mk]
699
+ break
700
+ else:
701
+ state_dict = checkpoint
702
+ if next(iter(state_dict.items()))[0].startswith('module'):
703
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
704
+
705
+ for k in skip_list:
706
+ if k in list(state_dict.keys()):
707
+ logging.info(f"Removing key {k} from pretrained checkpoint")
708
+ del state_dict[k]
709
+
710
+ if os.getenv('RoPE') == '1':
711
+ for k in list(state_dict.keys()):
712
+ if 'freqs_cos' in k or 'freqs_sin' in k:
713
+ del state_dict[k]
714
+ return state_dict
715
+
716
+
717
+ def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = 'cpu', is_openai: bool = False,
718
+ skip_list: list = []):
719
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
720
+
721
+ for k in list(state_dict.keys()):
722
+ if not k.startswith('visual.'):
723
+ del state_dict[k]
724
+ for k in list(state_dict.keys()):
725
+ if k.startswith('visual.'):
726
+ new_k = k[7:]
727
+ state_dict[new_k] = state_dict[k]
728
+ del state_dict[k]
729
+ return state_dict
730
+
731
+
732
+ from dataclasses import dataclass
733
+ from typing import Optional, Tuple, Union
734
+
735
+ try:
736
+ from apex.normalization import FusedLayerNorm
737
+ except:
738
+ FusedLayerNorm = LayerNorm
739
+ print(
740
+ "Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source .")
741
+
742
+
743
+ @dataclass
744
+ class CLIPVisionCfg:
745
+ layers: Union[Tuple[int, int, int, int], int] = 12
746
+ width: int = 768
747
+ head_width: int = 64
748
+ mlp_ratio: float = 4.0
749
+ patch_size: int = 16
750
+ image_size: Union[Tuple[int, int], int] = 224
751
+ ls_init_value: Optional[float] = None # layer scale initial value
752
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
753
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
754
+ drop_path_rate: Optional[float] = None # drop path rate
755
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
756
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
757
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
758
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
759
+ timm_proj_bias: bool = False # enable bias final projection
760
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
761
+ qkv_bias: bool = True
762
+ fusedLN: bool = False
763
+ xattn: bool = False
764
+ postnorm: bool = False
765
+ rope: bool = False
766
+ pt_hw_seq_len: int = 16 # 224/14
767
+ intp_freq: bool = False
768
+ naiveswiglu: bool = False
769
+ subln: bool = False
770
+
771
+
772
+ def _build_vision_tower(
773
+ vision_tower_path: str,
774
+ embed_dim: int,
775
+ vision_cfg: CLIPVisionCfg
776
+ ):
777
+ if isinstance(vision_cfg, dict):
778
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
779
+
780
+ if vision_cfg.eva_model_name:
781
+ vision_heads = vision_cfg.width // vision_cfg.head_width
782
+ norm_layer = LayerNorm
783
+
784
+ visual = EVAVisionTransformer(
785
+ img_size=vision_cfg.image_size,
786
+ patch_size=vision_cfg.patch_size,
787
+ num_classes=embed_dim,
788
+ use_mean_pooling=vision_cfg.global_average_pool, # False
789
+ init_values=vision_cfg.ls_init_value,
790
+ patch_dropout=vision_cfg.patch_dropout,
791
+ embed_dim=vision_cfg.width,
792
+ depth=vision_cfg.layers,
793
+ num_heads=vision_heads,
794
+ mlp_ratio=vision_cfg.mlp_ratio,
795
+ qkv_bias=vision_cfg.qkv_bias,
796
+ drop_path_rate=vision_cfg.drop_path_rate,
797
+ norm_layer=partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
798
+ xattn=vision_cfg.xattn,
799
+ rope=vision_cfg.rope,
800
+ postnorm=vision_cfg.postnorm,
801
+ pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
802
+ intp_freq=vision_cfg.intp_freq,
803
+ naiveswiglu=vision_cfg.naiveswiglu,
804
+ subln=vision_cfg.subln
805
+ )
806
+
807
+ state_dict = load_clip_visual_state_dict(vision_tower_path)
808
+ incompatible_keys = visual.load_state_dict(state_dict, strict=False)
809
+ print('EVA-CLIP incompatible_keys:', incompatible_keys)
810
+
811
+ return visual
812
+
813
+
814
+ class Eva2LargePlusEncoder(nn.Module):
815
+ def __init__(self, vision_tower_path):
816
+ super(Eva2LargePlusEncoder, self).__init__()
817
+ self.config = {
818
+ "embed_dim": 768,
819
+ "vision_cfg": {
820
+ "image_size": 336,
821
+ "layers": 24,
822
+ "width": 1024,
823
+ "drop_path_rate": 0,
824
+ "head_width": 64,
825
+ "mlp_ratio": 2.6667,
826
+ "patch_size": 14,
827
+ "eva_model_name": "eva-clip-l-14-336",
828
+ "xattn": True,
829
+ "fusedLN": True,
830
+ "rope": True,
831
+ "pt_hw_seq_len": 16,
832
+ "intp_freq": True,
833
+ "naiveswiglu": True,
834
+ "subln": True
835
+ }
836
+ }
837
+
838
+ self.config['vision_tower_path'] = vision_tower_path
839
+ self.model = _build_vision_tower(**self.config)
840
+
841
+ def forward(self, image, **kwargs):
842
+ encode = self.model(image, return_all_features=True)[:, 1:, :]
843
+ return encode
844
+
845
+ @property
846
+ def dtype(self):
847
+ return list(self.parameters())[-1].dtype
848
+
849
+ @property
850
+ def device(self):
851
+ return list(self.parameters())[-1].device
bunny/model/multimodal_encoder/siglip/siglip_encoder.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig
5
+ from bunny.util.s2wrapper import forward as multiscale_forward
6
+
7
+
8
+ class SiglipVisionTower(nn.Module):
9
+ def __init__(self, vision_tower, args, delay_load=False):
10
+ super().__init__()
11
+
12
+ self.is_loaded = False
13
+
14
+ self.vision_tower_name = vision_tower
15
+ self.select_layer = -2
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ else:
20
+ self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
21
+
22
+ def load_model(self):
23
+ self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
24
+ self.image_processor.crop_size = self.image_processor.size
25
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
26
+ self.vision_tower.requires_grad_(False)
27
+
28
+ self.is_loaded = True
29
+
30
+ def feature_select(self, image_forward_outs):
31
+ image_features = image_forward_outs.hidden_states[self.select_layer]
32
+
33
+ return image_features
34
+
35
+ @torch.no_grad()
36
+ def forward(self, images):
37
+ if type(images) is list:
38
+ image_features = []
39
+ for image in images:
40
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
41
+ output_hidden_states=True)
42
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
43
+ image_features.append(image_feature)
44
+ else:
45
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
46
+ output_hidden_states=True)
47
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
48
+
49
+ return image_features
50
+
51
+ @property
52
+ def dummy_feature(self):
53
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
54
+
55
+ @property
56
+ def dtype(self):
57
+ return self.vision_tower.dtype
58
+
59
+ @property
60
+ def device(self):
61
+ return self.vision_tower.device
62
+
63
+ @property
64
+ def config(self):
65
+ if self.is_loaded:
66
+ return self.vision_tower.config
67
+ else:
68
+ return self.cfg_only
69
+
70
+ @property
71
+ def hidden_size(self):
72
+ return self.config.hidden_size
73
+
74
+ @property
75
+ def num_patches(self):
76
+ return (self.config.image_size // self.config.patch_size) ** 2
77
+
78
+
79
+ class SiglipVisionTowerS2(SiglipVisionTower):
80
+ def __init__(self, vision_tower, args, delay_load=False):
81
+ self.s2_scales = getattr(args, 's2_scales', '384,768,1152')
82
+ self.s2_scales = list(map(int, self.s2_scales.split(',')))
83
+ self.s2_scales.sort()
84
+ self.s2_split_size = self.s2_scales[0]
85
+ self.s2_image_size = self.s2_scales[-1]
86
+
87
+ super().__init__(vision_tower, args, delay_load)
88
+
89
+ self.multiscale_forward = multiscale_forward
90
+
91
+ if not delay_load:
92
+ self.image_processor.size['height'] = self.image_processor.size['width'] = self.s2_image_size
93
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
94
+
95
+ def load_model(self):
96
+ self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
97
+ self.image_processor.crop_size = self.image_processor.size
98
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
99
+ self.vision_tower.requires_grad_(False)
100
+
101
+ self.image_processor.size['height'] = self.image_processor.size['width'] = self.s2_image_size
102
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
103
+
104
+ self.is_loaded = True
105
+
106
+ @torch.no_grad()
107
+ def forward_feature(self, images):
108
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
109
+ output_hidden_states=True)
110
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
111
+ return image_features
112
+
113
+ @torch.no_grad()
114
+ def forward(self, images):
115
+ if type(images) is list:
116
+ image_features = []
117
+ for image in images:
118
+ image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0),
119
+ img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
120
+ image_features.append(image_feature)
121
+ else:
122
+ image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales,
123
+ max_split_size=self.s2_split_size)
124
+
125
+ return image_features
126
+
127
+ @property
128
+ def hidden_size(self):
129
+ return self.config.hidden_size * len(self.s2_scales)
bunny/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ from torch import nn
4
+ from functools import partial
5
+ from timm.layers.norm_act import LayerNormAct2d
6
+ from torchvision.ops.misc import SqueezeExcitation as SELayer
7
+ from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig
8
+
9
+
10
+ class IdentityMap(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def forward(self, x, *args, **kwargs):
15
+ return x
16
+
17
+ @property
18
+ def config(self):
19
+ return {"mm_projector_type": 'identity'}
20
+
21
+
22
+ class Minigpt(nn.Module):
23
+ def __init__(self, config=None):
24
+ super(Minigpt, self).__init__()
25
+ # c*4 is the input size, and c is the output size for the linear layer
26
+ inc, ouc = config.mm_hidden_size, config.hidden_size
27
+ self.linear = nn.Linear(inc * 4, ouc)
28
+
29
+ def forward(self, x):
30
+ # x is the input tensor with shape [b, num_tokens, c]
31
+ b, num_tokens, c = x.shape
32
+
33
+ # Check if num_tokens is divisible by 4
34
+ if num_tokens % 4 != 0:
35
+ raise ValueError("num_tokens must be divisible by 4")
36
+
37
+ # Reshape x to [b, num_tokens/4, c*4]
38
+ x = x.view(b, num_tokens // 4, c * 4)
39
+
40
+ # Apply the linear transformation
41
+ x = self.linear(x)
42
+ return x
43
+
44
+
45
+ class Vanilla(nn.Module):
46
+ def __init__(self, config=None):
47
+ super(Vanilla, self).__init__()
48
+ # c*4 is the input size, and c is the output size for the linear layer
49
+ inc, ouc = config.mm_hidden_size, config.hidden_size
50
+ self.linear = nn.Linear(inc * 4, ouc)
51
+
52
+ def forward(self, x):
53
+ b, num_tokens, c = x.shape
54
+
55
+ # Check if num_tokens is divisible by 4
56
+ if num_tokens % 4 != 0:
57
+ raise ValueError("num_tokens must be divisible by 4")
58
+
59
+ # First, reshape to [b, num_tokens//4, 4, c]
60
+ x = x.view(b, num_tokens // 4, 4, c)
61
+
62
+ # Then, permute to interleave the tokens
63
+ x = x.permute(0, 1, 3, 2).contiguous()
64
+
65
+ # Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens
66
+ x = x.view(b, num_tokens // 4, c * 4)
67
+
68
+ # Apply the linear transformation
69
+ x = self.linear(x)
70
+ return x
71
+
72
+
73
+ class LDPBlock(nn.Module):
74
+ # Lightweight Downsample Projector Block
75
+
76
+ def __init__(self, config=None):
77
+ super().__init__()
78
+
79
+ inc, ouc = config.mm_hidden_size, config.hidden_size
80
+ layer_norm = partial(LayerNormAct2d, act_layer=None)
81
+ se_layer = partial(SELayer, scale_activation=nn.Hardsigmoid)
82
+ self.mlp = nn.Sequential(
83
+ nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc)
84
+ )
85
+ self.mb_block = nn.Sequential(
86
+ nn.Identity(),
87
+ InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer),
88
+ InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer)
89
+ )
90
+
91
+ def forward(self, x):
92
+ b, num_tokens, c = x.shape
93
+ h = int(math.sqrt(num_tokens))
94
+ x = self.mlp(x)
95
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
96
+ x = self.mb_block(x)
97
+ x = x.flatten(2).permute(0, 2, 1)
98
+ return x
99
+
100
+
101
+ class LDPNetProjector(nn.Module):
102
+
103
+ def __init__(self, config=None):
104
+ super().__init__()
105
+ self.model = LDPBlock(config)
106
+
107
+ def forward(self, x):
108
+ return self.model(x)
109
+
110
+
111
+ class SPP(nn.Module):
112
+
113
+ def __init__(self, config=None, projector_type='v1'):
114
+ super().__init__()
115
+
116
+ self.projector_type = projector_type
117
+
118
+ inc, ouc = config.mm_hidden_size, config.hidden_size
119
+ self.linear_0 = nn.Linear(inc, inc)
120
+
121
+ self.linear_1 = nn.Linear(inc, ouc)
122
+
123
+ self.pooling = nn.AvgPool2d(kernel_size=2)
124
+
125
+ self.linear_2 = nn.Linear(ouc, ouc)
126
+
127
+ def forward(self, x):
128
+ b, num_tokens, c = x.shape
129
+ h = int(math.sqrt(num_tokens))
130
+ if 'v1' in self.projector_type:
131
+ x = self.linear_1(x)
132
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
133
+ x = self.pooling(x)
134
+ x = x.flatten(2).permute(0, 2, 1)
135
+ x = self.linear_2(x)
136
+ elif 'v2' in self.projector_type:
137
+ x = self.linear_1(x)
138
+ x = self.linear_2(x)
139
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
140
+ x = self.pooling(x)
141
+ x = x.flatten(2).permute(0, 2, 1)
142
+ elif 'v3' in self.projector_type:
143
+ x = self.linear_0(x)
144
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
145
+ x = self.pooling(x)
146
+ x = x.flatten(2).permute(0, 2, 1)
147
+ x = self.linear_1(x)
148
+ x = self.linear_2(x)
149
+ return x
150
+
151
+
152
+ def build_vision_projector(config, delay_load=False, **kwargs):
153
+ projector_type = getattr(config, 'mm_projector_type', 'mlp2x_gelu')
154
+
155
+ if projector_type == 'linear':
156
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
157
+
158
+ elif projector_type.startswith('mlp'):
159
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
160
+ if mlp_gelu_match:
161
+ mlp_depth = int(mlp_gelu_match.group(1))
162
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
163
+ for _ in range(1, mlp_depth):
164
+ modules.append(nn.GELU())
165
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
166
+ return nn.Sequential(*modules)
167
+
168
+ elif projector_type.startswith('spp'):
169
+ return SPP(config, projector_type)
170
+
171
+ elif projector_type == 'ldp':
172
+ return LDPNetProjector(config)
173
+
174
+ elif projector_type == 'vanilla':
175
+ return Vanilla(config)
176
+
177
+ elif projector_type == 'minigpt':
178
+ return Minigpt(config)
179
+
180
+ elif projector_type == 'identity':
181
+ return IdentityMap()
182
+
183
+ raise ValueError(f'Unknown projector type: {projector_type}')
bunny/serve/cli.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import requests
4
+
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ from transformers import TextStreamer
8
+
9
+ from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
10
+ from bunny.conversation import conv_templates, SeparatorStyle
11
+ from bunny.model.builder import load_pretrained_model
12
+ from bunny.util.utils import disable_torch_init
13
+ from bunny.util.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, \
14
+ KeywordsStoppingCriteria
15
+
16
+
17
+ def load_image(image_file):
18
+ if image_file.startswith('http://') or image_file.startswith('https://'):
19
+ response = requests.get(image_file)
20
+ image = Image.open(BytesIO(response.content)).convert('RGB')
21
+ else:
22
+ image = Image.open(image_file).convert('RGB')
23
+ return image
24
+
25
+
26
+ def main(args):
27
+ # Model
28
+ disable_torch_init()
29
+
30
+ model_name = get_model_name_from_path(args.model_path)
31
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name,
32
+ args.model_type, args.load_8bit,
33
+ args.load_4bit, device=args.device)
34
+
35
+ conv_mode = "bunny"
36
+
37
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
38
+ print(
39
+ '[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode,
40
+ args.conv_mode,
41
+ args.conv_mode))
42
+ else:
43
+ args.conv_mode = conv_mode
44
+
45
+ conv = conv_templates[args.conv_mode].copy()
46
+ roles = conv.roles
47
+
48
+ image = load_image(args.image_file)
49
+ # Similar operation in model_worker.py
50
+ image_tensor = process_images([image], image_processor, model.config)
51
+ if type(image_tensor) is list:
52
+ image_tensor = [image.to(model.device, dtype=model.dtype) for image in image_tensor]
53
+ else:
54
+ image_tensor = image_tensor.to(model.device, dtype=model.dtype)
55
+
56
+ while True:
57
+ try:
58
+ inp = input(f"{roles[0]}: ")
59
+ except EOFError:
60
+ inp = ""
61
+ if not inp:
62
+ print("exit...")
63
+ break
64
+
65
+ print(f"{roles[1]}: ", end="")
66
+
67
+ if image is not None:
68
+ # first message
69
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
70
+ conv.append_message(conv.roles[0], inp)
71
+ image = None
72
+ else:
73
+ conv.append_message(conv.roles[0], inp)
74
+ conv.append_message(conv.roles[1], None)
75
+ prompt = conv.get_prompt()
76
+
77
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(
78
+ model.device)
79
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
80
+ keywords = [stop_str]
81
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
82
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
83
+
84
+ with torch.inference_mode():
85
+ output_ids = model.generate(
86
+ input_ids,
87
+ images=image_tensor,
88
+ do_sample=True if args.temperature > 0 else False,
89
+ temperature=args.temperature,
90
+ max_new_tokens=args.max_new_tokens,
91
+ streamer=streamer,
92
+ use_cache=True,
93
+ repetition_penalty=args.repetition_penalty,
94
+ stopping_criteria=[stopping_criteria])
95
+
96
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
97
+ conv.messages[-1][-1] = outputs
98
+
99
+ if args.debug:
100
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
101
+
102
+
103
+ if __name__ == "__main__":
104
+ parser = argparse.ArgumentParser()
105
+ parser.add_argument("--model-path", type=str, default=None)
106
+ parser.add_argument("--model-base", type=str, default=None)
107
+ parser.add_argument("--model-type", type=str, default=None)
108
+ parser.add_argument("--image-file", type=str, required=True)
109
+ parser.add_argument("--device", type=str, default="cuda")
110
+ parser.add_argument("--conv-mode", type=str, default=None)
111
+ parser.add_argument("--temperature", type=float, default=0.2)
112
+ parser.add_argument("--repetition-penalty", type=float, default=1.0)
113
+ parser.add_argument("--max-new-tokens", type=int, default=512)
114
+ parser.add_argument("--load-8bit", action="store_true")
115
+ parser.add_argument("--load-4bit", action="store_true")
116
+ parser.add_argument("--debug", action="store_true")
117
+ args = parser.parse_args()
118
+ main(args)
bunny/serve/controller.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import dataclasses
7
+ import threading
8
+ import json
9
+ import time
10
+ import numpy as np
11
+ import requests
12
+ import uvicorn
13
+
14
+ from typing import List
15
+ from enum import Enum, auto
16
+ from fastapi import FastAPI, Request
17
+ from fastapi.responses import StreamingResponse
18
+
19
+ from bunny.constants import CONTROLLER_HEART_BEAT_EXPIRATION
20
+ from bunny.util.utils import build_logger, server_error_msg
21
+
22
+ logger = build_logger("controller", "controller.log")
23
+
24
+
25
+ class DispatchMethod(Enum):
26
+ LOTTERY = auto()
27
+ SHORTEST_QUEUE = auto()
28
+
29
+ @classmethod
30
+ def from_str(cls, name):
31
+ if name == "lottery":
32
+ return cls.LOTTERY
33
+ elif name == "shortest_queue":
34
+ return cls.SHORTEST_QUEUE
35
+ else:
36
+ raise ValueError(f"Invalid dispatch method")
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class WorkerInfo:
41
+ model_names: List[str]
42
+ speed: int
43
+ queue_length: int
44
+ check_heart_beat: bool
45
+ last_heart_beat: str
46
+
47
+
48
+ def heart_beat_controller(controller):
49
+ while True:
50
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
51
+ controller.remove_stable_workers_by_expiration()
52
+
53
+
54
+ class Controller:
55
+ def __init__(self, dispatch_method: str):
56
+ # Dict[str -> WorkerInfo]
57
+ self.worker_info = {}
58
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
59
+
60
+ self.heart_beat_thread = threading.Thread(
61
+ target=heart_beat_controller, args=(self,))
62
+ self.heart_beat_thread.start()
63
+
64
+ logger.info("Init controller")
65
+
66
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
67
+ worker_status: dict):
68
+ if worker_name not in self.worker_info:
69
+ logger.info(f"Register a new worker: {worker_name}")
70
+ else:
71
+ logger.info(f"Register an existing worker: {worker_name}")
72
+
73
+ if not worker_status:
74
+ worker_status = self.get_worker_status(worker_name)
75
+ if not worker_status:
76
+ return False
77
+
78
+ self.worker_info[worker_name] = WorkerInfo(
79
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
80
+ check_heart_beat, time.time())
81
+
82
+ logger.info(f"Register done: {worker_name}, {worker_status}")
83
+ return True
84
+
85
+ def get_worker_status(self, worker_name: str):
86
+ try:
87
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
88
+ except requests.exceptions.RequestException as e:
89
+ logger.error(f"Get status fails: {worker_name}, {e}")
90
+ return None
91
+
92
+ if r.status_code != 200:
93
+ logger.error(f"Get status fails: {worker_name}, {r}")
94
+ return None
95
+
96
+ return r.json()
97
+
98
+ def remove_worker(self, worker_name: str):
99
+ del self.worker_info[worker_name]
100
+
101
+ def refresh_all_workers(self):
102
+ old_info = dict(self.worker_info)
103
+ self.worker_info = {}
104
+
105
+ for w_name, w_info in old_info.items():
106
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
107
+ logger.info(f"Remove stale worker: {w_name}")
108
+
109
+ def list_models(self):
110
+ model_names = set()
111
+
112
+ for w_name, w_info in self.worker_info.items():
113
+ model_names.update(w_info.model_names)
114
+
115
+ return list(model_names)
116
+
117
+ def get_worker_address(self, model_name: str):
118
+ if self.dispatch_method == DispatchMethod.LOTTERY:
119
+ worker_names = []
120
+ worker_speeds = []
121
+ for w_name, w_info in self.worker_info.items():
122
+ if model_name in w_info.model_names:
123
+ worker_names.append(w_name)
124
+ worker_speeds.append(w_info.speed)
125
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
126
+ norm = np.sum(worker_speeds)
127
+ if norm < 1e-4:
128
+ return ""
129
+ worker_speeds = worker_speeds / norm
130
+
131
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
132
+ worker_name = worker_names[pt]
133
+ return worker_name
134
+
135
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
136
+ worker_names = []
137
+ worker_qlen = []
138
+ for w_name, w_info in self.worker_info.items():
139
+ if model_name in w_info.model_names:
140
+ worker_names.append(w_name)
141
+ worker_qlen.append(w_info.queue_length / w_info.speed)
142
+ if len(worker_names) == 0:
143
+ return ""
144
+ min_index = np.argmin(worker_qlen)
145
+ w_name = worker_names[min_index]
146
+ self.worker_info[w_name].queue_length += 1
147
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
148
+ return w_name
149
+ else:
150
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
151
+
152
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
153
+ if worker_name not in self.worker_info:
154
+ logger.info(f"Receive unknown heart beat. {worker_name}")
155
+ return False
156
+
157
+ self.worker_info[worker_name].queue_length = queue_length
158
+ self.worker_info[worker_name].last_heart_beat = time.time()
159
+ # logger.info(f"Receive heart beat. {worker_name}")
160
+ return True
161
+
162
+ def remove_stable_workers_by_expiration(self):
163
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
164
+ to_delete = []
165
+ for worker_name, w_info in self.worker_info.items():
166
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
167
+ to_delete.append(worker_name)
168
+
169
+ for worker_name in to_delete:
170
+ self.remove_worker(worker_name)
171
+
172
+ def worker_api_generate_stream(self, params):
173
+ worker_addr = self.get_worker_address(params["model"])
174
+ if not worker_addr:
175
+ logger.info(f"no worker: {params['model']}")
176
+ ret = {
177
+ "text": server_error_msg,
178
+ "error_code": 2,
179
+ }
180
+ yield json.dumps(ret).encode() + b"\0"
181
+
182
+ try:
183
+ response = requests.post(worker_addr + "/worker_generate_stream",
184
+ json=params, stream=True, timeout=5)
185
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
186
+ if chunk:
187
+ yield chunk + b"\0"
188
+ except requests.exceptions.RequestException as e:
189
+ logger.info(f"worker timeout: {worker_addr}")
190
+ ret = {
191
+ "text": server_error_msg,
192
+ "error_code": 3,
193
+ }
194
+ yield json.dumps(ret).encode() + b"\0"
195
+
196
+ # Let the controller act as a worker to achieve hierarchical
197
+ # management. This can be used to connect isolated sub networks.
198
+ def worker_api_get_status(self):
199
+ model_names = set()
200
+ speed = 0
201
+ queue_length = 0
202
+
203
+ for w_name in self.worker_info:
204
+ worker_status = self.get_worker_status(w_name)
205
+ if worker_status is not None:
206
+ model_names.update(worker_status["model_names"])
207
+ speed += worker_status["speed"]
208
+ queue_length += worker_status["queue_length"]
209
+
210
+ return {
211
+ "model_names": list(model_names),
212
+ "speed": speed,
213
+ "queue_length": queue_length,
214
+ }
215
+
216
+
217
+ app = FastAPI()
218
+
219
+
220
+ @app.post("/register_worker")
221
+ async def register_worker(request: Request):
222
+ data = await request.json()
223
+ controller.register_worker(
224
+ data["worker_name"], data["check_heart_beat"],
225
+ data.get("worker_status", None))
226
+
227
+
228
+ @app.post("/refresh_all_workers")
229
+ async def refresh_all_workers():
230
+ models = controller.refresh_all_workers()
231
+
232
+
233
+ @app.post("/list_models")
234
+ async def list_models():
235
+ models = controller.list_models()
236
+ return {"models": models}
237
+
238
+
239
+ @app.post("/get_worker_address")
240
+ async def get_worker_address(request: Request):
241
+ data = await request.json()
242
+ addr = controller.get_worker_address(data["model"])
243
+ return {"address": addr}
244
+
245
+
246
+ @app.post("/receive_heart_beat")
247
+ async def receive_heart_beat(request: Request):
248
+ data = await request.json()
249
+ exist = controller.receive_heart_beat(
250
+ data["worker_name"], data["queue_length"])
251
+ return {"exist": exist}
252
+
253
+
254
+ @app.post("/worker_generate_stream")
255
+ async def worker_api_generate_stream(request: Request):
256
+ params = await request.json()
257
+ generator = controller.worker_api_generate_stream(params)
258
+ return StreamingResponse(generator)
259
+
260
+
261
+ @app.post("/worker_get_status")
262
+ async def worker_api_get_status(request: Request):
263
+ return controller.worker_api_get_status()
264
+
265
+
266
+ if __name__ == "__main__":
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument("--host", type=str, default="localhost")
269
+ parser.add_argument("--port", type=int, default=21001)
270
+ parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue")
271
+ args = parser.parse_args()
272
+ logger.info(f"args: {args}")
273
+
274
+ controller = Controller(args.dispatch_method)
275
+ log_config = uvicorn.config.LOGGING_CONFIG
276
+ log_config['handlers']['default']['stream'] = 'ext://sys.stdout'
277
+ uvicorn.run(app, host=args.host, port=args.port, log_level="critical")
bunny/serve/examples/example_1.png ADDED
bunny/serve/examples/example_2.png ADDED