clfegg commited on
Commit
1643107
·
verified ·
1 Parent(s): b152350

Delete hub/models--vikhyatk--moondream2/blobs

Browse files
hub/models--vikhyatk--moondream2/blobs/0204ed10c186a4c7c68f55dff8f26087a45898d6 DELETED
@@ -1,5 +0,0 @@
1
- {
2
- "bos_token": "<|endoftext|>",
3
- "eos_token": "<|endoftext|>",
4
- "unk_token": "<|endoftext|>"
5
- }
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/226b0752cac7789c48f0cb3ec53eda48b7be36cc DELETED
The diff for this file is too large to render. See raw diff
 
hub/models--vikhyatk--moondream2/blobs/4b1f9051605c296344c271b6d21c1e2e412a99e8 DELETED
@@ -1,96 +0,0 @@
1
- from transformers import PretrainedConfig
2
-
3
-
4
- class PhiConfig(PretrainedConfig):
5
- model_type = "phi"
6
- keys_to_ignore_at_inference = ["past_key_values"]
7
-
8
- def __init__(
9
- self,
10
- vocab_size=51200,
11
- hidden_size=2048,
12
- intermediate_size=8192,
13
- num_hidden_layers=24,
14
- num_attention_heads=32,
15
- num_key_value_heads=None,
16
- resid_pdrop=0.0,
17
- embd_pdrop=0.0,
18
- attention_dropout=0.0,
19
- hidden_act="gelu_new",
20
- max_position_embeddings=2048,
21
- initializer_range=0.02,
22
- layer_norm_eps=1e-5,
23
- use_cache=True,
24
- tie_word_embeddings=False,
25
- rope_theta=10000.0,
26
- rope_scaling=None,
27
- partial_rotary_factor=0.5,
28
- bos_token_id=1,
29
- eos_token_id=2,
30
- **kwargs,
31
- ):
32
- self.vocab_size = vocab_size
33
- self.hidden_size = hidden_size
34
- self.intermediate_size = intermediate_size
35
- self.num_hidden_layers = num_hidden_layers
36
- self.num_attention_heads = num_attention_heads
37
-
38
- if num_key_value_heads is None:
39
- num_key_value_heads = num_attention_heads
40
-
41
- self.num_key_value_heads = num_key_value_heads
42
- self.resid_pdrop = resid_pdrop
43
- self.embd_pdrop = embd_pdrop
44
- self.attention_dropout = attention_dropout
45
- self.hidden_act = hidden_act
46
- self.max_position_embeddings = max_position_embeddings
47
- self.initializer_range = initializer_range
48
- self.layer_norm_eps = layer_norm_eps
49
- self.use_cache = use_cache
50
- self.rope_theta = rope_theta
51
- self.rope_scaling = rope_scaling
52
- self.partial_rotary_factor = partial_rotary_factor
53
- self._rope_scaling_validation()
54
-
55
- super().__init__(
56
- bos_token_id=bos_token_id,
57
- eos_token_id=eos_token_id,
58
- tie_word_embeddings=tie_word_embeddings,
59
- **kwargs,
60
- )
61
-
62
- # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
63
- def _rope_scaling_validation(self):
64
- """
65
- Validate the `rope_scaling` configuration.
66
- """
67
- if self.rope_scaling is None:
68
- return
69
-
70
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
71
- raise ValueError(
72
- "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
73
- f"got {self.rope_scaling}"
74
- )
75
- rope_scaling_type = self.rope_scaling.get("type", None)
76
- rope_scaling_factor = self.rope_scaling.get("factor", None)
77
- if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
78
- raise ValueError(
79
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
80
- )
81
- if (
82
- rope_scaling_factor is None
83
- or not isinstance(rope_scaling_factor, float)
84
- or rope_scaling_factor <= 1.0
85
- ):
86
- raise ValueError(
87
- f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
88
- )
89
-
90
-
91
- class MoondreamConfig(PretrainedConfig):
92
- model_type = "moondream1"
93
-
94
- def __init__(self, **kwargs):
95
- self.text_config = PhiConfig(**kwargs.pop("text_config", {}))
96
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/4bf7aed8ba4325d23fa7cd348d795a27f3b272682536f08aca4cdd62cde79293 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4bf7aed8ba4325d23fa7cd348d795a27f3b272682536f08aca4cdd62cde79293
3
- size 3736040266
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/5145e0895f2fe7f1ccb3eb9da69ec74ec9c680db DELETED
@@ -1,323 +0,0 @@
1
- {
2
- "add_prefix_space": false,
3
- "added_tokens_decoder": {
4
- "50256": {
5
- "content": "<|endoftext|>",
6
- "lstrip": false,
7
- "normalized": false,
8
- "rstrip": false,
9
- "single_word": false,
10
- "special": true
11
- },
12
- "50257": {
13
- "content": " ",
14
- "lstrip": false,
15
- "normalized": true,
16
- "rstrip": false,
17
- "single_word": false,
18
- "special": false
19
- },
20
- "50258": {
21
- "content": " ",
22
- "lstrip": false,
23
- "normalized": true,
24
- "rstrip": false,
25
- "single_word": false,
26
- "special": false
27
- },
28
- "50259": {
29
- "content": " ",
30
- "lstrip": false,
31
- "normalized": true,
32
- "rstrip": false,
33
- "single_word": false,
34
- "special": false
35
- },
36
- "50260": {
37
- "content": " ",
38
- "lstrip": false,
39
- "normalized": true,
40
- "rstrip": false,
41
- "single_word": false,
42
- "special": false
43
- },
44
- "50261": {
45
- "content": " ",
46
- "lstrip": false,
47
- "normalized": true,
48
- "rstrip": false,
49
- "single_word": false,
50
- "special": false
51
- },
52
- "50262": {
53
- "content": " ",
54
- "lstrip": false,
55
- "normalized": true,
56
- "rstrip": false,
57
- "single_word": false,
58
- "special": false
59
- },
60
- "50263": {
61
- "content": " ",
62
- "lstrip": false,
63
- "normalized": true,
64
- "rstrip": false,
65
- "single_word": false,
66
- "special": false
67
- },
68
- "50264": {
69
- "content": " ",
70
- "lstrip": false,
71
- "normalized": true,
72
- "rstrip": false,
73
- "single_word": false,
74
- "special": false
75
- },
76
- "50265": {
77
- "content": " ",
78
- "lstrip": false,
79
- "normalized": true,
80
- "rstrip": false,
81
- "single_word": false,
82
- "special": false
83
- },
84
- "50266": {
85
- "content": " ",
86
- "lstrip": false,
87
- "normalized": true,
88
- "rstrip": false,
89
- "single_word": false,
90
- "special": false
91
- },
92
- "50267": {
93
- "content": " ",
94
- "lstrip": false,
95
- "normalized": true,
96
- "rstrip": false,
97
- "single_word": false,
98
- "special": false
99
- },
100
- "50268": {
101
- "content": " ",
102
- "lstrip": false,
103
- "normalized": true,
104
- "rstrip": false,
105
- "single_word": false,
106
- "special": false
107
- },
108
- "50269": {
109
- "content": " ",
110
- "lstrip": false,
111
- "normalized": true,
112
- "rstrip": false,
113
- "single_word": false,
114
- "special": false
115
- },
116
- "50270": {
117
- "content": " ",
118
- "lstrip": false,
119
- "normalized": true,
120
- "rstrip": false,
121
- "single_word": false,
122
- "special": false
123
- },
124
- "50271": {
125
- "content": " ",
126
- "lstrip": false,
127
- "normalized": true,
128
- "rstrip": false,
129
- "single_word": false,
130
- "special": false
131
- },
132
- "50272": {
133
- "content": " ",
134
- "lstrip": false,
135
- "normalized": true,
136
- "rstrip": false,
137
- "single_word": false,
138
- "special": false
139
- },
140
- "50273": {
141
- "content": " ",
142
- "lstrip": false,
143
- "normalized": true,
144
- "rstrip": false,
145
- "single_word": false,
146
- "special": false
147
- },
148
- "50274": {
149
- "content": " ",
150
- "lstrip": false,
151
- "normalized": true,
152
- "rstrip": false,
153
- "single_word": false,
154
- "special": false
155
- },
156
- "50275": {
157
- "content": " ",
158
- "lstrip": false,
159
- "normalized": true,
160
- "rstrip": false,
161
- "single_word": false,
162
- "special": false
163
- },
164
- "50276": {
165
- "content": " ",
166
- "lstrip": false,
167
- "normalized": true,
168
- "rstrip": false,
169
- "single_word": false,
170
- "special": false
171
- },
172
- "50277": {
173
- "content": " ",
174
- "lstrip": false,
175
- "normalized": true,
176
- "rstrip": false,
177
- "single_word": false,
178
- "special": false
179
- },
180
- "50278": {
181
- "content": " ",
182
- "lstrip": false,
183
- "normalized": true,
184
- "rstrip": false,
185
- "single_word": false,
186
- "special": false
187
- },
188
- "50279": {
189
- "content": " ",
190
- "lstrip": false,
191
- "normalized": true,
192
- "rstrip": false,
193
- "single_word": false,
194
- "special": false
195
- },
196
- "50280": {
197
- "content": " ",
198
- "lstrip": false,
199
- "normalized": true,
200
- "rstrip": false,
201
- "single_word": false,
202
- "special": false
203
- },
204
- "50281": {
205
- "content": " ",
206
- "lstrip": false,
207
- "normalized": true,
208
- "rstrip": false,
209
- "single_word": false,
210
- "special": false
211
- },
212
- "50282": {
213
- "content": " ",
214
- "lstrip": false,
215
- "normalized": true,
216
- "rstrip": false,
217
- "single_word": false,
218
- "special": false
219
- },
220
- "50283": {
221
- "content": " ",
222
- "lstrip": false,
223
- "normalized": true,
224
- "rstrip": false,
225
- "single_word": false,
226
- "special": false
227
- },
228
- "50284": {
229
- "content": " ",
230
- "lstrip": false,
231
- "normalized": true,
232
- "rstrip": false,
233
- "single_word": false,
234
- "special": false
235
- },
236
- "50285": {
237
- "content": " ",
238
- "lstrip": false,
239
- "normalized": true,
240
- "rstrip": false,
241
- "single_word": false,
242
- "special": false
243
- },
244
- "50286": {
245
- "content": " ",
246
- "lstrip": false,
247
- "normalized": true,
248
- "rstrip": false,
249
- "single_word": false,
250
- "special": false
251
- },
252
- "50287": {
253
- "content": "\t\t\t\t\t\t\t\t\t",
254
- "lstrip": false,
255
- "normalized": true,
256
- "rstrip": false,
257
- "single_word": false,
258
- "special": false
259
- },
260
- "50288": {
261
- "content": "\t\t\t\t\t\t\t\t",
262
- "lstrip": false,
263
- "normalized": true,
264
- "rstrip": false,
265
- "single_word": false,
266
- "special": false
267
- },
268
- "50289": {
269
- "content": "\t\t\t\t\t\t\t",
270
- "lstrip": false,
271
- "normalized": true,
272
- "rstrip": false,
273
- "single_word": false,
274
- "special": false
275
- },
276
- "50290": {
277
- "content": "\t\t\t\t\t\t",
278
- "lstrip": false,
279
- "normalized": true,
280
- "rstrip": false,
281
- "single_word": false,
282
- "special": false
283
- },
284
- "50291": {
285
- "content": "\t\t\t\t\t",
286
- "lstrip": false,
287
- "normalized": true,
288
- "rstrip": false,
289
- "single_word": false,
290
- "special": false
291
- },
292
- "50292": {
293
- "content": "\t\t\t\t",
294
- "lstrip": false,
295
- "normalized": true,
296
- "rstrip": false,
297
- "single_word": false,
298
- "special": false
299
- },
300
- "50293": {
301
- "content": "\t\t\t",
302
- "lstrip": false,
303
- "normalized": true,
304
- "rstrip": false,
305
- "single_word": false,
306
- "special": false
307
- },
308
- "50294": {
309
- "content": "\t\t",
310
- "lstrip": false,
311
- "normalized": true,
312
- "rstrip": false,
313
- "single_word": false,
314
- "special": false
315
- }
316
- },
317
- "bos_token": "<|endoftext|>",
318
- "clean_up_tokenization_spaces": true,
319
- "eos_token": "<|endoftext|>",
320
- "model_max_length": 2048,
321
- "tokenizer_class": "CodeGenTokenizer",
322
- "unk_token": "<|endoftext|>"
323
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/619b6765140cdfaa9b9d20619cae17643a28265f DELETED
@@ -1,6 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "bos_token_id": 1,
4
- "eos_token_id": 2,
5
- "transformers_version": "4.44.0"
6
- }
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/6ac7b4364eba1fdd1d3981e4669aed01a2b0cec4 DELETED
@@ -1,43 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from .fourier_features import FourierFeatures
4
-
5
- class RegionModel(nn.Module):
6
- def __init__(self):
7
- super().__init__()
8
-
9
- self.position_features = FourierFeatures(2, 256)
10
- self.position_encoder = nn.Linear(256, 2048)
11
- self.size_features = FourierFeatures(2, 256)
12
- self.size_encoder = nn.Linear(256, 2048)
13
-
14
- self.position_decoder = nn.Linear(2048, 2)
15
- self.size_decoder = nn.Linear(2048, 2)
16
- self.confidence_decoder = nn.Linear(2048, 1)
17
-
18
- def encode_position(self, position):
19
- return self.position_encoder(self.position_features(position))
20
-
21
- def encode_size(self, size):
22
- return self.size_encoder(self.size_features(size))
23
-
24
- def decode_position(self, x):
25
- return self.position_decoder(x)
26
-
27
- def decode_size(self, x):
28
- return self.size_decoder(x)
29
-
30
- def decode_confidence(self, x):
31
- return self.confidence_decoder(x)
32
-
33
- def encode(self, position, size):
34
- return torch.stack(
35
- [self.encode_position(position), self.encode_size(size)], dim=0
36
- )
37
-
38
- def decode(self, position_logits, size_logits):
39
- return (
40
- self.decode_position(position_logits),
41
- self.decode_size(size_logits),
42
- self.decode_confidence(size_logits),
43
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/7debb4784a7d53328d4d021fc46314bec4af3833 DELETED
@@ -1,40 +0,0 @@
1
- {
2
- "\t\t": 50294,
3
- "\t\t\t": 50293,
4
- "\t\t\t\t": 50292,
5
- "\t\t\t\t\t": 50291,
6
- "\t\t\t\t\t\t": 50290,
7
- "\t\t\t\t\t\t\t": 50289,
8
- "\t\t\t\t\t\t\t\t": 50288,
9
- "\t\t\t\t\t\t\t\t\t": 50287,
10
- " ": 50286,
11
- " ": 50285,
12
- " ": 50284,
13
- " ": 50283,
14
- " ": 50282,
15
- " ": 50281,
16
- " ": 50280,
17
- " ": 50279,
18
- " ": 50278,
19
- " ": 50277,
20
- " ": 50276,
21
- " ": 50275,
22
- " ": 50274,
23
- " ": 50273,
24
- " ": 50272,
25
- " ": 50271,
26
- " ": 50270,
27
- " ": 50269,
28
- " ": 50268,
29
- " ": 50267,
30
- " ": 50266,
31
- " ": 50265,
32
- " ": 50264,
33
- " ": 50263,
34
- " ": 50262,
35
- " ": 50261,
36
- " ": 50260,
37
- " ": 50259,
38
- " ": 50258,
39
- " ": 50257
40
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/84ef7fb594b5c0979e48bdeddb60a0adef33df0b DELETED
The diff for this file is too large to render. See raw diff
 
hub/models--vikhyatk--moondream2/blobs/923ea295017e96fb15774a11a903f99adff3bd4b DELETED
@@ -1,230 +0,0 @@
1
- import torch
2
-
3
- from typing import List, Union, Literal, Optional
4
- from transformers import PreTrainedModel
5
- from PIL import Image
6
-
7
- from .configuration_moondream import PhiConfig
8
- from .configuration_moondream import MoondreamConfig
9
- from .vision_encoder import VisionEncoder
10
- from .region_model import RegionModel
11
- from .modeling_phi import PhiForCausalLM
12
-
13
- class Moondream(PreTrainedModel):
14
- config_class = MoondreamConfig
15
- _supports_flash_attn_2 = True
16
-
17
- def __init__(self, config):
18
- super().__init__(config)
19
- self.vision_encoder = VisionEncoder(
20
- use_flash_attn=config._attn_implementation == "flash_attention_2"
21
- )
22
- self.region_model = RegionModel()
23
-
24
- if type(config.text_config) == dict:
25
- phi_config = PhiConfig(
26
- **config.text_config, attn_implementation=config._attn_implementation
27
- )
28
- else:
29
- phi_config = config.text_config
30
- self.text_model = PhiForCausalLM(phi_config)
31
-
32
- @property
33
- def device(self):
34
- return self.text_model.device
35
-
36
- def encode_image(self, image):
37
- with torch.no_grad():
38
- return self.vision_encoder(image)
39
-
40
- def input_embeds(self, prompt, image_embeds, tokenizer):
41
- def _tokenize(txt):
42
- return tokenizer(
43
- txt, return_tensors="pt", add_special_tokens=False
44
- ).input_ids.to(self.device)
45
-
46
- text_emb = self.text_model.get_input_embeddings()
47
-
48
- # Add BOS token
49
- embeds = []
50
- embeds.append(
51
- text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
52
- )
53
-
54
- if "<image>" not in prompt:
55
- embeds.append(text_emb(_tokenize(prompt)))
56
- else:
57
- assert prompt.count("<image>") == 1
58
- before, after = prompt.split("<image>")
59
- if len(before) > 0:
60
- embeds.append(text_emb(_tokenize(before)))
61
- embeds.append(image_embeds.to(self.device))
62
- if len(after) > 0:
63
- embeds.append(text_emb(_tokenize(after)))
64
-
65
- return torch.cat(embeds, dim=1)
66
-
67
- def get_input_embeddings(self):
68
- return self.text_model.get_input_embeddings()
69
-
70
- def generate(
71
- self,
72
- image_embeds,
73
- prompt,
74
- tokenizer,
75
- max_new_tokens=128,
76
- **kwargs,
77
- ):
78
- generate_config = {
79
- "eos_token_id": tokenizer.eos_token_id,
80
- "bos_token_id": tokenizer.bos_token_id,
81
- "pad_token_id": tokenizer.bos_token_id,
82
- "max_new_tokens": max_new_tokens,
83
- **kwargs,
84
- }
85
-
86
- with torch.no_grad():
87
- inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
88
- attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
89
- output_ids = self.text_model.generate(
90
- inputs_embeds=inputs_embeds,
91
- attention_mask=attention_mask,
92
- **generate_config,
93
- )
94
-
95
- return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
96
-
97
- # Note: Not ready for use yet, intended for September release.
98
- def caption(
99
- self,
100
- images: List[Image.Image],
101
- tokenizer,
102
- length: Optional[Literal["short"]] = None,
103
- **kwargs,
104
- ):
105
- image_embeds = self.encode_image(images)
106
-
107
- templated_prompts = [
108
- f"<image>\n\n{'Short caption' if length == 'short' else 'Caption'}:" for _ in images
109
- ]
110
- inputs_embeds = torch.stack([
111
- self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
112
- for prompt, image_embed in zip(templated_prompts, image_embeds)
113
- ])
114
- attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
115
-
116
- generate_config = {
117
- "eos_token_id": tokenizer.eos_token_id,
118
- "bos_token_id": tokenizer.bos_token_id,
119
- "pad_token_id": tokenizer.bos_token_id,
120
- "repetition_penalty": 1.2,
121
- "max_new_tokens": 512,
122
- **kwargs,
123
- }
124
-
125
- with torch.no_grad():
126
- output_ids = self.text_model.generate(
127
- inputs_embeds=inputs_embeds,
128
- attention_mask=attention_mask,
129
- **generate_config,
130
- )
131
-
132
- return [
133
- x.strip()
134
- for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
135
- ]
136
-
137
- def answer_question(
138
- self,
139
- image_embeds,
140
- question,
141
- tokenizer,
142
- chat_history="",
143
- result_queue=None,
144
- max_new_tokens=256,
145
- **kwargs,
146
- ):
147
- prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
148
- answer = self.generate(
149
- image_embeds,
150
- prompt,
151
- tokenizer=tokenizer,
152
- max_new_tokens=max_new_tokens,
153
- **kwargs,
154
- )[0]
155
- cleaned_answer = answer.strip()
156
-
157
- # Use the result_queue to pass the result if it is provided
158
- if result_queue:
159
- result_queue.put(cleaned_answer)
160
- else:
161
- return cleaned_answer
162
-
163
- def batch_answer(
164
- self,
165
- images,
166
- prompts,
167
- tokenizer,
168
- **kwargs,
169
- ):
170
- image_embeds = self.encode_image(images)
171
-
172
- templated_prompts = [
173
- f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
174
- ]
175
- prompt_embs = [
176
- self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
177
- for prompt, image_embed in zip(templated_prompts, image_embeds)
178
- ]
179
-
180
- bos_emb = prompt_embs[0][0]
181
- max_len = max([p.shape[0] for p in prompt_embs])
182
-
183
- inputs_embeds = torch.cat(
184
- [
185
- torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
186
- for p in prompt_embs
187
- ],
188
- dim=0,
189
- )
190
- attention_mask = torch.cat(
191
- [
192
- torch.cat(
193
- [
194
- torch.zeros(
195
- 1,
196
- max_len - p.shape[0],
197
- device=self.device,
198
- dtype=torch.long,
199
- ),
200
- torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
201
- ],
202
- dim=1,
203
- )
204
- for p in prompt_embs
205
- ],
206
- dim=0,
207
- )
208
-
209
- generate_config = {
210
- "eos_token_id": tokenizer.eos_token_id,
211
- "bos_token_id": tokenizer.bos_token_id,
212
- "pad_token_id": tokenizer.bos_token_id,
213
- "max_new_tokens": 512,
214
- **kwargs,
215
- }
216
-
217
- with torch.no_grad():
218
- output_ids = self.text_model.generate(
219
- inputs_embeds=inputs_embeds,
220
- attention_mask=attention_mask,
221
- **generate_config,
222
- )
223
-
224
- return [
225
- x.strip()
226
- for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
227
- ]
228
-
229
- def detect(self, image: Image.Image, query: str, tokenizer):
230
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/98dd65a59581dac66a3601da9aadd1534f019006 DELETED
@@ -1,325 +0,0 @@
1
- from typing import Union
2
-
3
- import PIL.Image
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import nn
7
- from einops import rearrange
8
- import PIL
9
- from torchvision.transforms.v2 import (
10
- Compose,
11
- Resize,
12
- InterpolationMode,
13
- ToImage,
14
- ToDtype,
15
- Normalize,
16
- )
17
- from transformers.utils import is_flash_attn_2_available
18
-
19
- try:
20
- if is_flash_attn_2_available():
21
- from flash_attn.modules.mha import FlashSelfAttention
22
- else:
23
- FlashSelfAttention = None
24
- except ImportError:
25
- FlashSelfAttention = None
26
-
27
-
28
- class Attention(nn.Module):
29
-
30
- def __init__(self, dim, num_heads=16, use_flash_attn=False):
31
- super().__init__()
32
- assert dim % num_heads == 0, "dim should be divisible by num_heads"
33
-
34
- self.num_heads = num_heads
35
- self.head_dim = dim // num_heads
36
-
37
- self.qkv = nn.Linear(dim, dim * 3)
38
- self.proj = nn.Linear(dim, dim)
39
-
40
- if use_flash_attn and FlashSelfAttention is not None:
41
- self.flash_attn = FlashSelfAttention()
42
- else:
43
- self.flash_attn = None
44
-
45
- torch.nn.init.kaiming_normal_(
46
- self.qkv.weight, mode="fan_in", nonlinearity="relu"
47
- )
48
- torch.nn.init.kaiming_normal_(
49
- self.proj.weight, mode="fan_in", nonlinearity="relu"
50
- )
51
-
52
- def forward(self, x: torch.Tensor) -> torch.Tensor:
53
- if self.flash_attn is not None:
54
- qkv = self.qkv(x)
55
- qkv = rearrange(
56
- qkv, "... (three h d) -> ... three h d", three=3, h=self.num_heads
57
- )
58
- attn_output = self.flash_attn(qkv)
59
- output = rearrange(attn_output, "... h d -> ... (h d)")
60
- output = self.proj(output)
61
- return output
62
- else:
63
- B, N, C = x.shape
64
- qkv = (
65
- self.qkv(x)
66
- .reshape(B, N, 3, self.num_heads, self.head_dim)
67
- .permute(2, 0, 3, 1, 4)
68
- )
69
- q, k, v = qkv.unbind(0)
70
-
71
- x = F.scaled_dot_product_attention(q, k, v)
72
-
73
- x = x.transpose(1, 2).reshape(B, N, C)
74
- x = self.proj(x)
75
- return x
76
-
77
-
78
- class VitBlock(nn.Module):
79
-
80
- def __init__(self, embed_dim, use_flash_attn=False):
81
- super().__init__()
82
- self.attn = Attention(embed_dim, use_flash_attn=use_flash_attn)
83
- self.mlp = MLP(embed_dim, 4304)
84
- self.norm1 = nn.LayerNorm(embed_dim)
85
- self.norm2 = nn.LayerNorm(embed_dim)
86
-
87
- def forward(self, x):
88
- x = x + self.attn(self.norm1(x))
89
- x = x + self.mlp(self.norm2(x))
90
- return x
91
-
92
-
93
- class VisionTransformer(nn.Module):
94
-
95
- def __init__(self, use_flash_attn=False):
96
- super().__init__()
97
-
98
- embed_len = 729
99
- embed_dim = 1152
100
-
101
- self.patch_embed = LinearPatchEmbedding()
102
- self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
103
- self.blocks = nn.Sequential(
104
- *[VitBlock(embed_dim, use_flash_attn=use_flash_attn) for _ in range(27)]
105
- )
106
- self.norm = nn.LayerNorm(embed_dim)
107
-
108
- def forward(self, x):
109
- x = self.patch_embed(x)
110
- x = x + self.pos_embed
111
- for block in self.blocks:
112
- x = block(x)
113
- return self.norm(x)
114
-
115
-
116
- class EncoderWrapper(nn.Module):
117
-
118
- def __init__(self, use_flash_attn=False):
119
- super().__init__()
120
- self.model = nn.ModuleDict({"visual": VisionTransformer(use_flash_attn)})
121
-
122
- def forward(self, x):
123
- return self.model["visual"](x)
124
-
125
-
126
- class LinearPatchEmbedding(nn.Module):
127
-
128
- def __init__(self):
129
- super().__init__()
130
- self.linear = nn.Linear(588, 1152)
131
-
132
- def forward(self, x):
133
- b, c, hp1, wp2 = x.shape
134
- p1, p2 = 14, 14
135
- h, w = hp1 // p1, wp2 // p2
136
- x = x.reshape(b, c, h, p1, w, p2)
137
- x = x.permute(0, 2, 4, 1, 3, 5)
138
- x = x.reshape(b, h * w, c * p1 * p2)
139
-
140
- return self.linear(x)
141
-
142
-
143
- class MLP(nn.Module):
144
- def __init__(
145
- self,
146
- in_features: int,
147
- hidden_features: int = None,
148
- out_features: int = None,
149
- ) -> None:
150
- super().__init__()
151
- out_features = out_features or in_features
152
- hidden_features = hidden_features or in_features
153
- self.fc1 = nn.Linear(in_features, hidden_features)
154
- self.act = nn.GELU(approximate="tanh")
155
- self.fc2 = nn.Linear(hidden_features, out_features)
156
-
157
- torch.nn.init.kaiming_normal_(
158
- self.fc1.weight, mode="fan_in", nonlinearity="relu"
159
- )
160
- torch.nn.init.kaiming_normal_(
161
- self.fc2.weight, mode="fan_in", nonlinearity="relu"
162
- )
163
-
164
- def forward(self, x: torch.Tensor) -> torch.Tensor:
165
- x = self.fc1(x)
166
- x = self.act(x)
167
- x = self.fc2(x)
168
- return x
169
-
170
-
171
- class VisionProjection(nn.Module):
172
- def __init__(self):
173
- super().__init__()
174
-
175
- image_embedding_dim = 1152
176
- model_dim = 2048
177
- hidden_dim = model_dim * 4
178
-
179
- self.mlp = MLP(image_embedding_dim * 2, hidden_dim, model_dim)
180
-
181
- @property
182
- def device(self):
183
- return self.mlp.fc1.weight.device
184
-
185
- def forward(self, x):
186
- return self.mlp(x)
187
-
188
-
189
- def create_patches(image, patch_size=(378, 378)):
190
- assert image.dim() == 3, "Image must be in CHW format"
191
-
192
- _, height, width = image.shape # Channels, Height, Width
193
- patch_height, patch_width = patch_size
194
-
195
- if height == patch_height and width == patch_width:
196
- return []
197
-
198
- # Iterate over the image and create patches
199
- patches = []
200
- for i in range(0, height, patch_height):
201
- row_patches = []
202
- for j in range(0, width, patch_width):
203
- patch = image[:, i : i + patch_height, j : j + patch_width]
204
- row_patches.append(patch)
205
- patches.append(torch.stack(row_patches))
206
- return patches
207
-
208
-
209
- class VisionEncoder(nn.Module):
210
-
211
- def __init__(self, use_flash_attn=False):
212
- super().__init__()
213
-
214
- self.encoder = EncoderWrapper(use_flash_attn)
215
- self.projection = VisionProjection()
216
- self.supported_sizes = [(378, 378), (378, 756), (756, 378), (756, 756)]
217
-
218
- @property
219
- def device(self):
220
- return self.projection.mlp.fc1.weight.device
221
-
222
- @property
223
- def dtype(self):
224
- return self.projection.mlp.fc1.weight.dtype
225
-
226
- def preprocess(self, image: PIL.Image.Image):
227
- width, height = image.size
228
- max_dim = max(width, height)
229
- if max_dim < 512:
230
- im_size = (378, 378)
231
- else:
232
- aspect_ratio = width / height
233
- im_size = min(
234
- self.supported_sizes,
235
- key=lambda size: (
236
- abs((size[1] / size[0]) - aspect_ratio),
237
- abs(size[0] - width) + abs(size[1] - height),
238
- ),
239
- )
240
-
241
- return Compose(
242
- [
243
- Resize(size=im_size, interpolation=InterpolationMode.BICUBIC),
244
- ToImage(),
245
- ToDtype(torch.float32, scale=True),
246
- Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
247
- ]
248
- )(image)
249
-
250
- def forward(
251
- self, images: Union[PIL.Image.Image, list[PIL.Image.Image], torch.Tensor]
252
- ) -> torch.Tensor:
253
- im_list = None
254
- if isinstance(images, torch.Tensor):
255
- # Input must have dimensions (B, C, H, W)
256
- assert (
257
- len(images.shape) == 4
258
- ), "Tensor input must have dimensions (B, C, H, W)"
259
- im_list = list(images)
260
- elif isinstance(images, PIL.Image.Image):
261
- im_list = [images]
262
- elif isinstance(images, list):
263
- im_list = images
264
- else:
265
- raise ValueError(
266
- "Input must be a PIL image, list of PIL images, or a tensor"
267
- )
268
-
269
- # Preprocess unless the images are already tensors (indicating that
270
- # they have already been preprocessed)
271
- if not isinstance(im_list[0], torch.Tensor):
272
- im_list = [self.preprocess(im.convert("RGB")) for im in im_list]
273
-
274
- patches = [create_patches(im) for im in im_list]
275
- flat_patches = [patch for image_patches in patches for patch in image_patches]
276
-
277
- # Images may be variable size, and need to be resized to a common size after
278
- # creating patches.
279
- resized_images = [
280
- F.interpolate(im.unsqueeze(0), size=(378, 378), mode="bilinear")
281
- for im in im_list
282
- ]
283
-
284
- combined_images = torch.cat([*resized_images, *flat_patches], dim=0)
285
- combined_images = combined_images.to(self.device, dtype=self.dtype)
286
-
287
- combined_features = self.encoder(combined_images)
288
-
289
- full_img_features = combined_features[: len(im_list)]
290
- patch_features = (
291
- combined_features[len(im_list) :].transpose(1, 2).view(-1, 1152, 27, 27)
292
- )
293
-
294
- # Reshape patch features back to their original structure
295
- reshaped_patch_features = []
296
- patch_idx = 0
297
- for i, patch_set in enumerate(patches):
298
- if len(patch_set) == 0:
299
- reshaped_patch_features.append(
300
- full_img_features[i].transpose(0, 1).view(1152, 27, 27)
301
- )
302
- else:
303
- sample_features = []
304
- for row_patches in patch_set:
305
- row_len = len(row_patches)
306
- row_features = patch_features[
307
- patch_idx : patch_idx + row_len
308
- ] # row_len, T, C
309
- row_features = torch.cat(
310
- list(row_features), dim=2
311
- ) # T, C * row_len
312
- patch_idx += row_len
313
- sample_features.append(row_features)
314
- sample_features = torch.cat(sample_features, dim=1)
315
- sample_features = F.interpolate(
316
- sample_features.unsqueeze(0), size=(27, 27), mode="bilinear"
317
- ).squeeze(0)
318
- reshaped_patch_features.append(sample_features)
319
- reshaped_patch_features = (
320
- torch.stack(reshaped_patch_features).view(-1, 1152, 729).transpose(1, 2)
321
- )
322
-
323
- final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)
324
-
325
- return self.projection(final_features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/a4878a2253d32f2dcd950cde16ebedffb9644ae6 DELETED
@@ -1,1463 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """PyTorch Phi model."""
17
-
18
- import math
19
- from typing import List, Optional, Tuple, Union
20
-
21
- import torch
22
- import torch.utils.checkpoint
23
- from packaging import version
24
- from torch import nn
25
- from torch.nn import CrossEntropyLoss
26
-
27
- from transformers.activations import ACT2FN
28
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
29
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
30
- from transformers.modeling_outputs import (
31
- BaseModelOutputWithPast,
32
- CausalLMOutputWithPast,
33
- )
34
- from transformers.modeling_utils import PreTrainedModel
35
- from transformers.utils import (
36
- add_start_docstrings,
37
- add_start_docstrings_to_model_forward,
38
- get_torch_version,
39
- is_flash_attn_2_available,
40
- is_flash_attn_greater_or_equal_2_10,
41
- is_torchdynamo_compiling,
42
- logging,
43
- replace_return_docstrings,
44
- )
45
- from .configuration_moondream import PhiConfig
46
-
47
-
48
- if is_flash_attn_2_available():
49
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
50
-
51
-
52
- logger = logging.get_logger(__name__)
53
-
54
- _CONFIG_FOR_DOC = "PhiConfig"
55
-
56
-
57
- # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
58
- def _prepare_4d_causal_attention_mask_with_cache_position(
59
- attention_mask: torch.Tensor,
60
- sequence_length: int,
61
- target_length: int,
62
- dtype: torch.dtype,
63
- device: torch.device,
64
- min_dtype: float,
65
- cache_position: torch.Tensor,
66
- batch_size: int,
67
- ):
68
- """
69
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
70
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
71
-
72
- Args:
73
- attention_mask (`torch.Tensor`):
74
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
75
- sequence_length (`int`):
76
- The sequence length being processed.
77
- target_length (`int`):
78
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
79
- dtype (`torch.dtype`):
80
- The dtype to use for the 4D attention mask.
81
- device (`torch.device`):
82
- The device to plcae the 4D attention mask on.
83
- min_dtype (`float`):
84
- The minimum value representable with the dtype `dtype`.
85
- cache_position (`torch.Tensor`):
86
- Indices depicting the position of the input sequence tokens in the sequence.
87
- batch_size (`torch.Tensor`):
88
- Batch size.
89
- """
90
- if attention_mask is not None and attention_mask.dim() == 4:
91
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
92
- causal_mask = attention_mask
93
- else:
94
- causal_mask = torch.full(
95
- (sequence_length, target_length),
96
- fill_value=min_dtype,
97
- dtype=dtype,
98
- device=device,
99
- )
100
- if sequence_length != 1:
101
- causal_mask = torch.triu(causal_mask, diagonal=1)
102
- causal_mask *= torch.arange(
103
- target_length, device=device
104
- ) > cache_position.reshape(-1, 1)
105
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
106
- if attention_mask is not None:
107
- causal_mask = (
108
- causal_mask.clone()
109
- ) # copy to contiguous memory for in-place edit
110
- mask_length = attention_mask.shape[-1]
111
- padding_mask = (
112
- causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
113
- )
114
- padding_mask = padding_mask == 0
115
- causal_mask[:, :, :, :mask_length] = causal_mask[
116
- :, :, :, :mask_length
117
- ].masked_fill(padding_mask, min_dtype)
118
-
119
- return causal_mask
120
-
121
-
122
- # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
123
- class PhiRotaryEmbedding(nn.Module):
124
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
125
- super().__init__()
126
-
127
- self.dim = dim
128
- self.max_position_embeddings = max_position_embeddings
129
- self.base = base
130
- inv_freq = 1.0 / (
131
- self.base
132
- ** (
133
- torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
134
- / self.dim
135
- )
136
- )
137
- self.register_buffer("inv_freq", inv_freq, persistent=False)
138
-
139
- # Build here to make `torch.jit.trace` work.
140
- self._set_cos_sin_cache(
141
- seq_len=max_position_embeddings,
142
- device=self.inv_freq.device,
143
- dtype=torch.get_default_dtype(),
144
- )
145
-
146
- def _set_cos_sin_cache(self, seq_len, device, dtype):
147
- self.max_seq_len_cached = seq_len
148
- t = torch.arange(
149
- self.max_seq_len_cached, device=device, dtype=torch.int64
150
- ).type_as(self.inv_freq)
151
-
152
- freqs = torch.outer(t, self.inv_freq)
153
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
154
- emb = torch.cat((freqs, freqs), dim=-1)
155
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
156
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
157
-
158
- def forward(self, x, seq_len=None):
159
- # x: [bs, num_attention_heads, seq_len, head_size]
160
- if seq_len > self.max_seq_len_cached:
161
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
162
-
163
- return (
164
- self.cos_cached[:seq_len].to(dtype=x.dtype),
165
- self.sin_cached[:seq_len].to(dtype=x.dtype),
166
- )
167
-
168
-
169
- # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
170
- class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
171
- """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
-
173
- def __init__(
174
- self,
175
- dim,
176
- max_position_embeddings=2048,
177
- base=10000,
178
- device=None,
179
- scaling_factor=1.0,
180
- ):
181
- self.scaling_factor = scaling_factor
182
- super().__init__(dim, max_position_embeddings, base, device)
183
-
184
- def _set_cos_sin_cache(self, seq_len, device, dtype):
185
- self.max_seq_len_cached = seq_len
186
- t = torch.arange(
187
- self.max_seq_len_cached, device=device, dtype=torch.int64
188
- ).type_as(self.inv_freq)
189
- t = t / self.scaling_factor
190
-
191
- freqs = torch.outer(t, self.inv_freq)
192
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
- emb = torch.cat((freqs, freqs), dim=-1)
194
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
195
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
196
-
197
-
198
- # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
199
- class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
200
- """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
201
-
202
- def __init__(
203
- self,
204
- dim,
205
- max_position_embeddings=2048,
206
- base=10000,
207
- device=None,
208
- scaling_factor=1.0,
209
- ):
210
- self.scaling_factor = scaling_factor
211
- super().__init__(dim, max_position_embeddings, base, device)
212
-
213
- def _set_cos_sin_cache(self, seq_len, device, dtype):
214
- self.max_seq_len_cached = seq_len
215
-
216
- if seq_len > self.max_position_embeddings:
217
- base = self.base * (
218
- (self.scaling_factor * seq_len / self.max_position_embeddings)
219
- - (self.scaling_factor - 1)
220
- ) ** (self.dim / (self.dim - 2))
221
- inv_freq = 1.0 / (
222
- base
223
- ** (
224
- torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
225
- / self.dim
226
- )
227
- )
228
- self.register_buffer("inv_freq", inv_freq, persistent=False)
229
-
230
- t = torch.arange(
231
- self.max_seq_len_cached, device=device, dtype=torch.int64
232
- ).type_as(self.inv_freq)
233
-
234
- freqs = torch.outer(t, self.inv_freq)
235
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
236
- emb = torch.cat((freqs, freqs), dim=-1)
237
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
238
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
239
-
240
-
241
- # Copied from transformers.models.llama.modeling_llama.rotate_half
242
- def rotate_half(x):
243
- """Rotates half the hidden dims of the input."""
244
- x1 = x[..., : x.shape[-1] // 2]
245
- x2 = x[..., x.shape[-1] // 2 :]
246
- return torch.cat((-x2, x1), dim=-1)
247
-
248
-
249
- # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
250
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
251
- """Applies Rotary Position Embedding to the query and key tensors.
252
-
253
- Args:
254
- q (`torch.Tensor`): The query tensor.
255
- k (`torch.Tensor`): The key tensor.
256
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
257
- sin (`torch.Tensor`): The sine part of the rotary embedding.
258
- position_ids (`torch.Tensor`):
259
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
260
- used to pass offsetted position ids when working with a KV-cache.
261
- unsqueeze_dim (`int`, *optional*, defaults to 1):
262
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
263
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
264
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
265
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
266
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
267
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
268
- Returns:
269
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
270
- """
271
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
272
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
273
- q_embed = (q * cos) + (rotate_half(q) * sin)
274
- k_embed = (k * cos) + (rotate_half(k) * sin)
275
- return q_embed, k_embed
276
-
277
-
278
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
279
- class PhiMLP(nn.Module):
280
- def __init__(self, config):
281
- super().__init__()
282
- self.config = config
283
- self.activation_fn = ACT2FN[config.hidden_act]
284
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
285
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
286
-
287
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
288
- hidden_states = self.fc1(hidden_states)
289
- hidden_states = self.activation_fn(hidden_states)
290
- hidden_states = self.fc2(hidden_states)
291
- return hidden_states
292
-
293
-
294
- # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
295
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
296
- """
297
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
298
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
299
- """
300
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
301
- if n_rep == 1:
302
- return hidden_states
303
- hidden_states = hidden_states[:, :, None, :, :].expand(
304
- batch, num_key_value_heads, n_rep, slen, head_dim
305
- )
306
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
307
-
308
-
309
- class PhiAttention(nn.Module):
310
- """Multi-headed attention from 'Attention Is All You Need' paper"""
311
-
312
- def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
313
- super().__init__()
314
- self.config = config
315
- self.layer_idx = layer_idx
316
- if layer_idx is None:
317
- logger.warning_once(
318
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
319
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
320
- "when creating this class."
321
- )
322
-
323
- self.attention_dropout = config.attention_dropout
324
- self.hidden_size = config.hidden_size
325
- self.num_heads = config.num_attention_heads
326
- self.head_dim = self.hidden_size // self.num_heads
327
- self.num_key_value_heads = config.num_key_value_heads
328
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
329
- self.max_position_embeddings = config.max_position_embeddings
330
- self.rope_theta = config.rope_theta
331
- self.partial_rotary_factor = config.partial_rotary_factor
332
- self.is_causal = True
333
-
334
- if (self.head_dim * self.num_heads) != self.hidden_size:
335
- raise ValueError(
336
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
337
- f" and `num_heads`: {self.num_heads})."
338
- )
339
-
340
- self.Wqkv = nn.Linear(
341
- self.hidden_size, 3 * self.num_heads * self.head_dim, bias=True
342
- )
343
- self.out_proj = nn.Linear(
344
- self.num_heads * self.head_dim, self.hidden_size, bias=True
345
- )
346
-
347
- self._init_rope()
348
-
349
- def _init_rope(self):
350
- if self.config.rope_scaling is None:
351
- self.rotary_emb = PhiRotaryEmbedding(
352
- int(self.partial_rotary_factor * self.head_dim),
353
- max_position_embeddings=self.max_position_embeddings,
354
- base=self.rope_theta,
355
- )
356
- else:
357
- scaling_type = self.config.rope_scaling["type"]
358
- scaling_factor = self.config.rope_scaling["factor"]
359
- if scaling_type == "linear":
360
- self.rotary_emb = PhiLinearScalingRotaryEmbedding(
361
- int(self.partial_rotary_factor * self.head_dim),
362
- max_position_embeddings=self.max_position_embeddings,
363
- scaling_factor=scaling_factor,
364
- base=self.rope_theta,
365
- )
366
- elif scaling_type == "dynamic":
367
- self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
368
- int(self.partial_rotary_factor * self.head_dim),
369
- max_position_embeddings=self.max_position_embeddings,
370
- scaling_factor=scaling_factor,
371
- base=self.rope_theta,
372
- )
373
- else:
374
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
375
-
376
- def forward(
377
- self,
378
- hidden_states: torch.Tensor,
379
- attention_mask: Optional[torch.Tensor] = None,
380
- position_ids: Optional[torch.LongTensor] = None,
381
- past_key_value: Optional[Cache] = None,
382
- output_attentions: bool = False,
383
- use_cache: bool = False,
384
- cache_position: Optional[torch.LongTensor] = None,
385
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
386
- bsz, q_len, _ = hidden_states.size()
387
-
388
- query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
389
- 3, dim=-1
390
- )
391
-
392
- query_states = query_states.view(
393
- bsz, q_len, self.num_heads, self.head_dim
394
- ).transpose(1, 2)
395
- key_states = key_states.view(
396
- bsz, q_len, self.num_key_value_heads, self.head_dim
397
- ).transpose(1, 2)
398
- value_states = value_states.view(
399
- bsz, q_len, self.num_key_value_heads, self.head_dim
400
- ).transpose(1, 2)
401
-
402
- kv_seq_len = key_states.shape[-2]
403
- if past_key_value is not None:
404
- if self.layer_idx is None:
405
- raise ValueError(
406
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
407
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
408
- "with a layer index."
409
- )
410
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
411
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
412
-
413
- # Partial rotary embedding
414
- query_rot, query_pass = (
415
- query_states[..., : self.rotary_emb.dim],
416
- query_states[..., self.rotary_emb.dim :],
417
- )
418
- key_rot, key_pass = (
419
- key_states[..., : self.rotary_emb.dim],
420
- key_states[..., self.rotary_emb.dim :],
421
- )
422
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
423
- query_rot, key_rot = apply_rotary_pos_emb(
424
- query_rot, key_rot, cos, sin, position_ids
425
- )
426
-
427
- # [batch_size, seq_length, num_heads, head_dim]
428
- query_states = torch.cat((query_rot, query_pass), dim=-1)
429
- key_states = torch.cat((key_rot, key_pass), dim=-1)
430
-
431
- if past_key_value is not None:
432
- cache_kwargs = {
433
- "sin": sin,
434
- "cos": cos,
435
- "partial_rotation_size": self.rotary_emb.dim,
436
- "cache_position": cache_position,
437
- }
438
- key_states, value_states = past_key_value.update(
439
- key_states, value_states, self.layer_idx, cache_kwargs
440
- )
441
-
442
- key_states = repeat_kv(key_states, self.num_key_value_groups)
443
- value_states = repeat_kv(value_states, self.num_key_value_groups)
444
-
445
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
446
- attn_weights = torch.matmul(
447
- query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
448
- ) / math.sqrt(self.head_dim)
449
-
450
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
451
- raise ValueError(
452
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
453
- f" {attn_weights.size()}"
454
- )
455
-
456
- if attention_mask is not None:
457
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
458
- attn_weights += causal_mask
459
-
460
- # upcast attention to fp32
461
- attn_weights = nn.functional.softmax(
462
- attn_weights, dim=-1, dtype=torch.float32
463
- ).to(value_states.dtype)
464
- attn_weights = nn.functional.dropout(
465
- attn_weights, p=self.attention_dropout, training=self.training
466
- )
467
-
468
- attn_output = torch.matmul(attn_weights, value_states)
469
-
470
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
471
- raise ValueError(
472
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
473
- f" {attn_output.size()}"
474
- )
475
-
476
- attn_output = attn_output.transpose(1, 2).contiguous()
477
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
478
-
479
- attn_output = self.out_proj(attn_output)
480
-
481
- if not output_attentions:
482
- attn_weights = None
483
-
484
- return attn_output, attn_weights, past_key_value
485
-
486
-
487
- class PhiFlashAttention2(PhiAttention):
488
- """
489
- Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
490
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
491
- flash attention and deal with padding tokens in case the input contains any of them.
492
- """
493
-
494
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
495
- def __init__(self, *args, **kwargs):
496
- super().__init__(*args, **kwargs)
497
-
498
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
499
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
500
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
501
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
502
-
503
- def forward(
504
- self,
505
- hidden_states: torch.Tensor,
506
- attention_mask: Optional[torch.LongTensor] = None,
507
- position_ids: Optional[torch.LongTensor] = None,
508
- past_key_value: Optional[Cache] = None,
509
- output_attentions: bool = False,
510
- use_cache: bool = False,
511
- cache_position: Optional[torch.LongTensor] = None,
512
- **kwargs,
513
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
514
- # PhiFlashAttention2 attention does not support output_attentions
515
-
516
- output_attentions = False
517
-
518
- bsz, q_len, _ = hidden_states.size()
519
-
520
- query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
521
- 3, dim=-1
522
- )
523
-
524
- # Flash attention requires the input to have the shape
525
- # batch_size x seq_length x head_dim x hidden_dim
526
- # therefore we just need to keep the original shape
527
- query_states = query_states.view(
528
- bsz, q_len, self.num_heads, self.head_dim
529
- ).transpose(1, 2)
530
- key_states = key_states.view(
531
- bsz, q_len, self.num_key_value_heads, self.head_dim
532
- ).transpose(1, 2)
533
- value_states = value_states.view(
534
- bsz, q_len, self.num_key_value_heads, self.head_dim
535
- ).transpose(1, 2)
536
-
537
- kv_seq_len = key_states.shape[-2]
538
- if past_key_value is not None:
539
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
540
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
541
-
542
- # Partial rotary embedding
543
- query_rot, query_pass = (
544
- query_states[..., : self.rotary_emb.dim],
545
- query_states[..., self.rotary_emb.dim :],
546
- )
547
- key_rot, key_pass = (
548
- key_states[..., : self.rotary_emb.dim],
549
- key_states[..., self.rotary_emb.dim :],
550
- )
551
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
552
- query_rot, key_rot = apply_rotary_pos_emb(
553
- query_rot, key_rot, cos, sin, position_ids
554
- )
555
-
556
- # [batch_size, seq_length, num_heads, head_dim]
557
- query_states = torch.cat((query_rot, query_pass), dim=-1)
558
- key_states = torch.cat((key_rot, key_pass), dim=-1)
559
-
560
- if past_key_value is not None:
561
- cache_kwargs = {
562
- "sin": sin,
563
- "cos": cos,
564
- "partial_rotation_size": self.rotary_emb.dim,
565
- "cache_position": cache_position,
566
- }
567
- key_states, value_states = past_key_value.update(
568
- key_states, value_states, self.layer_idx, cache_kwargs
569
- )
570
-
571
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
572
- # to be able to avoid many of these transpose/reshape/view.
573
- query_states = query_states.transpose(1, 2)
574
- key_states = key_states.transpose(1, 2)
575
- value_states = value_states.transpose(1, 2)
576
-
577
- attn_dropout = self.attention_dropout if self.training else 0.0
578
-
579
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
580
- # therefore the input hidden states gets silently casted in float32. Hence, we need
581
- # cast them back in the correct dtype just to be sure everything works as expected.
582
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
583
- # in fp32.
584
-
585
- if query_states.dtype == torch.float32:
586
- if torch.is_autocast_enabled():
587
- target_dtype = torch.get_autocast_gpu_dtype()
588
- # Handle the case where the model is quantized
589
- elif hasattr(self.config, "_pre_quantization_dtype"):
590
- target_dtype = self.config._pre_quantization_dtype
591
- else:
592
- target_dtype = self.q_proj.weight.dtype
593
-
594
- logger.warning_once(
595
- f"The input hidden states seems to be silently casted in float32, this might be related to"
596
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
597
- f" {target_dtype}."
598
- )
599
-
600
- query_states = query_states.to(target_dtype)
601
- key_states = key_states.to(target_dtype)
602
- value_states = value_states.to(target_dtype)
603
-
604
- attn_output = _flash_attention_forward(
605
- query_states,
606
- key_states,
607
- value_states,
608
- attention_mask,
609
- q_len,
610
- position_ids=position_ids,
611
- dropout=attn_dropout,
612
- softmax_scale=None,
613
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
614
- is_causal=self.is_causal,
615
- )
616
-
617
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
618
- attn_output = self.out_proj(attn_output)
619
-
620
- if not output_attentions:
621
- attn_weights = None
622
-
623
- return attn_output, attn_weights, past_key_value
624
-
625
-
626
- class PhiSdpaAttention(PhiAttention):
627
- def __init__(self, *args, **kwargs):
628
- super().__init__(*args, **kwargs)
629
- self.require_contiguous_qkv = version.parse(
630
- get_torch_version()
631
- ) < version.parse("2.2.0")
632
-
633
- """
634
- SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
635
- `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
636
- SDPA API.
637
- """
638
-
639
- # Adapted from PhiAttention.forward
640
- def forward(
641
- self,
642
- hidden_states: torch.Tensor,
643
- attention_mask: Optional[torch.Tensor] = None,
644
- position_ids: Optional[torch.LongTensor] = None,
645
- past_key_value: Optional[Cache] = None,
646
- output_attentions: bool = False,
647
- use_cache: bool = False,
648
- cache_position: Optional[torch.LongTensor] = None,
649
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
650
- if output_attentions:
651
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
652
- logger.warning_once(
653
- "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
654
- "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
655
- "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
656
- 'be removed using the argument `attn_implementation="eager"` when loading the model.'
657
- )
658
- return super().forward(
659
- hidden_states=hidden_states,
660
- attention_mask=attention_mask,
661
- position_ids=position_ids,
662
- past_key_value=past_key_value,
663
- output_attentions=output_attentions,
664
- use_cache=use_cache,
665
- )
666
-
667
- bsz, q_len, _ = hidden_states.size()
668
-
669
- query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
670
- 3, dim=-1
671
- )
672
-
673
- query_states = query_states.view(
674
- bsz, q_len, self.num_heads, self.head_dim
675
- ).transpose(1, 2)
676
- key_states = key_states.view(
677
- bsz, q_len, self.num_key_value_heads, self.head_dim
678
- ).transpose(1, 2)
679
- value_states = value_states.view(
680
- bsz, q_len, self.num_key_value_heads, self.head_dim
681
- ).transpose(1, 2)
682
-
683
- kv_seq_len = key_states.shape[-2]
684
- if past_key_value is not None:
685
- if self.layer_idx is None:
686
- raise ValueError(
687
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
688
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
689
- "with a layer index."
690
- )
691
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
692
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
693
-
694
- # Partial rotary embedding
695
- query_rot, query_pass = (
696
- query_states[..., : self.rotary_emb.dim],
697
- query_states[..., self.rotary_emb.dim :],
698
- )
699
- key_rot, key_pass = (
700
- key_states[..., : self.rotary_emb.dim],
701
- key_states[..., self.rotary_emb.dim :],
702
- )
703
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
704
- query_rot, key_rot = apply_rotary_pos_emb(
705
- query_rot, key_rot, cos, sin, position_ids
706
- )
707
-
708
- # [batch_size, seq_length, num_heads, head_dim]
709
- query_states = torch.cat((query_rot, query_pass), dim=-1)
710
- key_states = torch.cat((key_rot, key_pass), dim=-1)
711
-
712
- if past_key_value is not None:
713
- cache_kwargs = {
714
- "sin": sin,
715
- "cos": cos,
716
- "partial_rotation_size": self.rotary_emb.dim,
717
- "cache_position": cache_position,
718
- }
719
- key_states, value_states = past_key_value.update(
720
- key_states, value_states, self.layer_idx, cache_kwargs
721
- )
722
-
723
- key_states = repeat_kv(key_states, self.num_key_value_groups)
724
- value_states = repeat_kv(value_states, self.num_key_value_groups)
725
-
726
- causal_mask = attention_mask
727
- if attention_mask is not None:
728
- causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
729
-
730
- # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
731
- # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
732
- # Reference: https://github.com/pytorch/pytorch/issues/112577
733
- if (
734
- self.require_contiguous_qkv
735
- and query_states.device.type == "cuda"
736
- and attention_mask is not None
737
- ):
738
- query_states = query_states.contiguous()
739
- key_states = key_states.contiguous()
740
- value_states = value_states.contiguous()
741
-
742
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
743
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
744
- is_causal = True if causal_mask is None and q_len > 1 else False
745
-
746
- attn_output = torch.nn.functional.scaled_dot_product_attention(
747
- query_states,
748
- key_states,
749
- value_states,
750
- attn_mask=causal_mask,
751
- dropout_p=self.attention_dropout if self.training else 0.0,
752
- is_causal=is_causal,
753
- )
754
-
755
- attn_output = attn_output.transpose(1, 2).contiguous()
756
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
757
-
758
- attn_output = self.out_proj(attn_output)
759
-
760
- return attn_output, None, past_key_value
761
-
762
-
763
- PHI_ATTENTION_CLASSES = {
764
- "eager": PhiAttention,
765
- "flash_attention_2": PhiFlashAttention2,
766
- "sdpa": PhiSdpaAttention,
767
- }
768
-
769
-
770
- class PhiDecoderLayer(nn.Module):
771
- def __init__(self, config: PhiConfig, layer_idx: int):
772
- super().__init__()
773
- self.mixer = PHI_ATTENTION_CLASSES[config._attn_implementation](
774
- config, layer_idx=layer_idx
775
- )
776
- self.mlp = PhiMLP(config)
777
- self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
778
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
779
-
780
- def forward(
781
- self,
782
- hidden_states: torch.Tensor,
783
- attention_mask: Optional[torch.Tensor] = None,
784
- position_ids: Optional[torch.LongTensor] = None,
785
- output_attentions: Optional[bool] = False,
786
- use_cache: Optional[bool] = False,
787
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
788
- cache_position: Optional[torch.LongTensor] = None,
789
- **kwargs,
790
- ) -> Tuple[
791
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
792
- ]:
793
- """
794
- Args:
795
- hidden_states (`torch.FloatTensor`):
796
- input to the layer of shape `(batch, seq_len, embed_dim)`
797
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
798
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
799
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
800
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
801
- `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
802
- output_attentions (`bool`, *optional*):
803
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
804
- returned tensors for more detail.
805
- use_cache (`bool`, *optional*):
806
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
807
- (see `past_key_values`).
808
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
809
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
810
- Indices depicting the position of the input sequence tokens in the sequence
811
- kwargs (`dict`, *optional*):
812
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
813
- into the model
814
- """
815
-
816
- residual = hidden_states
817
-
818
- hidden_states = self.ln(hidden_states)
819
-
820
- # Self Attention
821
- attn_outputs, self_attn_weights, present_key_value = self.mixer(
822
- hidden_states=hidden_states,
823
- attention_mask=attention_mask,
824
- position_ids=position_ids,
825
- past_key_value=past_key_value,
826
- output_attentions=output_attentions,
827
- use_cache=use_cache,
828
- cache_position=cache_position,
829
- )
830
- attn_outputs = self.resid_dropout(attn_outputs)
831
-
832
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
833
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
834
- outputs = (hidden_states,)
835
-
836
- if output_attentions:
837
- outputs += (self_attn_weights,)
838
-
839
- if use_cache:
840
- outputs += (present_key_value,)
841
-
842
- return outputs
843
-
844
-
845
- PHI_START_DOCSTRING = r"""
846
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
847
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
848
- etc.)
849
-
850
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
851
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
852
- and behavior.
853
-
854
- Parameters:
855
- config ([`PhiConfig`]):
856
- Model configuration class with all the parameters of the model. Initializing with a config file does not
857
- load the weights associated with the model, only the configuration. Check out the
858
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
859
- """
860
-
861
-
862
- @add_start_docstrings(
863
- "The bare Phi Model outputting raw hidden-states without any specific head on top.",
864
- PHI_START_DOCSTRING,
865
- )
866
- class PhiPreTrainedModel(PreTrainedModel):
867
- config_class = PhiConfig
868
- base_model_prefix = "model"
869
- supports_gradient_checkpointing = True
870
- _no_split_modules = ["PhiDecoderLayer"]
871
- _skip_keys_device_placement = "past_key_values"
872
- _supports_flash_attn_2 = True
873
- _supports_sdpa = True
874
- _supports_cache_class = True
875
-
876
- def _init_weights(self, module):
877
- std = self.config.initializer_range
878
- if isinstance(module, nn.Linear):
879
- module.weight.data.normal_(mean=0.0, std=std)
880
- if module.bias is not None:
881
- module.bias.data.zero_()
882
- elif isinstance(module, nn.Embedding):
883
- module.weight.data.normal_(mean=0.0, std=std)
884
- if module.padding_idx is not None:
885
- module.weight.data[module.padding_idx].zero_()
886
-
887
-
888
- class Embedding(nn.Module):
889
- def __init__(self, config: PhiConfig):
890
- super().__init__()
891
- self.wte = nn.Embedding(
892
- config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
893
- )
894
-
895
- def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
896
- return self.wte(input_ids)
897
-
898
- PHI_INPUTS_DOCSTRING = r"""
899
- Args:
900
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
901
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
902
- it.
903
-
904
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
905
- [`PreTrainedTokenizer.__call__`] for details.
906
-
907
- [What are input IDs?](../glossary#input-ids)
908
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
909
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
910
-
911
- - 1 for tokens that are **not masked**,
912
- - 0 for tokens that are **masked**.
913
-
914
- [What are attention masks?](../glossary#attention-mask)
915
-
916
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
- [`PreTrainedTokenizer.__call__`] for details.
918
-
919
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
920
- `past_key_values`).
921
-
922
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
923
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
924
- information on the default strategy.
925
-
926
- - 1 indicates the head is **not masked**,
927
- - 0 indicates the head is **masked**.
928
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
929
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
930
- config.n_positions - 1]`.
931
-
932
- [What are position IDs?](../glossary#position-ids)
933
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
934
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
935
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
936
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
937
-
938
- Two formats are allowed:
939
- - a [`~cache_utils.Cache`] instance;
940
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
941
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
942
- cache format.
943
-
944
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
945
- legacy cache format will be returned.
946
-
947
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
948
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
949
- of shape `(batch_size, sequence_length)`.
950
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
951
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
952
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
953
- model's internal embedding lookup matrix.
954
- use_cache (`bool`, *optional*):
955
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
956
- `past_key_values`).
957
- output_attentions (`bool`, *optional*):
958
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
959
- tensors for more detail.
960
- output_hidden_states (`bool`, *optional*):
961
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
962
- more detail.
963
- return_dict (`bool`, *optional*):
964
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
965
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
966
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
967
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
968
- the complete sequence length.
969
- """
970
-
971
-
972
- @add_start_docstrings(
973
- "The bare Phi Model outputting raw hidden-states without any specific head on top.",
974
- PHI_START_DOCSTRING,
975
- )
976
- class PhiModel(PhiPreTrainedModel):
977
- """
978
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
979
-
980
- Args:
981
- config: PhiConfig
982
- """
983
-
984
- def __init__(self, config: PhiConfig):
985
- super().__init__(config)
986
- self.padding_idx = config.pad_token_id
987
- self.vocab_size = config.vocab_size
988
-
989
- self.embd = Embedding(config)
990
- self.embed_dropout = nn.Dropout(config.embd_pdrop)
991
- self.h = nn.ModuleList(
992
- [
993
- PhiDecoderLayer(config, layer_idx)
994
- for layer_idx in range(config.num_hidden_layers)
995
- ]
996
- )
997
-
998
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
999
- self._use_sdpa = config._attn_implementation == "sdpa"
1000
-
1001
- self.gradient_checkpointing = False
1002
- # Initialize weights and apply final processing
1003
- self.post_init()
1004
-
1005
- def get_input_embeddings(self):
1006
- return self.embd.wte
1007
-
1008
- def set_input_embeddings(self, value):
1009
- self.embd.wte = value
1010
-
1011
- @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1012
- def forward(
1013
- self,
1014
- input_ids: torch.LongTensor = None,
1015
- attention_mask: Optional[torch.Tensor] = None,
1016
- position_ids: Optional[torch.LongTensor] = None,
1017
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1018
- inputs_embeds: Optional[torch.FloatTensor] = None,
1019
- use_cache: Optional[bool] = None,
1020
- output_attentions: Optional[bool] = None,
1021
- output_hidden_states: Optional[bool] = None,
1022
- return_dict: Optional[bool] = None,
1023
- cache_position: Optional[torch.LongTensor] = None,
1024
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1025
- output_attentions = (
1026
- output_attentions
1027
- if output_attentions is not None
1028
- else self.config.output_attentions
1029
- )
1030
- output_hidden_states = (
1031
- output_hidden_states
1032
- if output_hidden_states is not None
1033
- else self.config.output_hidden_states
1034
- )
1035
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1036
-
1037
- return_dict = (
1038
- return_dict if return_dict is not None else self.config.use_return_dict
1039
- )
1040
-
1041
- if (input_ids is None) ^ (inputs_embeds is not None):
1042
- raise ValueError(
1043
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1044
- )
1045
-
1046
- if self.gradient_checkpointing and self.training:
1047
- if use_cache:
1048
- logger.warning_once(
1049
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1050
- )
1051
- use_cache = False
1052
-
1053
- use_legacy_cache = False
1054
- if use_cache and not isinstance(past_key_values, Cache) and not self.training:
1055
- use_legacy_cache = True
1056
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1057
- logger.warning_once(
1058
- "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1059
- "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
1060
- )
1061
-
1062
- if inputs_embeds is None:
1063
- inputs_embeds = self.embd(input_ids)
1064
-
1065
- if cache_position is None:
1066
- past_seen_tokens = (
1067
- past_key_values.get_seq_length() if past_key_values is not None else 0
1068
- )
1069
- cache_position = torch.arange(
1070
- past_seen_tokens,
1071
- past_seen_tokens + inputs_embeds.shape[1],
1072
- device=inputs_embeds.device,
1073
- )
1074
- if position_ids is None:
1075
- position_ids = cache_position.unsqueeze(0)
1076
-
1077
- causal_mask = self._update_causal_mask(
1078
- attention_mask,
1079
- inputs_embeds,
1080
- cache_position,
1081
- past_key_values,
1082
- output_attentions,
1083
- )
1084
-
1085
- hidden_states = inputs_embeds
1086
-
1087
- # decoder layers
1088
- all_hidden_states = () if output_hidden_states else None
1089
- all_self_attns = () if output_attentions else None
1090
- next_decoder_cache = None
1091
-
1092
- for decoder_layer in self.h:
1093
- if output_hidden_states:
1094
- all_hidden_states += (hidden_states,)
1095
-
1096
- if self.gradient_checkpointing and self.training:
1097
- layer_outputs = self._gradient_checkpointing_func(
1098
- decoder_layer.__call__,
1099
- hidden_states,
1100
- causal_mask,
1101
- position_ids,
1102
- output_attentions,
1103
- use_cache,
1104
- past_key_values,
1105
- cache_position,
1106
- )
1107
- else:
1108
- layer_outputs = decoder_layer(
1109
- hidden_states,
1110
- attention_mask=causal_mask,
1111
- position_ids=position_ids,
1112
- past_key_value=past_key_values,
1113
- output_attentions=output_attentions,
1114
- use_cache=use_cache,
1115
- cache_position=cache_position,
1116
- )
1117
-
1118
- hidden_states = layer_outputs[0]
1119
-
1120
- if use_cache:
1121
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1122
-
1123
- if output_attentions:
1124
- all_self_attns += (layer_outputs[1],)
1125
-
1126
- # add hidden states from the last decoder layer
1127
- if output_hidden_states:
1128
- all_hidden_states += (hidden_states,)
1129
-
1130
- next_cache = None
1131
- if use_cache:
1132
- next_cache = (
1133
- next_decoder_cache.to_legacy_cache()
1134
- if use_legacy_cache
1135
- else next_decoder_cache
1136
- )
1137
- if not return_dict:
1138
- return tuple(
1139
- v
1140
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1141
- if v is not None
1142
- )
1143
- return BaseModelOutputWithPast(
1144
- last_hidden_state=hidden_states,
1145
- past_key_values=next_cache,
1146
- hidden_states=all_hidden_states,
1147
- attentions=all_self_attns,
1148
- )
1149
-
1150
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1151
- def _update_causal_mask(
1152
- self,
1153
- attention_mask: torch.Tensor,
1154
- input_tensor: torch.Tensor,
1155
- cache_position: torch.Tensor,
1156
- past_key_values: Cache,
1157
- output_attentions: bool,
1158
- ):
1159
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1160
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1161
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1162
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1163
-
1164
- if self.config._attn_implementation == "flash_attention_2":
1165
- if attention_mask is not None and 0.0 in attention_mask:
1166
- return attention_mask
1167
- return None
1168
-
1169
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1170
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1171
- # to infer the attention mask.
1172
- past_seen_tokens = (
1173
- past_key_values.get_seq_length() if past_key_values is not None else 0
1174
- )
1175
- using_static_cache = isinstance(past_key_values, StaticCache)
1176
-
1177
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1178
- if (
1179
- self.config._attn_implementation == "sdpa"
1180
- and not using_static_cache
1181
- and not output_attentions
1182
- ):
1183
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1184
- attention_mask,
1185
- inputs_embeds=input_tensor,
1186
- past_key_values_length=past_seen_tokens,
1187
- is_training=self.training,
1188
- ):
1189
- return None
1190
-
1191
- dtype, device = input_tensor.dtype, input_tensor.device
1192
- min_dtype = torch.finfo(dtype).min
1193
- sequence_length = input_tensor.shape[1]
1194
- if using_static_cache:
1195
- target_length = past_key_values.get_max_length()
1196
- else:
1197
- target_length = (
1198
- attention_mask.shape[-1]
1199
- if isinstance(attention_mask, torch.Tensor)
1200
- else past_seen_tokens + sequence_length + 1
1201
- )
1202
-
1203
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1204
- causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1205
- attention_mask,
1206
- sequence_length=sequence_length,
1207
- target_length=target_length,
1208
- dtype=dtype,
1209
- device=device,
1210
- min_dtype=min_dtype,
1211
- cache_position=cache_position,
1212
- batch_size=input_tensor.shape[0],
1213
- )
1214
-
1215
- if (
1216
- self.config._attn_implementation == "sdpa"
1217
- and attention_mask is not None
1218
- and attention_mask.device.type == "cuda"
1219
- and not output_attentions
1220
- ):
1221
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1222
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1223
- # Details: https://github.com/pytorch/pytorch/issues/110213
1224
- causal_mask = AttentionMaskConverter._unmask_unattended(
1225
- causal_mask, min_dtype
1226
- )
1227
-
1228
- return causal_mask
1229
-
1230
-
1231
- class CausalLMHead(nn.Module):
1232
- """Causal Language Modeling head. Simplified version."""
1233
-
1234
- def __init__(self, config):
1235
- super().__init__()
1236
- self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1237
- self.linear = nn.Linear(config.hidden_size, config.vocab_size)
1238
-
1239
- def forward(self, hidden_states):
1240
- return self.linear(self.ln(hidden_states))
1241
-
1242
-
1243
- class PhiForCausalLM(PhiPreTrainedModel):
1244
-
1245
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
1246
- def __init__(self, config):
1247
- super().__init__(config)
1248
- self.transformer = PhiModel(config)
1249
- self.vocab_size = config.vocab_size
1250
- self.lm_head = CausalLMHead(config)
1251
-
1252
- # Initialize weights and apply final processing
1253
- self.post_init()
1254
-
1255
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1256
- def get_input_embeddings(self):
1257
- return self.transformer.embd.wte
1258
-
1259
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1260
- def set_input_embeddings(self, value):
1261
- self.transformer.embd.wte = value
1262
-
1263
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1264
- def get_output_embeddings(self):
1265
- return self.lm_head.linear
1266
-
1267
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1268
- def set_output_embeddings(self, new_embeddings):
1269
- self.lm_head.linear = new_embeddings
1270
-
1271
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1272
- def set_decoder(self, decoder):
1273
- self.model = decoder
1274
-
1275
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1276
- def get_decoder(self):
1277
- return self.model
1278
-
1279
- @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1280
- @replace_return_docstrings(
1281
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1282
- )
1283
- def forward(
1284
- self,
1285
- input_ids: torch.LongTensor = None,
1286
- attention_mask: Optional[torch.Tensor] = None,
1287
- position_ids: Optional[torch.LongTensor] = None,
1288
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1289
- inputs_embeds: Optional[torch.FloatTensor] = None,
1290
- labels: Optional[torch.LongTensor] = None,
1291
- use_cache: Optional[bool] = None,
1292
- output_attentions: Optional[bool] = None,
1293
- output_hidden_states: Optional[bool] = None,
1294
- return_dict: Optional[bool] = None,
1295
- cache_position: Optional[torch.LongTensor] = None,
1296
- num_logits_to_keep: int = 0,
1297
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1298
- r"""
1299
- Args:
1300
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1301
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1302
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1303
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1304
-
1305
- num_logits_to_keep (`int`, *optional*):
1306
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1307
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1308
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1309
-
1310
- Returns:
1311
-
1312
- Example:
1313
-
1314
- ```python
1315
- >>> from transformers import AutoTokenizer, PhiForCausalLM
1316
-
1317
- >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1318
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1319
-
1320
- >>> prompt = "This is an example script ."
1321
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1322
-
1323
- >>> # Generate
1324
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1325
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1326
- 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1327
- ```"""
1328
-
1329
- output_attentions = (
1330
- output_attentions
1331
- if output_attentions is not None
1332
- else self.config.output_attentions
1333
- )
1334
- output_hidden_states = (
1335
- output_hidden_states
1336
- if output_hidden_states is not None
1337
- else self.config.output_hidden_states
1338
- )
1339
- return_dict = (
1340
- return_dict if return_dict is not None else self.config.use_return_dict
1341
- )
1342
-
1343
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1344
- outputs = self.transformer(
1345
- input_ids=input_ids,
1346
- attention_mask=attention_mask,
1347
- position_ids=position_ids,
1348
- past_key_values=past_key_values,
1349
- inputs_embeds=inputs_embeds,
1350
- use_cache=use_cache,
1351
- output_attentions=output_attentions,
1352
- output_hidden_states=output_hidden_states,
1353
- return_dict=return_dict,
1354
- cache_position=cache_position,
1355
- )
1356
-
1357
- hidden_states = outputs[0]
1358
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1359
-
1360
- loss = None
1361
- if labels is not None:
1362
- # Upcast to float if we need to compute the loss to avoid potential precision issues
1363
- logits = logits.float()
1364
- # Shift so that tokens < n predict n
1365
- shift_logits = logits[..., :-1, :].contiguous()
1366
- shift_labels = labels[..., 1:].contiguous()
1367
- # Flatten the tokens
1368
- loss_fct = CrossEntropyLoss()
1369
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1370
- shift_labels = shift_labels.view(-1)
1371
- # Enable model parallelism
1372
- shift_labels = shift_labels.to(shift_logits.device)
1373
- loss = loss_fct(shift_logits, shift_labels)
1374
-
1375
- if not return_dict:
1376
- output = (logits,) + outputs[1:]
1377
- return (loss,) + output if loss is not None else output
1378
-
1379
- return CausalLMOutputWithPast(
1380
- loss=loss,
1381
- logits=logits,
1382
- past_key_values=outputs.past_key_values,
1383
- hidden_states=outputs.hidden_states,
1384
- attentions=outputs.attentions,
1385
- )
1386
-
1387
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1388
- def prepare_inputs_for_generation(
1389
- self,
1390
- input_ids,
1391
- past_key_values=None,
1392
- attention_mask=None,
1393
- inputs_embeds=None,
1394
- cache_position=None,
1395
- position_ids=None,
1396
- use_cache=True,
1397
- num_logits_to_keep=0,
1398
- **kwargs,
1399
- ):
1400
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1401
- # Exception 1: when passing input_embeds, input_ids may be missing entries
1402
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1403
- if past_key_values is not None:
1404
- if inputs_embeds is not None: # Exception 1
1405
- input_ids = input_ids[:, -cache_position.shape[0] :]
1406
- elif (
1407
- input_ids.shape[1] != cache_position.shape[0]
1408
- ): # Default case (the "else", a no op, is Exception 2)
1409
- input_ids = input_ids[:, cache_position]
1410
-
1411
- if attention_mask is not None and position_ids is None:
1412
- # create position_ids on the fly for batch generation
1413
- position_ids = attention_mask.long().cumsum(-1) - 1
1414
- position_ids.masked_fill_(attention_mask == 0, 1)
1415
- if past_key_values:
1416
- position_ids = position_ids[:, -input_ids.shape[1] :]
1417
-
1418
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1419
- position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1420
-
1421
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1422
- if inputs_embeds is not None and cache_position[0] == 0:
1423
- model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1424
- else:
1425
- # The clone here is for the same reason as for `position_ids`.
1426
- model_inputs = {
1427
- "input_ids": input_ids.clone(memory_format=torch.contiguous_format),
1428
- "inputs_embeds": None,
1429
- }
1430
-
1431
- if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1432
- if model_inputs["inputs_embeds"] is not None:
1433
- batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1434
- device = model_inputs["inputs_embeds"].device
1435
- else:
1436
- batch_size, sequence_length = model_inputs["input_ids"].shape
1437
- device = model_inputs["input_ids"].device
1438
-
1439
- dtype = self.lm_head.weight.dtype
1440
- min_dtype = torch.finfo(dtype).min
1441
-
1442
- attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1443
- attention_mask,
1444
- sequence_length=sequence_length,
1445
- target_length=past_key_values.get_max_length(),
1446
- dtype=dtype,
1447
- device=device,
1448
- min_dtype=min_dtype,
1449
- cache_position=cache_position,
1450
- batch_size=batch_size,
1451
- )
1452
-
1453
- model_inputs.update(
1454
- {
1455
- "position_ids": position_ids,
1456
- "cache_position": cache_position,
1457
- "past_key_values": past_key_values,
1458
- "use_cache": use_cache,
1459
- "attention_mask": attention_mask,
1460
- "num_logits_to_keep": num_logits_to_keep,
1461
- }
1462
- )
1463
- return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/ae1ab764382e24c65d906c16fba36650b634426a DELETED
@@ -1,15 +0,0 @@
1
- {
2
- "architectures": [
3
- "Moondream"
4
- ],
5
- "auto_map": {
6
- "AutoConfig": "configuration_moondream.MoondreamConfig",
7
- "AutoModelForCausalLM": "moondream.Moondream"
8
- },
9
- "model_type": "moondream1",
10
- "text_config": {
11
- "model_type": "phi"
12
- },
13
- "torch_dtype": "float16",
14
- "transformers_version": "4.44.0"
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/b93162eb8252d2d937a69f17971c76b8be87aedd DELETED
@@ -1,18 +0,0 @@
1
- # Adopted from https://github.com/crowsonkb/k-diffusion/blob/transformer-model-v2/k_diffusion/layers.py
2
-
3
- import torch
4
- import torch.nn as nn
5
- import math
6
-
7
-
8
- class FourierFeatures(nn.Module):
9
- def __init__(self, in_features, out_features, std=1.0):
10
- super().__init__()
11
- assert out_features % 2 == 0
12
- self.register_buffer(
13
- "weight", torch.randn([out_features // 2, in_features]) * std
14
- )
15
-
16
- def forward(self, input):
17
- f = 2 * math.pi * input @ self.weight.T
18
- return torch.cat([f.cos(), f.sin()], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hub/models--vikhyatk--moondream2/blobs/c1148447551675ea739c440ee3e247df9f354d8f DELETED
The diff for this file is too large to render. See raw diff