Sasmiit commited on
Commit
4d28950
1 Parent(s): 34db749

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "vikhyatk/moondream2",
3
+ "architectures": [
4
+ "Moondream"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "vikhyatk/moondream2--configuration_moondream.MoondreamConfig",
8
+ "AutoModelForCausalLM": "vikhyatk/moondream2--moondream.Moondream"
9
+ },
10
+ "model_type": "moondream1",
11
+ "text_config": {
12
+ "model_type": "phi"
13
+ },
14
+ "torch_dtype": "float16",
15
+ "transformers_version": "4.44.2"
16
+ }
configuration_moondream.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class PhiConfig(PretrainedConfig):
5
+ model_type = "phi"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=51200,
11
+ hidden_size=2048,
12
+ intermediate_size=8192,
13
+ num_hidden_layers=24,
14
+ num_attention_heads=32,
15
+ num_key_value_heads=None,
16
+ resid_pdrop=0.0,
17
+ embd_pdrop=0.0,
18
+ attention_dropout=0.0,
19
+ hidden_act="gelu_new",
20
+ max_position_embeddings=2048,
21
+ initializer_range=0.02,
22
+ layer_norm_eps=1e-5,
23
+ use_cache=True,
24
+ tie_word_embeddings=False,
25
+ rope_theta=10000.0,
26
+ rope_scaling=None,
27
+ partial_rotary_factor=0.5,
28
+ bos_token_id=1,
29
+ eos_token_id=2,
30
+ **kwargs,
31
+ ):
32
+ self.vocab_size = vocab_size
33
+ self.hidden_size = hidden_size
34
+ self.intermediate_size = intermediate_size
35
+ self.num_hidden_layers = num_hidden_layers
36
+ self.num_attention_heads = num_attention_heads
37
+
38
+ if num_key_value_heads is None:
39
+ num_key_value_heads = num_attention_heads
40
+
41
+ self.num_key_value_heads = num_key_value_heads
42
+ self.resid_pdrop = resid_pdrop
43
+ self.embd_pdrop = embd_pdrop
44
+ self.attention_dropout = attention_dropout
45
+ self.hidden_act = hidden_act
46
+ self.max_position_embeddings = max_position_embeddings
47
+ self.initializer_range = initializer_range
48
+ self.layer_norm_eps = layer_norm_eps
49
+ self.use_cache = use_cache
50
+ self.rope_theta = rope_theta
51
+ self.rope_scaling = rope_scaling
52
+ self.partial_rotary_factor = partial_rotary_factor
53
+ self._rope_scaling_validation()
54
+
55
+ super().__init__(
56
+ bos_token_id=bos_token_id,
57
+ eos_token_id=eos_token_id,
58
+ tie_word_embeddings=tie_word_embeddings,
59
+ **kwargs,
60
+ )
61
+
62
+ # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
63
+ def _rope_scaling_validation(self):
64
+ """
65
+ Validate the `rope_scaling` configuration.
66
+ """
67
+ if self.rope_scaling is None:
68
+ return
69
+
70
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
71
+ raise ValueError(
72
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
73
+ f"got {self.rope_scaling}"
74
+ )
75
+ rope_scaling_type = self.rope_scaling.get("type", None)
76
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
77
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
78
+ raise ValueError(
79
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
80
+ )
81
+ if (
82
+ rope_scaling_factor is None
83
+ or not isinstance(rope_scaling_factor, float)
84
+ or rope_scaling_factor <= 1.0
85
+ ):
86
+ raise ValueError(
87
+ f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
88
+ )
89
+
90
+
91
+ class MoondreamConfig(PretrainedConfig):
92
+ model_type = "moondream1"
93
+
94
+ def __init__(self, **kwargs):
95
+ self.text_config = PhiConfig(**kwargs.pop("text_config", {}))
96
+ super().__init__(**kwargs)
fourier_features.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/crowsonkb/k-diffusion/blob/transformer-model-v2/k_diffusion/layers.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+
7
+
8
+ class FourierFeatures(nn.Module):
9
+ def __init__(self, in_features, out_features, std=1.0):
10
+ super().__init__()
11
+ assert out_features % 2 == 0
12
+ self.register_buffer(
13
+ "weight", torch.randn([out_features // 2, in_features]) * std
14
+ )
15
+
16
+ def forward(self, input):
17
+ f = 2 * math.pi * input @ self.weight.T
18
+ return torch.cat([f.cos(), f.sin()], dim=-1)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.44.2"
6
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9fe382ec0e9fdaeb26f3b498f01e40047443c9ec77a5373a363ecc41d2c6924
3
+ size 3736040266
modeling_phi.py ADDED
@@ -0,0 +1,1463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """PyTorch Phi model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from packaging import version
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
29
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.utils import (
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ get_torch_version,
39
+ is_flash_attn_2_available,
40
+ is_flash_attn_greater_or_equal_2_10,
41
+ is_torchdynamo_compiling,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from .configuration_moondream import PhiConfig
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CONFIG_FOR_DOC = "PhiConfig"
55
+
56
+
57
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
58
+ def _prepare_4d_causal_attention_mask_with_cache_position(
59
+ attention_mask: torch.Tensor,
60
+ sequence_length: int,
61
+ target_length: int,
62
+ dtype: torch.dtype,
63
+ device: torch.device,
64
+ min_dtype: float,
65
+ cache_position: torch.Tensor,
66
+ batch_size: int,
67
+ ):
68
+ """
69
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
70
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
71
+
72
+ Args:
73
+ attention_mask (`torch.Tensor`):
74
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
75
+ sequence_length (`int`):
76
+ The sequence length being processed.
77
+ target_length (`int`):
78
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
79
+ dtype (`torch.dtype`):
80
+ The dtype to use for the 4D attention mask.
81
+ device (`torch.device`):
82
+ The device to plcae the 4D attention mask on.
83
+ min_dtype (`float`):
84
+ The minimum value representable with the dtype `dtype`.
85
+ cache_position (`torch.Tensor`):
86
+ Indices depicting the position of the input sequence tokens in the sequence.
87
+ batch_size (`torch.Tensor`):
88
+ Batch size.
89
+ """
90
+ if attention_mask is not None and attention_mask.dim() == 4:
91
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
92
+ causal_mask = attention_mask
93
+ else:
94
+ causal_mask = torch.full(
95
+ (sequence_length, target_length),
96
+ fill_value=min_dtype,
97
+ dtype=dtype,
98
+ device=device,
99
+ )
100
+ if sequence_length != 1:
101
+ causal_mask = torch.triu(causal_mask, diagonal=1)
102
+ causal_mask *= torch.arange(
103
+ target_length, device=device
104
+ ) > cache_position.reshape(-1, 1)
105
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
106
+ if attention_mask is not None:
107
+ causal_mask = (
108
+ causal_mask.clone()
109
+ ) # copy to contiguous memory for in-place edit
110
+ mask_length = attention_mask.shape[-1]
111
+ padding_mask = (
112
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
113
+ )
114
+ padding_mask = padding_mask == 0
115
+ causal_mask[:, :, :, :mask_length] = causal_mask[
116
+ :, :, :, :mask_length
117
+ ].masked_fill(padding_mask, min_dtype)
118
+
119
+ return causal_mask
120
+
121
+
122
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
123
+ class PhiRotaryEmbedding(nn.Module):
124
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
125
+ super().__init__()
126
+
127
+ self.dim = dim
128
+ self.max_position_embeddings = max_position_embeddings
129
+ self.base = base
130
+ inv_freq = 1.0 / (
131
+ self.base
132
+ ** (
133
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
134
+ / self.dim
135
+ )
136
+ )
137
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
138
+
139
+ # Build here to make `torch.jit.trace` work.
140
+ self._set_cos_sin_cache(
141
+ seq_len=max_position_embeddings,
142
+ device=self.inv_freq.device,
143
+ dtype=torch.get_default_dtype(),
144
+ )
145
+
146
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
147
+ self.max_seq_len_cached = seq_len
148
+ t = torch.arange(
149
+ self.max_seq_len_cached, device=device, dtype=torch.int64
150
+ ).type_as(self.inv_freq)
151
+
152
+ freqs = torch.outer(t, self.inv_freq)
153
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
154
+ emb = torch.cat((freqs, freqs), dim=-1)
155
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
156
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
157
+
158
+ def forward(self, x, seq_len=None):
159
+ # x: [bs, num_attention_heads, seq_len, head_size]
160
+ if seq_len > self.max_seq_len_cached:
161
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
162
+
163
+ return (
164
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
165
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
166
+ )
167
+
168
+
169
+ # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
170
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
171
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
+
173
+ def __init__(
174
+ self,
175
+ dim,
176
+ max_position_embeddings=2048,
177
+ base=10000,
178
+ device=None,
179
+ scaling_factor=1.0,
180
+ ):
181
+ self.scaling_factor = scaling_factor
182
+ super().__init__(dim, max_position_embeddings, base, device)
183
+
184
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
185
+ self.max_seq_len_cached = seq_len
186
+ t = torch.arange(
187
+ self.max_seq_len_cached, device=device, dtype=torch.int64
188
+ ).type_as(self.inv_freq)
189
+ t = t / self.scaling_factor
190
+
191
+ freqs = torch.outer(t, self.inv_freq)
192
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
+ emb = torch.cat((freqs, freqs), dim=-1)
194
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
195
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
196
+
197
+
198
+ # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
199
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
200
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
201
+
202
+ def __init__(
203
+ self,
204
+ dim,
205
+ max_position_embeddings=2048,
206
+ base=10000,
207
+ device=None,
208
+ scaling_factor=1.0,
209
+ ):
210
+ self.scaling_factor = scaling_factor
211
+ super().__init__(dim, max_position_embeddings, base, device)
212
+
213
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
214
+ self.max_seq_len_cached = seq_len
215
+
216
+ if seq_len > self.max_position_embeddings:
217
+ base = self.base * (
218
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
219
+ - (self.scaling_factor - 1)
220
+ ) ** (self.dim / (self.dim - 2))
221
+ inv_freq = 1.0 / (
222
+ base
223
+ ** (
224
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
225
+ / self.dim
226
+ )
227
+ )
228
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
229
+
230
+ t = torch.arange(
231
+ self.max_seq_len_cached, device=device, dtype=torch.int64
232
+ ).type_as(self.inv_freq)
233
+
234
+ freqs = torch.outer(t, self.inv_freq)
235
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
236
+ emb = torch.cat((freqs, freqs), dim=-1)
237
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
238
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
239
+
240
+
241
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
242
+ def rotate_half(x):
243
+ """Rotates half the hidden dims of the input."""
244
+ x1 = x[..., : x.shape[-1] // 2]
245
+ x2 = x[..., x.shape[-1] // 2 :]
246
+ return torch.cat((-x2, x1), dim=-1)
247
+
248
+
249
+ # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
250
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
251
+ """Applies Rotary Position Embedding to the query and key tensors.
252
+
253
+ Args:
254
+ q (`torch.Tensor`): The query tensor.
255
+ k (`torch.Tensor`): The key tensor.
256
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
257
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
258
+ position_ids (`torch.Tensor`):
259
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
260
+ used to pass offsetted position ids when working with a KV-cache.
261
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
262
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
263
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
264
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
265
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
266
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
267
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
268
+ Returns:
269
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
270
+ """
271
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
272
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
273
+ q_embed = (q * cos) + (rotate_half(q) * sin)
274
+ k_embed = (k * cos) + (rotate_half(k) * sin)
275
+ return q_embed, k_embed
276
+
277
+
278
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
279
+ class PhiMLP(nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.config = config
283
+ self.activation_fn = ACT2FN[config.hidden_act]
284
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
285
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
286
+
287
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
288
+ hidden_states = self.fc1(hidden_states)
289
+ hidden_states = self.activation_fn(hidden_states)
290
+ hidden_states = self.fc2(hidden_states)
291
+ return hidden_states
292
+
293
+
294
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
295
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
296
+ """
297
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
298
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
299
+ """
300
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
301
+ if n_rep == 1:
302
+ return hidden_states
303
+ hidden_states = hidden_states[:, :, None, :, :].expand(
304
+ batch, num_key_value_heads, n_rep, slen, head_dim
305
+ )
306
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
307
+
308
+
309
+ class PhiAttention(nn.Module):
310
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
311
+
312
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
313
+ super().__init__()
314
+ self.config = config
315
+ self.layer_idx = layer_idx
316
+ if layer_idx is None:
317
+ logger.warning_once(
318
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
319
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
320
+ "when creating this class."
321
+ )
322
+
323
+ self.attention_dropout = config.attention_dropout
324
+ self.hidden_size = config.hidden_size
325
+ self.num_heads = config.num_attention_heads
326
+ self.head_dim = self.hidden_size // self.num_heads
327
+ self.num_key_value_heads = config.num_key_value_heads
328
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
329
+ self.max_position_embeddings = config.max_position_embeddings
330
+ self.rope_theta = config.rope_theta
331
+ self.partial_rotary_factor = config.partial_rotary_factor
332
+ self.is_causal = True
333
+
334
+ if (self.head_dim * self.num_heads) != self.hidden_size:
335
+ raise ValueError(
336
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
337
+ f" and `num_heads`: {self.num_heads})."
338
+ )
339
+
340
+ self.Wqkv = nn.Linear(
341
+ self.hidden_size, 3 * self.num_heads * self.head_dim, bias=True
342
+ )
343
+ self.out_proj = nn.Linear(
344
+ self.num_heads * self.head_dim, self.hidden_size, bias=True
345
+ )
346
+
347
+ self._init_rope()
348
+
349
+ def _init_rope(self):
350
+ if self.config.rope_scaling is None:
351
+ self.rotary_emb = PhiRotaryEmbedding(
352
+ int(self.partial_rotary_factor * self.head_dim),
353
+ max_position_embeddings=self.max_position_embeddings,
354
+ base=self.rope_theta,
355
+ )
356
+ else:
357
+ scaling_type = self.config.rope_scaling["type"]
358
+ scaling_factor = self.config.rope_scaling["factor"]
359
+ if scaling_type == "linear":
360
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
361
+ int(self.partial_rotary_factor * self.head_dim),
362
+ max_position_embeddings=self.max_position_embeddings,
363
+ scaling_factor=scaling_factor,
364
+ base=self.rope_theta,
365
+ )
366
+ elif scaling_type == "dynamic":
367
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
368
+ int(self.partial_rotary_factor * self.head_dim),
369
+ max_position_embeddings=self.max_position_embeddings,
370
+ scaling_factor=scaling_factor,
371
+ base=self.rope_theta,
372
+ )
373
+ else:
374
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states: torch.Tensor,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ position_ids: Optional[torch.LongTensor] = None,
381
+ past_key_value: Optional[Cache] = None,
382
+ output_attentions: bool = False,
383
+ use_cache: bool = False,
384
+ cache_position: Optional[torch.LongTensor] = None,
385
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
386
+ bsz, q_len, _ = hidden_states.size()
387
+
388
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
389
+ 3, dim=-1
390
+ )
391
+
392
+ query_states = query_states.view(
393
+ bsz, q_len, self.num_heads, self.head_dim
394
+ ).transpose(1, 2)
395
+ key_states = key_states.view(
396
+ bsz, q_len, self.num_key_value_heads, self.head_dim
397
+ ).transpose(1, 2)
398
+ value_states = value_states.view(
399
+ bsz, q_len, self.num_key_value_heads, self.head_dim
400
+ ).transpose(1, 2)
401
+
402
+ kv_seq_len = key_states.shape[-2]
403
+ if past_key_value is not None:
404
+ if self.layer_idx is None:
405
+ raise ValueError(
406
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
407
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
408
+ "with a layer index."
409
+ )
410
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
411
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
412
+
413
+ # Partial rotary embedding
414
+ query_rot, query_pass = (
415
+ query_states[..., : self.rotary_emb.dim],
416
+ query_states[..., self.rotary_emb.dim :],
417
+ )
418
+ key_rot, key_pass = (
419
+ key_states[..., : self.rotary_emb.dim],
420
+ key_states[..., self.rotary_emb.dim :],
421
+ )
422
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
423
+ query_rot, key_rot = apply_rotary_pos_emb(
424
+ query_rot, key_rot, cos, sin, position_ids
425
+ )
426
+
427
+ # [batch_size, seq_length, num_heads, head_dim]
428
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
429
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
430
+
431
+ if past_key_value is not None:
432
+ cache_kwargs = {
433
+ "sin": sin,
434
+ "cos": cos,
435
+ "partial_rotation_size": self.rotary_emb.dim,
436
+ "cache_position": cache_position,
437
+ }
438
+ key_states, value_states = past_key_value.update(
439
+ key_states, value_states, self.layer_idx, cache_kwargs
440
+ )
441
+
442
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
443
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
444
+
445
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
446
+ attn_weights = torch.matmul(
447
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
448
+ ) / math.sqrt(self.head_dim)
449
+
450
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
451
+ raise ValueError(
452
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
453
+ f" {attn_weights.size()}"
454
+ )
455
+
456
+ if attention_mask is not None:
457
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
458
+ attn_weights += causal_mask
459
+
460
+ # upcast attention to fp32
461
+ attn_weights = nn.functional.softmax(
462
+ attn_weights, dim=-1, dtype=torch.float32
463
+ ).to(value_states.dtype)
464
+ attn_weights = nn.functional.dropout(
465
+ attn_weights, p=self.attention_dropout, training=self.training
466
+ )
467
+
468
+ attn_output = torch.matmul(attn_weights, value_states)
469
+
470
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
471
+ raise ValueError(
472
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
473
+ f" {attn_output.size()}"
474
+ )
475
+
476
+ attn_output = attn_output.transpose(1, 2).contiguous()
477
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
478
+
479
+ attn_output = self.out_proj(attn_output)
480
+
481
+ if not output_attentions:
482
+ attn_weights = None
483
+
484
+ return attn_output, attn_weights, past_key_value
485
+
486
+
487
+ class PhiFlashAttention2(PhiAttention):
488
+ """
489
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
490
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
491
+ flash attention and deal with padding tokens in case the input contains any of them.
492
+ """
493
+
494
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
495
+ def __init__(self, *args, **kwargs):
496
+ super().__init__(*args, **kwargs)
497
+
498
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
499
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
500
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
501
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
502
+
503
+ def forward(
504
+ self,
505
+ hidden_states: torch.Tensor,
506
+ attention_mask: Optional[torch.LongTensor] = None,
507
+ position_ids: Optional[torch.LongTensor] = None,
508
+ past_key_value: Optional[Cache] = None,
509
+ output_attentions: bool = False,
510
+ use_cache: bool = False,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ **kwargs,
513
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
514
+ # PhiFlashAttention2 attention does not support output_attentions
515
+
516
+ output_attentions = False
517
+
518
+ bsz, q_len, _ = hidden_states.size()
519
+
520
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
521
+ 3, dim=-1
522
+ )
523
+
524
+ # Flash attention requires the input to have the shape
525
+ # batch_size x seq_length x head_dim x hidden_dim
526
+ # therefore we just need to keep the original shape
527
+ query_states = query_states.view(
528
+ bsz, q_len, self.num_heads, self.head_dim
529
+ ).transpose(1, 2)
530
+ key_states = key_states.view(
531
+ bsz, q_len, self.num_key_value_heads, self.head_dim
532
+ ).transpose(1, 2)
533
+ value_states = value_states.view(
534
+ bsz, q_len, self.num_key_value_heads, self.head_dim
535
+ ).transpose(1, 2)
536
+
537
+ kv_seq_len = key_states.shape[-2]
538
+ if past_key_value is not None:
539
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
540
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
541
+
542
+ # Partial rotary embedding
543
+ query_rot, query_pass = (
544
+ query_states[..., : self.rotary_emb.dim],
545
+ query_states[..., self.rotary_emb.dim :],
546
+ )
547
+ key_rot, key_pass = (
548
+ key_states[..., : self.rotary_emb.dim],
549
+ key_states[..., self.rotary_emb.dim :],
550
+ )
551
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
552
+ query_rot, key_rot = apply_rotary_pos_emb(
553
+ query_rot, key_rot, cos, sin, position_ids
554
+ )
555
+
556
+ # [batch_size, seq_length, num_heads, head_dim]
557
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
558
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
559
+
560
+ if past_key_value is not None:
561
+ cache_kwargs = {
562
+ "sin": sin,
563
+ "cos": cos,
564
+ "partial_rotation_size": self.rotary_emb.dim,
565
+ "cache_position": cache_position,
566
+ }
567
+ key_states, value_states = past_key_value.update(
568
+ key_states, value_states, self.layer_idx, cache_kwargs
569
+ )
570
+
571
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
572
+ # to be able to avoid many of these transpose/reshape/view.
573
+ query_states = query_states.transpose(1, 2)
574
+ key_states = key_states.transpose(1, 2)
575
+ value_states = value_states.transpose(1, 2)
576
+
577
+ attn_dropout = self.attention_dropout if self.training else 0.0
578
+
579
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
580
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
581
+ # cast them back in the correct dtype just to be sure everything works as expected.
582
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
583
+ # in fp32.
584
+
585
+ if query_states.dtype == torch.float32:
586
+ if torch.is_autocast_enabled():
587
+ target_dtype = torch.get_autocast_gpu_dtype()
588
+ # Handle the case where the model is quantized
589
+ elif hasattr(self.config, "_pre_quantization_dtype"):
590
+ target_dtype = self.config._pre_quantization_dtype
591
+ else:
592
+ target_dtype = self.q_proj.weight.dtype
593
+
594
+ logger.warning_once(
595
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
596
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
597
+ f" {target_dtype}."
598
+ )
599
+
600
+ query_states = query_states.to(target_dtype)
601
+ key_states = key_states.to(target_dtype)
602
+ value_states = value_states.to(target_dtype)
603
+
604
+ attn_output = _flash_attention_forward(
605
+ query_states,
606
+ key_states,
607
+ value_states,
608
+ attention_mask,
609
+ q_len,
610
+ position_ids=position_ids,
611
+ dropout=attn_dropout,
612
+ softmax_scale=None,
613
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
614
+ is_causal=self.is_causal,
615
+ )
616
+
617
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
618
+ attn_output = self.out_proj(attn_output)
619
+
620
+ if not output_attentions:
621
+ attn_weights = None
622
+
623
+ return attn_output, attn_weights, past_key_value
624
+
625
+
626
+ class PhiSdpaAttention(PhiAttention):
627
+ def __init__(self, *args, **kwargs):
628
+ super().__init__(*args, **kwargs)
629
+ self.require_contiguous_qkv = version.parse(
630
+ get_torch_version()
631
+ ) < version.parse("2.2.0")
632
+
633
+ """
634
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
635
+ `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
636
+ SDPA API.
637
+ """
638
+
639
+ # Adapted from PhiAttention.forward
640
+ def forward(
641
+ self,
642
+ hidden_states: torch.Tensor,
643
+ attention_mask: Optional[torch.Tensor] = None,
644
+ position_ids: Optional[torch.LongTensor] = None,
645
+ past_key_value: Optional[Cache] = None,
646
+ output_attentions: bool = False,
647
+ use_cache: bool = False,
648
+ cache_position: Optional[torch.LongTensor] = None,
649
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
650
+ if output_attentions:
651
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
652
+ logger.warning_once(
653
+ "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
654
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
655
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
656
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
657
+ )
658
+ return super().forward(
659
+ hidden_states=hidden_states,
660
+ attention_mask=attention_mask,
661
+ position_ids=position_ids,
662
+ past_key_value=past_key_value,
663
+ output_attentions=output_attentions,
664
+ use_cache=use_cache,
665
+ )
666
+
667
+ bsz, q_len, _ = hidden_states.size()
668
+
669
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
670
+ 3, dim=-1
671
+ )
672
+
673
+ query_states = query_states.view(
674
+ bsz, q_len, self.num_heads, self.head_dim
675
+ ).transpose(1, 2)
676
+ key_states = key_states.view(
677
+ bsz, q_len, self.num_key_value_heads, self.head_dim
678
+ ).transpose(1, 2)
679
+ value_states = value_states.view(
680
+ bsz, q_len, self.num_key_value_heads, self.head_dim
681
+ ).transpose(1, 2)
682
+
683
+ kv_seq_len = key_states.shape[-2]
684
+ if past_key_value is not None:
685
+ if self.layer_idx is None:
686
+ raise ValueError(
687
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
688
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
689
+ "with a layer index."
690
+ )
691
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
692
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
693
+
694
+ # Partial rotary embedding
695
+ query_rot, query_pass = (
696
+ query_states[..., : self.rotary_emb.dim],
697
+ query_states[..., self.rotary_emb.dim :],
698
+ )
699
+ key_rot, key_pass = (
700
+ key_states[..., : self.rotary_emb.dim],
701
+ key_states[..., self.rotary_emb.dim :],
702
+ )
703
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
704
+ query_rot, key_rot = apply_rotary_pos_emb(
705
+ query_rot, key_rot, cos, sin, position_ids
706
+ )
707
+
708
+ # [batch_size, seq_length, num_heads, head_dim]
709
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
710
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
711
+
712
+ if past_key_value is not None:
713
+ cache_kwargs = {
714
+ "sin": sin,
715
+ "cos": cos,
716
+ "partial_rotation_size": self.rotary_emb.dim,
717
+ "cache_position": cache_position,
718
+ }
719
+ key_states, value_states = past_key_value.update(
720
+ key_states, value_states, self.layer_idx, cache_kwargs
721
+ )
722
+
723
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
724
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
725
+
726
+ causal_mask = attention_mask
727
+ if attention_mask is not None:
728
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
729
+
730
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
731
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
732
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
733
+ if (
734
+ self.require_contiguous_qkv
735
+ and query_states.device.type == "cuda"
736
+ and attention_mask is not None
737
+ ):
738
+ query_states = query_states.contiguous()
739
+ key_states = key_states.contiguous()
740
+ value_states = value_states.contiguous()
741
+
742
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
743
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
744
+ is_causal = True if causal_mask is None and q_len > 1 else False
745
+
746
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
747
+ query_states,
748
+ key_states,
749
+ value_states,
750
+ attn_mask=causal_mask,
751
+ dropout_p=self.attention_dropout if self.training else 0.0,
752
+ is_causal=is_causal,
753
+ )
754
+
755
+ attn_output = attn_output.transpose(1, 2).contiguous()
756
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
757
+
758
+ attn_output = self.out_proj(attn_output)
759
+
760
+ return attn_output, None, past_key_value
761
+
762
+
763
+ PHI_ATTENTION_CLASSES = {
764
+ "eager": PhiAttention,
765
+ "flash_attention_2": PhiFlashAttention2,
766
+ "sdpa": PhiSdpaAttention,
767
+ }
768
+
769
+
770
+ class PhiDecoderLayer(nn.Module):
771
+ def __init__(self, config: PhiConfig, layer_idx: int):
772
+ super().__init__()
773
+ self.mixer = PHI_ATTENTION_CLASSES[config._attn_implementation](
774
+ config, layer_idx=layer_idx
775
+ )
776
+ self.mlp = PhiMLP(config)
777
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
778
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
779
+
780
+ def forward(
781
+ self,
782
+ hidden_states: torch.Tensor,
783
+ attention_mask: Optional[torch.Tensor] = None,
784
+ position_ids: Optional[torch.LongTensor] = None,
785
+ output_attentions: Optional[bool] = False,
786
+ use_cache: Optional[bool] = False,
787
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
788
+ cache_position: Optional[torch.LongTensor] = None,
789
+ **kwargs,
790
+ ) -> Tuple[
791
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
792
+ ]:
793
+ """
794
+ Args:
795
+ hidden_states (`torch.FloatTensor`):
796
+ input to the layer of shape `(batch, seq_len, embed_dim)`
797
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
798
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
799
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
800
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
801
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
802
+ output_attentions (`bool`, *optional*):
803
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
804
+ returned tensors for more detail.
805
+ use_cache (`bool`, *optional*):
806
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
807
+ (see `past_key_values`).
808
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
809
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
810
+ Indices depicting the position of the input sequence tokens in the sequence
811
+ kwargs (`dict`, *optional*):
812
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
813
+ into the model
814
+ """
815
+
816
+ residual = hidden_states
817
+
818
+ hidden_states = self.ln(hidden_states)
819
+
820
+ # Self Attention
821
+ attn_outputs, self_attn_weights, present_key_value = self.mixer(
822
+ hidden_states=hidden_states,
823
+ attention_mask=attention_mask,
824
+ position_ids=position_ids,
825
+ past_key_value=past_key_value,
826
+ output_attentions=output_attentions,
827
+ use_cache=use_cache,
828
+ cache_position=cache_position,
829
+ )
830
+ attn_outputs = self.resid_dropout(attn_outputs)
831
+
832
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
833
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
834
+ outputs = (hidden_states,)
835
+
836
+ if output_attentions:
837
+ outputs += (self_attn_weights,)
838
+
839
+ if use_cache:
840
+ outputs += (present_key_value,)
841
+
842
+ return outputs
843
+
844
+
845
+ PHI_START_DOCSTRING = r"""
846
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
847
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
848
+ etc.)
849
+
850
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
851
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
852
+ and behavior.
853
+
854
+ Parameters:
855
+ config ([`PhiConfig`]):
856
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
857
+ load the weights associated with the model, only the configuration. Check out the
858
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
859
+ """
860
+
861
+
862
+ @add_start_docstrings(
863
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
864
+ PHI_START_DOCSTRING,
865
+ )
866
+ class PhiPreTrainedModel(PreTrainedModel):
867
+ config_class = PhiConfig
868
+ base_model_prefix = "model"
869
+ supports_gradient_checkpointing = True
870
+ _no_split_modules = ["PhiDecoderLayer"]
871
+ _skip_keys_device_placement = "past_key_values"
872
+ _supports_flash_attn_2 = True
873
+ _supports_sdpa = True
874
+ _supports_cache_class = True
875
+
876
+ def _init_weights(self, module):
877
+ std = self.config.initializer_range
878
+ if isinstance(module, nn.Linear):
879
+ module.weight.data.normal_(mean=0.0, std=std)
880
+ if module.bias is not None:
881
+ module.bias.data.zero_()
882
+ elif isinstance(module, nn.Embedding):
883
+ module.weight.data.normal_(mean=0.0, std=std)
884
+ if module.padding_idx is not None:
885
+ module.weight.data[module.padding_idx].zero_()
886
+
887
+
888
+ class Embedding(nn.Module):
889
+ def __init__(self, config: PhiConfig):
890
+ super().__init__()
891
+ self.wte = nn.Embedding(
892
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
893
+ )
894
+
895
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
896
+ return self.wte(input_ids)
897
+
898
+ PHI_INPUTS_DOCSTRING = r"""
899
+ Args:
900
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
901
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
902
+ it.
903
+
904
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
905
+ [`PreTrainedTokenizer.__call__`] for details.
906
+
907
+ [What are input IDs?](../glossary#input-ids)
908
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
909
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
910
+
911
+ - 1 for tokens that are **not masked**,
912
+ - 0 for tokens that are **masked**.
913
+
914
+ [What are attention masks?](../glossary#attention-mask)
915
+
916
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
+ [`PreTrainedTokenizer.__call__`] for details.
918
+
919
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
920
+ `past_key_values`).
921
+
922
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
923
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
924
+ information on the default strategy.
925
+
926
+ - 1 indicates the head is **not masked**,
927
+ - 0 indicates the head is **masked**.
928
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
929
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
930
+ config.n_positions - 1]`.
931
+
932
+ [What are position IDs?](../glossary#position-ids)
933
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
934
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
935
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
936
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
937
+
938
+ Two formats are allowed:
939
+ - a [`~cache_utils.Cache`] instance;
940
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
941
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
942
+ cache format.
943
+
944
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
945
+ legacy cache format will be returned.
946
+
947
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
948
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
949
+ of shape `(batch_size, sequence_length)`.
950
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
951
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
952
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
953
+ model's internal embedding lookup matrix.
954
+ use_cache (`bool`, *optional*):
955
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
956
+ `past_key_values`).
957
+ output_attentions (`bool`, *optional*):
958
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
959
+ tensors for more detail.
960
+ output_hidden_states (`bool`, *optional*):
961
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
962
+ more detail.
963
+ return_dict (`bool`, *optional*):
964
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
965
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
966
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
967
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
968
+ the complete sequence length.
969
+ """
970
+
971
+
972
+ @add_start_docstrings(
973
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
974
+ PHI_START_DOCSTRING,
975
+ )
976
+ class PhiModel(PhiPreTrainedModel):
977
+ """
978
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
979
+
980
+ Args:
981
+ config: PhiConfig
982
+ """
983
+
984
+ def __init__(self, config: PhiConfig):
985
+ super().__init__(config)
986
+ self.padding_idx = config.pad_token_id
987
+ self.vocab_size = config.vocab_size
988
+
989
+ self.embd = Embedding(config)
990
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
991
+ self.h = nn.ModuleList(
992
+ [
993
+ PhiDecoderLayer(config, layer_idx)
994
+ for layer_idx in range(config.num_hidden_layers)
995
+ ]
996
+ )
997
+
998
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
999
+ self._use_sdpa = config._attn_implementation == "sdpa"
1000
+
1001
+ self.gradient_checkpointing = False
1002
+ # Initialize weights and apply final processing
1003
+ self.post_init()
1004
+
1005
+ def get_input_embeddings(self):
1006
+ return self.embd.wte
1007
+
1008
+ def set_input_embeddings(self, value):
1009
+ self.embd.wte = value
1010
+
1011
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1012
+ def forward(
1013
+ self,
1014
+ input_ids: torch.LongTensor = None,
1015
+ attention_mask: Optional[torch.Tensor] = None,
1016
+ position_ids: Optional[torch.LongTensor] = None,
1017
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1018
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1019
+ use_cache: Optional[bool] = None,
1020
+ output_attentions: Optional[bool] = None,
1021
+ output_hidden_states: Optional[bool] = None,
1022
+ return_dict: Optional[bool] = None,
1023
+ cache_position: Optional[torch.LongTensor] = None,
1024
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1025
+ output_attentions = (
1026
+ output_attentions
1027
+ if output_attentions is not None
1028
+ else self.config.output_attentions
1029
+ )
1030
+ output_hidden_states = (
1031
+ output_hidden_states
1032
+ if output_hidden_states is not None
1033
+ else self.config.output_hidden_states
1034
+ )
1035
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1036
+
1037
+ return_dict = (
1038
+ return_dict if return_dict is not None else self.config.use_return_dict
1039
+ )
1040
+
1041
+ if (input_ids is None) ^ (inputs_embeds is not None):
1042
+ raise ValueError(
1043
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1044
+ )
1045
+
1046
+ if self.gradient_checkpointing and self.training:
1047
+ if use_cache:
1048
+ logger.warning_once(
1049
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1050
+ )
1051
+ use_cache = False
1052
+
1053
+ use_legacy_cache = False
1054
+ if use_cache and not isinstance(past_key_values, Cache) and not self.training:
1055
+ use_legacy_cache = True
1056
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1057
+ logger.warning_once(
1058
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1059
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
1060
+ )
1061
+
1062
+ if inputs_embeds is None:
1063
+ inputs_embeds = self.embd(input_ids)
1064
+
1065
+ if cache_position is None:
1066
+ past_seen_tokens = (
1067
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1068
+ )
1069
+ cache_position = torch.arange(
1070
+ past_seen_tokens,
1071
+ past_seen_tokens + inputs_embeds.shape[1],
1072
+ device=inputs_embeds.device,
1073
+ )
1074
+ if position_ids is None:
1075
+ position_ids = cache_position.unsqueeze(0)
1076
+
1077
+ causal_mask = self._update_causal_mask(
1078
+ attention_mask,
1079
+ inputs_embeds,
1080
+ cache_position,
1081
+ past_key_values,
1082
+ output_attentions,
1083
+ )
1084
+
1085
+ hidden_states = inputs_embeds
1086
+
1087
+ # decoder layers
1088
+ all_hidden_states = () if output_hidden_states else None
1089
+ all_self_attns = () if output_attentions else None
1090
+ next_decoder_cache = None
1091
+
1092
+ for decoder_layer in self.h:
1093
+ if output_hidden_states:
1094
+ all_hidden_states += (hidden_states,)
1095
+
1096
+ if self.gradient_checkpointing and self.training:
1097
+ layer_outputs = self._gradient_checkpointing_func(
1098
+ decoder_layer.__call__,
1099
+ hidden_states,
1100
+ causal_mask,
1101
+ position_ids,
1102
+ output_attentions,
1103
+ use_cache,
1104
+ past_key_values,
1105
+ cache_position,
1106
+ )
1107
+ else:
1108
+ layer_outputs = decoder_layer(
1109
+ hidden_states,
1110
+ attention_mask=causal_mask,
1111
+ position_ids=position_ids,
1112
+ past_key_value=past_key_values,
1113
+ output_attentions=output_attentions,
1114
+ use_cache=use_cache,
1115
+ cache_position=cache_position,
1116
+ )
1117
+
1118
+ hidden_states = layer_outputs[0]
1119
+
1120
+ if use_cache:
1121
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1122
+
1123
+ if output_attentions:
1124
+ all_self_attns += (layer_outputs[1],)
1125
+
1126
+ # add hidden states from the last decoder layer
1127
+ if output_hidden_states:
1128
+ all_hidden_states += (hidden_states,)
1129
+
1130
+ next_cache = None
1131
+ if use_cache:
1132
+ next_cache = (
1133
+ next_decoder_cache.to_legacy_cache()
1134
+ if use_legacy_cache
1135
+ else next_decoder_cache
1136
+ )
1137
+ if not return_dict:
1138
+ return tuple(
1139
+ v
1140
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1141
+ if v is not None
1142
+ )
1143
+ return BaseModelOutputWithPast(
1144
+ last_hidden_state=hidden_states,
1145
+ past_key_values=next_cache,
1146
+ hidden_states=all_hidden_states,
1147
+ attentions=all_self_attns,
1148
+ )
1149
+
1150
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1151
+ def _update_causal_mask(
1152
+ self,
1153
+ attention_mask: torch.Tensor,
1154
+ input_tensor: torch.Tensor,
1155
+ cache_position: torch.Tensor,
1156
+ past_key_values: Cache,
1157
+ output_attentions: bool,
1158
+ ):
1159
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1160
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1161
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1162
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1163
+
1164
+ if self.config._attn_implementation == "flash_attention_2":
1165
+ if attention_mask is not None and 0.0 in attention_mask:
1166
+ return attention_mask
1167
+ return None
1168
+
1169
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1170
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1171
+ # to infer the attention mask.
1172
+ past_seen_tokens = (
1173
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1174
+ )
1175
+ using_static_cache = isinstance(past_key_values, StaticCache)
1176
+
1177
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1178
+ if (
1179
+ self.config._attn_implementation == "sdpa"
1180
+ and not using_static_cache
1181
+ and not output_attentions
1182
+ ):
1183
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1184
+ attention_mask,
1185
+ inputs_embeds=input_tensor,
1186
+ past_key_values_length=past_seen_tokens,
1187
+ is_training=self.training,
1188
+ ):
1189
+ return None
1190
+
1191
+ dtype, device = input_tensor.dtype, input_tensor.device
1192
+ min_dtype = torch.finfo(dtype).min
1193
+ sequence_length = input_tensor.shape[1]
1194
+ if using_static_cache:
1195
+ target_length = past_key_values.get_max_length()
1196
+ else:
1197
+ target_length = (
1198
+ attention_mask.shape[-1]
1199
+ if isinstance(attention_mask, torch.Tensor)
1200
+ else past_seen_tokens + sequence_length + 1
1201
+ )
1202
+
1203
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1204
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1205
+ attention_mask,
1206
+ sequence_length=sequence_length,
1207
+ target_length=target_length,
1208
+ dtype=dtype,
1209
+ device=device,
1210
+ min_dtype=min_dtype,
1211
+ cache_position=cache_position,
1212
+ batch_size=input_tensor.shape[0],
1213
+ )
1214
+
1215
+ if (
1216
+ self.config._attn_implementation == "sdpa"
1217
+ and attention_mask is not None
1218
+ and attention_mask.device.type == "cuda"
1219
+ and not output_attentions
1220
+ ):
1221
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1222
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1223
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1224
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1225
+ causal_mask, min_dtype
1226
+ )
1227
+
1228
+ return causal_mask
1229
+
1230
+
1231
+ class CausalLMHead(nn.Module):
1232
+ """Causal Language Modeling head. Simplified version."""
1233
+
1234
+ def __init__(self, config):
1235
+ super().__init__()
1236
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1237
+ self.linear = nn.Linear(config.hidden_size, config.vocab_size)
1238
+
1239
+ def forward(self, hidden_states):
1240
+ return self.linear(self.ln(hidden_states))
1241
+
1242
+
1243
+ class PhiForCausalLM(PhiPreTrainedModel):
1244
+
1245
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
1246
+ def __init__(self, config):
1247
+ super().__init__(config)
1248
+ self.transformer = PhiModel(config)
1249
+ self.vocab_size = config.vocab_size
1250
+ self.lm_head = CausalLMHead(config)
1251
+
1252
+ # Initialize weights and apply final processing
1253
+ self.post_init()
1254
+
1255
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1256
+ def get_input_embeddings(self):
1257
+ return self.transformer.embd.wte
1258
+
1259
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1260
+ def set_input_embeddings(self, value):
1261
+ self.transformer.embd.wte = value
1262
+
1263
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1264
+ def get_output_embeddings(self):
1265
+ return self.lm_head.linear
1266
+
1267
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1268
+ def set_output_embeddings(self, new_embeddings):
1269
+ self.lm_head.linear = new_embeddings
1270
+
1271
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1272
+ def set_decoder(self, decoder):
1273
+ self.model = decoder
1274
+
1275
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1276
+ def get_decoder(self):
1277
+ return self.model
1278
+
1279
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1280
+ @replace_return_docstrings(
1281
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1282
+ )
1283
+ def forward(
1284
+ self,
1285
+ input_ids: torch.LongTensor = None,
1286
+ attention_mask: Optional[torch.Tensor] = None,
1287
+ position_ids: Optional[torch.LongTensor] = None,
1288
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1289
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1290
+ labels: Optional[torch.LongTensor] = None,
1291
+ use_cache: Optional[bool] = None,
1292
+ output_attentions: Optional[bool] = None,
1293
+ output_hidden_states: Optional[bool] = None,
1294
+ return_dict: Optional[bool] = None,
1295
+ cache_position: Optional[torch.LongTensor] = None,
1296
+ num_logits_to_keep: int = 0,
1297
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1298
+ r"""
1299
+ Args:
1300
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1301
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1302
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1303
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1304
+
1305
+ num_logits_to_keep (`int`, *optional*):
1306
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1307
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1308
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1309
+
1310
+ Returns:
1311
+
1312
+ Example:
1313
+
1314
+ ```python
1315
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1316
+
1317
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1318
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1319
+
1320
+ >>> prompt = "This is an example script ."
1321
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1322
+
1323
+ >>> # Generate
1324
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1325
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1326
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1327
+ ```"""
1328
+
1329
+ output_attentions = (
1330
+ output_attentions
1331
+ if output_attentions is not None
1332
+ else self.config.output_attentions
1333
+ )
1334
+ output_hidden_states = (
1335
+ output_hidden_states
1336
+ if output_hidden_states is not None
1337
+ else self.config.output_hidden_states
1338
+ )
1339
+ return_dict = (
1340
+ return_dict if return_dict is not None else self.config.use_return_dict
1341
+ )
1342
+
1343
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1344
+ outputs = self.transformer(
1345
+ input_ids=input_ids,
1346
+ attention_mask=attention_mask,
1347
+ position_ids=position_ids,
1348
+ past_key_values=past_key_values,
1349
+ inputs_embeds=inputs_embeds,
1350
+ use_cache=use_cache,
1351
+ output_attentions=output_attentions,
1352
+ output_hidden_states=output_hidden_states,
1353
+ return_dict=return_dict,
1354
+ cache_position=cache_position,
1355
+ )
1356
+
1357
+ hidden_states = outputs[0]
1358
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1359
+
1360
+ loss = None
1361
+ if labels is not None:
1362
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1363
+ logits = logits.float()
1364
+ # Shift so that tokens < n predict n
1365
+ shift_logits = logits[..., :-1, :].contiguous()
1366
+ shift_labels = labels[..., 1:].contiguous()
1367
+ # Flatten the tokens
1368
+ loss_fct = CrossEntropyLoss()
1369
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1370
+ shift_labels = shift_labels.view(-1)
1371
+ # Enable model parallelism
1372
+ shift_labels = shift_labels.to(shift_logits.device)
1373
+ loss = loss_fct(shift_logits, shift_labels)
1374
+
1375
+ if not return_dict:
1376
+ output = (logits,) + outputs[1:]
1377
+ return (loss,) + output if loss is not None else output
1378
+
1379
+ return CausalLMOutputWithPast(
1380
+ loss=loss,
1381
+ logits=logits,
1382
+ past_key_values=outputs.past_key_values,
1383
+ hidden_states=outputs.hidden_states,
1384
+ attentions=outputs.attentions,
1385
+ )
1386
+
1387
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1388
+ def prepare_inputs_for_generation(
1389
+ self,
1390
+ input_ids,
1391
+ past_key_values=None,
1392
+ attention_mask=None,
1393
+ inputs_embeds=None,
1394
+ cache_position=None,
1395
+ position_ids=None,
1396
+ use_cache=True,
1397
+ num_logits_to_keep=0,
1398
+ **kwargs,
1399
+ ):
1400
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1401
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1402
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1403
+ if past_key_values is not None:
1404
+ if inputs_embeds is not None: # Exception 1
1405
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1406
+ elif (
1407
+ input_ids.shape[1] != cache_position.shape[0]
1408
+ ): # Default case (the "else", a no op, is Exception 2)
1409
+ input_ids = input_ids[:, cache_position]
1410
+
1411
+ if attention_mask is not None and position_ids is None:
1412
+ # create position_ids on the fly for batch generation
1413
+ position_ids = attention_mask.long().cumsum(-1) - 1
1414
+ position_ids.masked_fill_(attention_mask == 0, 1)
1415
+ if past_key_values:
1416
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1417
+
1418
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1419
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1420
+
1421
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1422
+ if inputs_embeds is not None and cache_position[0] == 0:
1423
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1424
+ else:
1425
+ # The clone here is for the same reason as for `position_ids`.
1426
+ model_inputs = {
1427
+ "input_ids": input_ids.clone(memory_format=torch.contiguous_format),
1428
+ "inputs_embeds": None,
1429
+ }
1430
+
1431
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1432
+ if model_inputs["inputs_embeds"] is not None:
1433
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1434
+ device = model_inputs["inputs_embeds"].device
1435
+ else:
1436
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1437
+ device = model_inputs["input_ids"].device
1438
+
1439
+ dtype = self.lm_head.weight.dtype
1440
+ min_dtype = torch.finfo(dtype).min
1441
+
1442
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1443
+ attention_mask,
1444
+ sequence_length=sequence_length,
1445
+ target_length=past_key_values.get_max_length(),
1446
+ dtype=dtype,
1447
+ device=device,
1448
+ min_dtype=min_dtype,
1449
+ cache_position=cache_position,
1450
+ batch_size=batch_size,
1451
+ )
1452
+
1453
+ model_inputs.update(
1454
+ {
1455
+ "position_ids": position_ids,
1456
+ "cache_position": cache_position,
1457
+ "past_key_values": past_key_values,
1458
+ "use_cache": use_cache,
1459
+ "attention_mask": attention_mask,
1460
+ "num_logits_to_keep": num_logits_to_keep,
1461
+ }
1462
+ )
1463
+ return model_inputs
moondream.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import List, Union, Literal, Optional
4
+ from transformers import PreTrainedModel
5
+ from PIL import Image
6
+
7
+ from .configuration_moondream import PhiConfig
8
+ from .configuration_moondream import MoondreamConfig
9
+ from .vision_encoder import VisionEncoder
10
+ from .region_model import RegionModel
11
+ from .modeling_phi import PhiForCausalLM
12
+
13
+ class Moondream(PreTrainedModel):
14
+ config_class = MoondreamConfig
15
+ _supports_flash_attn_2 = True
16
+
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+ self.vision_encoder = VisionEncoder(
20
+ use_flash_attn=config._attn_implementation == "flash_attention_2"
21
+ )
22
+ self.region_model = RegionModel()
23
+
24
+ if type(config.text_config) == dict:
25
+ phi_config = PhiConfig(
26
+ **config.text_config, attn_implementation=config._attn_implementation
27
+ )
28
+ else:
29
+ phi_config = config.text_config
30
+ self.text_model = PhiForCausalLM(phi_config)
31
+
32
+ @property
33
+ def device(self):
34
+ return self.text_model.device
35
+
36
+ def encode_image(self, image):
37
+ with torch.no_grad():
38
+ return self.vision_encoder(image)
39
+
40
+ def input_embeds(self, prompt, image_embeds, tokenizer):
41
+ def _tokenize(txt):
42
+ return tokenizer(
43
+ txt, return_tensors="pt", add_special_tokens=False
44
+ ).input_ids.to(self.device)
45
+
46
+ text_emb = self.text_model.get_input_embeddings()
47
+
48
+ # Add BOS token
49
+ embeds = []
50
+ embeds.append(
51
+ text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
52
+ )
53
+
54
+ if "<image>" not in prompt:
55
+ embeds.append(text_emb(_tokenize(prompt)))
56
+ else:
57
+ assert prompt.count("<image>") == 1
58
+ before, after = prompt.split("<image>")
59
+ if len(before) > 0:
60
+ embeds.append(text_emb(_tokenize(before)))
61
+ embeds.append(image_embeds.to(self.device))
62
+ if len(after) > 0:
63
+ embeds.append(text_emb(_tokenize(after)))
64
+
65
+ return torch.cat(embeds, dim=1)
66
+
67
+ def get_input_embeddings(self):
68
+ return self.text_model.get_input_embeddings()
69
+
70
+ def generate(
71
+ self,
72
+ image_embeds,
73
+ prompt,
74
+ tokenizer,
75
+ max_new_tokens=128,
76
+ **kwargs,
77
+ ):
78
+ generate_config = {
79
+ "eos_token_id": tokenizer.eos_token_id,
80
+ "bos_token_id": tokenizer.bos_token_id,
81
+ "pad_token_id": tokenizer.bos_token_id,
82
+ "max_new_tokens": max_new_tokens,
83
+ **kwargs,
84
+ }
85
+
86
+ with torch.no_grad():
87
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
88
+ attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
89
+ output_ids = self.text_model.generate(
90
+ inputs_embeds=inputs_embeds,
91
+ attention_mask=attention_mask,
92
+ **generate_config,
93
+ )
94
+
95
+ return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
96
+
97
+ # Note: Not ready for use yet, intended for September release.
98
+ def caption(
99
+ self,
100
+ images: List[Image.Image],
101
+ tokenizer,
102
+ length: Optional[Literal["short"]] = None,
103
+ **kwargs,
104
+ ):
105
+ image_embeds = self.encode_image(images)
106
+
107
+ templated_prompts = [
108
+ f"<image>\n\n{'Short caption' if length == 'short' else 'Caption'}:" for _ in images
109
+ ]
110
+ inputs_embeds = torch.stack([
111
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
112
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
113
+ ])
114
+ attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
115
+
116
+ generate_config = {
117
+ "eos_token_id": tokenizer.eos_token_id,
118
+ "bos_token_id": tokenizer.bos_token_id,
119
+ "pad_token_id": tokenizer.bos_token_id,
120
+ "repetition_penalty": 1.2,
121
+ "max_new_tokens": 512,
122
+ **kwargs,
123
+ }
124
+
125
+ with torch.no_grad():
126
+ output_ids = self.text_model.generate(
127
+ inputs_embeds=inputs_embeds,
128
+ attention_mask=attention_mask,
129
+ **generate_config,
130
+ )
131
+
132
+ return [
133
+ x.strip()
134
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
135
+ ]
136
+
137
+ def answer_question(
138
+ self,
139
+ image_embeds,
140
+ question,
141
+ tokenizer,
142
+ chat_history="",
143
+ result_queue=None,
144
+ max_new_tokens=256,
145
+ **kwargs,
146
+ ):
147
+ prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
148
+ answer = self.generate(
149
+ image_embeds,
150
+ prompt,
151
+ tokenizer=tokenizer,
152
+ max_new_tokens=max_new_tokens,
153
+ **kwargs,
154
+ )[0]
155
+ cleaned_answer = answer.strip()
156
+
157
+ # Use the result_queue to pass the result if it is provided
158
+ if result_queue:
159
+ result_queue.put(cleaned_answer)
160
+ else:
161
+ return cleaned_answer
162
+
163
+ def batch_answer(
164
+ self,
165
+ images,
166
+ prompts,
167
+ tokenizer,
168
+ **kwargs,
169
+ ):
170
+ image_embeds = self.encode_image(images)
171
+
172
+ templated_prompts = [
173
+ f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
174
+ ]
175
+ prompt_embs = [
176
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
177
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
178
+ ]
179
+
180
+ bos_emb = prompt_embs[0][0]
181
+ max_len = max([p.shape[0] for p in prompt_embs])
182
+
183
+ inputs_embeds = torch.cat(
184
+ [
185
+ torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
186
+ for p in prompt_embs
187
+ ],
188
+ dim=0,
189
+ )
190
+ attention_mask = torch.cat(
191
+ [
192
+ torch.cat(
193
+ [
194
+ torch.zeros(
195
+ 1,
196
+ max_len - p.shape[0],
197
+ device=self.device,
198
+ dtype=torch.long,
199
+ ),
200
+ torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
201
+ ],
202
+ dim=1,
203
+ )
204
+ for p in prompt_embs
205
+ ],
206
+ dim=0,
207
+ )
208
+
209
+ generate_config = {
210
+ "eos_token_id": tokenizer.eos_token_id,
211
+ "bos_token_id": tokenizer.bos_token_id,
212
+ "pad_token_id": tokenizer.bos_token_id,
213
+ "max_new_tokens": 512,
214
+ **kwargs,
215
+ }
216
+
217
+ with torch.no_grad():
218
+ output_ids = self.text_model.generate(
219
+ inputs_embeds=inputs_embeds,
220
+ attention_mask=attention_mask,
221
+ **generate_config,
222
+ )
223
+
224
+ return [
225
+ x.strip()
226
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
227
+ ]
228
+
229
+ def detect(self, image: Image.Image, query: str, tokenizer):
230
+ pass
region_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .fourier_features import FourierFeatures
4
+
5
+ class RegionModel(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ self.position_features = FourierFeatures(2, 256)
10
+ self.position_encoder = nn.Linear(256, 2048)
11
+ self.size_features = FourierFeatures(2, 256)
12
+ self.size_encoder = nn.Linear(256, 2048)
13
+
14
+ self.position_decoder = nn.Linear(2048, 2)
15
+ self.size_decoder = nn.Linear(2048, 2)
16
+ self.confidence_decoder = nn.Linear(2048, 1)
17
+
18
+ def encode_position(self, position):
19
+ return self.position_encoder(self.position_features(position))
20
+
21
+ def encode_size(self, size):
22
+ return self.size_encoder(self.size_features(size))
23
+
24
+ def decode_position(self, x):
25
+ return self.position_decoder(x)
26
+
27
+ def decode_size(self, x):
28
+ return self.size_decoder(x)
29
+
30
+ def decode_confidence(self, x):
31
+ return self.confidence_decoder(x)
32
+
33
+ def encode(self, position, size):
34
+ return torch.stack(
35
+ [self.encode_position(position), self.encode_size(size)], dim=0
36
+ )
37
+
38
+ def decode(self, position_logits, size_logits):
39
+ return (
40
+ self.decode_position(position_logits),
41
+ self.decode_size(size_logits),
42
+ self.decode_confidence(size_logits),
43
+ )
vision_encoder.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import PIL.Image
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from einops import rearrange
8
+ import PIL
9
+ from torchvision.transforms.v2 import (
10
+ Compose,
11
+ Resize,
12
+ InterpolationMode,
13
+ ToImage,
14
+ ToDtype,
15
+ Normalize,
16
+ )
17
+ from transformers.utils import is_flash_attn_2_available
18
+
19
+ try:
20
+ if is_flash_attn_2_available():
21
+ from flash_attn.modules.mha import FlashSelfAttention
22
+ else:
23
+ FlashSelfAttention = None
24
+ except ImportError:
25
+ FlashSelfAttention = None
26
+
27
+
28
+ class Attention(nn.Module):
29
+
30
+ def __init__(self, dim, num_heads=16, use_flash_attn=False):
31
+ super().__init__()
32
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
33
+
34
+ self.num_heads = num_heads
35
+ self.head_dim = dim // num_heads
36
+
37
+ self.qkv = nn.Linear(dim, dim * 3)
38
+ self.proj = nn.Linear(dim, dim)
39
+
40
+ if use_flash_attn and FlashSelfAttention is not None:
41
+ self.flash_attn = FlashSelfAttention()
42
+ else:
43
+ self.flash_attn = None
44
+
45
+ torch.nn.init.kaiming_normal_(
46
+ self.qkv.weight, mode="fan_in", nonlinearity="relu"
47
+ )
48
+ torch.nn.init.kaiming_normal_(
49
+ self.proj.weight, mode="fan_in", nonlinearity="relu"
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ if self.flash_attn is not None:
54
+ qkv = self.qkv(x)
55
+ qkv = rearrange(
56
+ qkv, "... (three h d) -> ... three h d", three=3, h=self.num_heads
57
+ )
58
+ attn_output = self.flash_attn(qkv)
59
+ output = rearrange(attn_output, "... h d -> ... (h d)")
60
+ output = self.proj(output)
61
+ return output
62
+ else:
63
+ B, N, C = x.shape
64
+ qkv = (
65
+ self.qkv(x)
66
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
67
+ .permute(2, 0, 3, 1, 4)
68
+ )
69
+ q, k, v = qkv.unbind(0)
70
+
71
+ x = F.scaled_dot_product_attention(q, k, v)
72
+
73
+ x = x.transpose(1, 2).reshape(B, N, C)
74
+ x = self.proj(x)
75
+ return x
76
+
77
+
78
+ class VitBlock(nn.Module):
79
+
80
+ def __init__(self, embed_dim, use_flash_attn=False):
81
+ super().__init__()
82
+ self.attn = Attention(embed_dim, use_flash_attn=use_flash_attn)
83
+ self.mlp = MLP(embed_dim, 4304)
84
+ self.norm1 = nn.LayerNorm(embed_dim)
85
+ self.norm2 = nn.LayerNorm(embed_dim)
86
+
87
+ def forward(self, x):
88
+ x = x + self.attn(self.norm1(x))
89
+ x = x + self.mlp(self.norm2(x))
90
+ return x
91
+
92
+
93
+ class VisionTransformer(nn.Module):
94
+
95
+ def __init__(self, use_flash_attn=False):
96
+ super().__init__()
97
+
98
+ embed_len = 729
99
+ embed_dim = 1152
100
+
101
+ self.patch_embed = LinearPatchEmbedding()
102
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
103
+ self.blocks = nn.Sequential(
104
+ *[VitBlock(embed_dim, use_flash_attn=use_flash_attn) for _ in range(27)]
105
+ )
106
+ self.norm = nn.LayerNorm(embed_dim)
107
+
108
+ def forward(self, x):
109
+ x = self.patch_embed(x)
110
+ x = x + self.pos_embed
111
+ for block in self.blocks:
112
+ x = block(x)
113
+ return self.norm(x)
114
+
115
+
116
+ class EncoderWrapper(nn.Module):
117
+
118
+ def __init__(self, use_flash_attn=False):
119
+ super().__init__()
120
+ self.model = nn.ModuleDict({"visual": VisionTransformer(use_flash_attn)})
121
+
122
+ def forward(self, x):
123
+ return self.model["visual"](x)
124
+
125
+
126
+ class LinearPatchEmbedding(nn.Module):
127
+
128
+ def __init__(self):
129
+ super().__init__()
130
+ self.linear = nn.Linear(588, 1152)
131
+
132
+ def forward(self, x):
133
+ b, c, hp1, wp2 = x.shape
134
+ p1, p2 = 14, 14
135
+ h, w = hp1 // p1, wp2 // p2
136
+ x = x.reshape(b, c, h, p1, w, p2)
137
+ x = x.permute(0, 2, 4, 1, 3, 5)
138
+ x = x.reshape(b, h * w, c * p1 * p2)
139
+
140
+ return self.linear(x)
141
+
142
+
143
+ class MLP(nn.Module):
144
+ def __init__(
145
+ self,
146
+ in_features: int,
147
+ hidden_features: int = None,
148
+ out_features: int = None,
149
+ ) -> None:
150
+ super().__init__()
151
+ out_features = out_features or in_features
152
+ hidden_features = hidden_features or in_features
153
+ self.fc1 = nn.Linear(in_features, hidden_features)
154
+ self.act = nn.GELU(approximate="tanh")
155
+ self.fc2 = nn.Linear(hidden_features, out_features)
156
+
157
+ torch.nn.init.kaiming_normal_(
158
+ self.fc1.weight, mode="fan_in", nonlinearity="relu"
159
+ )
160
+ torch.nn.init.kaiming_normal_(
161
+ self.fc2.weight, mode="fan_in", nonlinearity="relu"
162
+ )
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ x = self.fc1(x)
166
+ x = self.act(x)
167
+ x = self.fc2(x)
168
+ return x
169
+
170
+
171
+ class VisionProjection(nn.Module):
172
+ def __init__(self):
173
+ super().__init__()
174
+
175
+ image_embedding_dim = 1152
176
+ model_dim = 2048
177
+ hidden_dim = model_dim * 4
178
+
179
+ self.mlp = MLP(image_embedding_dim * 2, hidden_dim, model_dim)
180
+
181
+ @property
182
+ def device(self):
183
+ return self.mlp.fc1.weight.device
184
+
185
+ def forward(self, x):
186
+ return self.mlp(x)
187
+
188
+
189
+ def create_patches(image, patch_size=(378, 378)):
190
+ assert image.dim() == 3, "Image must be in CHW format"
191
+
192
+ _, height, width = image.shape # Channels, Height, Width
193
+ patch_height, patch_width = patch_size
194
+
195
+ if height == patch_height and width == patch_width:
196
+ return []
197
+
198
+ # Iterate over the image and create patches
199
+ patches = []
200
+ for i in range(0, height, patch_height):
201
+ row_patches = []
202
+ for j in range(0, width, patch_width):
203
+ patch = image[:, i : i + patch_height, j : j + patch_width]
204
+ row_patches.append(patch)
205
+ patches.append(torch.stack(row_patches))
206
+ return patches
207
+
208
+
209
+ class VisionEncoder(nn.Module):
210
+
211
+ def __init__(self, use_flash_attn=False):
212
+ super().__init__()
213
+
214
+ self.encoder = EncoderWrapper(use_flash_attn)
215
+ self.projection = VisionProjection()
216
+ self.supported_sizes = [(378, 378), (378, 756), (756, 378), (756, 756)]
217
+
218
+ @property
219
+ def device(self):
220
+ return self.projection.mlp.fc1.weight.device
221
+
222
+ @property
223
+ def dtype(self):
224
+ return self.projection.mlp.fc1.weight.dtype
225
+
226
+ def preprocess(self, image: PIL.Image.Image):
227
+ width, height = image.size
228
+ max_dim = max(width, height)
229
+ if max_dim < 512:
230
+ im_size = (378, 378)
231
+ else:
232
+ aspect_ratio = width / height
233
+ im_size = min(
234
+ self.supported_sizes,
235
+ key=lambda size: (
236
+ abs((size[1] / size[0]) - aspect_ratio),
237
+ abs(size[0] - width) + abs(size[1] - height),
238
+ ),
239
+ )
240
+
241
+ return Compose(
242
+ [
243
+ Resize(size=im_size, interpolation=InterpolationMode.BICUBIC),
244
+ ToImage(),
245
+ ToDtype(torch.float32, scale=True),
246
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
247
+ ]
248
+ )(image)
249
+
250
+ def forward(
251
+ self, images: Union[PIL.Image.Image, list[PIL.Image.Image], torch.Tensor]
252
+ ) -> torch.Tensor:
253
+ im_list = None
254
+ if isinstance(images, torch.Tensor):
255
+ # Input must have dimensions (B, C, H, W)
256
+ assert (
257
+ len(images.shape) == 4
258
+ ), "Tensor input must have dimensions (B, C, H, W)"
259
+ im_list = list(images)
260
+ elif isinstance(images, PIL.Image.Image):
261
+ im_list = [images]
262
+ elif isinstance(images, list):
263
+ im_list = images
264
+ else:
265
+ raise ValueError(
266
+ "Input must be a PIL image, list of PIL images, or a tensor"
267
+ )
268
+
269
+ # Preprocess unless the images are already tensors (indicating that
270
+ # they have already been preprocessed)
271
+ if not isinstance(im_list[0], torch.Tensor):
272
+ im_list = [self.preprocess(im.convert("RGB")) for im in im_list]
273
+
274
+ patches = [create_patches(im) for im in im_list]
275
+ flat_patches = [patch for image_patches in patches for patch in image_patches]
276
+
277
+ # Images may be variable size, and need to be resized to a common size after
278
+ # creating patches.
279
+ resized_images = [
280
+ F.interpolate(im.unsqueeze(0), size=(378, 378), mode="bilinear")
281
+ for im in im_list
282
+ ]
283
+
284
+ combined_images = torch.cat([*resized_images, *flat_patches], dim=0)
285
+ combined_images = combined_images.to(self.device, dtype=self.dtype)
286
+
287
+ combined_features = self.encoder(combined_images)
288
+
289
+ full_img_features = combined_features[: len(im_list)]
290
+ patch_features = (
291
+ combined_features[len(im_list) :].transpose(1, 2).view(-1, 1152, 27, 27)
292
+ )
293
+
294
+ # Reshape patch features back to their original structure
295
+ reshaped_patch_features = []
296
+ patch_idx = 0
297
+ for i, patch_set in enumerate(patches):
298
+ if len(patch_set) == 0:
299
+ reshaped_patch_features.append(
300
+ full_img_features[i].transpose(0, 1).view(1152, 27, 27)
301
+ )
302
+ else:
303
+ sample_features = []
304
+ for row_patches in patch_set:
305
+ row_len = len(row_patches)
306
+ row_features = patch_features[
307
+ patch_idx : patch_idx + row_len
308
+ ] # row_len, T, C
309
+ row_features = torch.cat(
310
+ list(row_features), dim=2
311
+ ) # T, C * row_len
312
+ patch_idx += row_len
313
+ sample_features.append(row_features)
314
+ sample_features = torch.cat(sample_features, dim=1)
315
+ sample_features = F.interpolate(
316
+ sample_features.unsqueeze(0), size=(27, 27), mode="bilinear"
317
+ ).squeeze(0)
318
+ reshaped_patch_features.append(sample_features)
319
+ reshaped_patch_features = (
320
+ torch.stack(reshaped_patch_features).view(-1, 1152, 729).transpose(1, 2)
321
+ )
322
+
323
+ final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)
324
+
325
+ return self.projection(final_features)