cyq1998 commited on
Commit
579b496
1 Parent(s): db51e6b

upload 8 files

Browse files
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "vit-phi-1.5",
3
+ "activation_function": "gelu_new",
4
+ "architecture": {
5
+ "block_cls": "parallel",
6
+ "mixer": {},
7
+ "mlp": {
8
+ "mlp_cls": "mlp"
9
+ }
10
+ },
11
+ "architectures": [
12
+ "MixFormerVLSequentialForCausalLM"
13
+ ],
14
+ "auto_map": {
15
+ "AutoConfig": "configuration_mixformer_sequential.MixFormerVLSequentialConfig",
16
+ "AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerVLSequentialForCausalLM"
17
+ },
18
+ "embd_layer": "default",
19
+ "embd_pdrop": 0.0,
20
+ "initializer_range": 0.02,
21
+ "layer_norm_epsilon": 1e-05,
22
+ "model_type": "mixformer-sequential",
23
+ "n_embd": 2048,
24
+ "n_head": 32,
25
+ "n_inner": null,
26
+ "n_layer": 24,
27
+ "n_positions": 2048,
28
+ "phyagi_version": "0.0.4.dev",
29
+ "resid_pdrop": 0.0,
30
+ "rotary_dim": 32,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float16",
33
+ "transformers_version": "4.32.1",
34
+ "tokenizer_type": "VitPhiTokenizer",
35
+ "visual": {
36
+ "heads": 16,
37
+ "image_size": 448,
38
+ "image_start_id": 50470,
39
+ "layers": 48,
40
+ "mlp_ratio": 4.9231,
41
+ "output_dim": 4096,
42
+ "patch_size": 14,
43
+ "width": 1664
44
+ },
45
+ "vocab_size": 51200
46
+ }
configuration_vitphi.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class MixFormerVLSequentialConfig(PretrainedConfig):
11
+ """MixFormer (sequential for DeepSpeed) configuration."""
12
+
13
+ model_type = "mixformer-sequential"
14
+
15
+ attribute_map = {
16
+ "max_position_embeddings": "n_positions",
17
+ "hidden_size": "n_embd",
18
+ "num_attention_heads": "n_head",
19
+ "num_hidden_layers": "n_layer",
20
+ "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
21
+ "blocks": "architecture", # `blocks` key is for backward compatibility
22
+ }
23
+
24
+ def __init__(
25
+ self,
26
+ vocab_size: Optional[int] = 50304,
27
+ n_positions: Optional[int] = 2048,
28
+ n_embd: Optional[int] = 1024,
29
+ n_layer: Optional[int] = 20,
30
+ n_inner: Optional[int] = None,
31
+ n_head: Optional[int] = 16,
32
+ rotary_dim: Optional[int] = 32,
33
+ activation_function: Optional[str] = "gelu_new",
34
+ embd_layer: Optional[str] = "default",
35
+ architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
36
+ embd_pdrop: Optional[float] = 0.0,
37
+ resid_pdrop: Optional[float] = 0.0,
38
+ layer_norm_epsilon: Optional[float] = 1e-5,
39
+ initializer_range: Optional[float] = 0.02,
40
+ tie_word_embeddings: Optional[bool] = False,
41
+ pad_vocab_size_multiple: Optional[int] = 64,
42
+
43
+ # vit_hidden_size: Optional[int] = 4096,
44
+
45
+ **kwargs
46
+ ) -> None:
47
+ self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
48
+ self.n_positions = n_positions
49
+ self.n_embd = n_embd
50
+ self.n_layer = n_layer
51
+ self.n_inner = n_inner
52
+ self.n_head = n_head
53
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
54
+ self.activation_function = activation_function
55
+ self.embd_layer = embd_layer
56
+ self.architecture = architecture
57
+ self.embd_pdrop = embd_pdrop
58
+ self.resid_pdrop = resid_pdrop
59
+ self.layer_norm_epsilon = layer_norm_epsilon
60
+ self.initializer_range = initializer_range
61
+ # self.vit_hidden_size = vit_hidden_size
62
+
63
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 50256,
4
+ "transformers_version": "4.32.1"
5
+
6
+ }
modeling_vitphi.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ # BSD 3-Clause License
5
+ #
6
+ # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
7
+ # All rights reserved.
8
+ #
9
+ # Redistribution and use in source and binary forms, with or without
10
+ # modification, are permitted provided that the following conditions are met:
11
+ #
12
+ # * Redistributions of source code must retain the above copyright notice, this
13
+ # list of conditions and the following disclaimer.
14
+ #
15
+ # * Redistributions in binary form must reproduce the above copyright notice,
16
+ # this list of conditions and the following disclaimer in the documentation
17
+ # and/or other materials provided with the distribution.
18
+ #
19
+ # * Neither the name of the copyright holder nor the names of its
20
+ # contributors may be used to endorse or promote products derived from
21
+ # this software without specific prior written permission.
22
+ #
23
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
24
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
27
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
31
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
+
34
+ from __future__ import annotations
35
+
36
+ import math
37
+ import copy
38
+ from typing import Any, Dict, Optional, Tuple
39
+ from dataclasses import dataclass, field
40
+
41
+ import torch
42
+ import torch.nn as nn
43
+
44
+ from einops import rearrange
45
+ from transformers.activations import ACT2FN
46
+ from transformers import PretrainedConfig, PreTrainedModel
47
+ from transformers.modeling_outputs import CausalLMOutputWithPast
48
+ from .configuration_vitphi import MixFormerVLSequentialConfig
49
+ from .visual import VisionTransformer
50
+
51
+
52
+
53
+ @dataclass
54
+ class InferenceParams:
55
+ """Inference parameters that are passed to the main model in order
56
+ to efficienly calculate and store the context during inference.
57
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
58
+ max_sequence_len: int
59
+ max_batch_size: int
60
+ sequence_len_offset: int = 0
61
+ batch_size_offset: int = 0
62
+ key_value_memory_dict: dict = field(default_factory=dict)
63
+ fused_ft_kernel: bool = False
64
+ lengths_per_sample: Optional[torch.Tensor] = None
65
+
66
+
67
+ class Embedding(nn.Module):
68
+ """Token embedding with dropout."""
69
+
70
+ def __init__(self, config: PretrainedConfig) -> None:
71
+ super().__init__()
72
+
73
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
74
+ self.drop = nn.Dropout(config.embd_pdrop)
75
+
76
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
77
+ input_shape = input_ids.size()
78
+ input_ids = input_ids.view(-1, input_shape[-1])
79
+
80
+ hidden_states = self.wte(input_ids)
81
+ hidden_states = self.drop(hidden_states)
82
+
83
+ return hidden_states
84
+
85
+
86
+ class RotaryEmbedding(nn.Module):
87
+ """PyTorch implementation of `flash-attn` RotaryEmbedding layer.
88
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
89
+
90
+ def __init__(
91
+ self,
92
+ dim: int,
93
+ base: Optional[int] = 10000,
94
+ scale_base: Optional[float] = None,
95
+ device: Optional[str] = None,
96
+ **kwargs,
97
+ ) -> None:
98
+ super().__init__()
99
+
100
+ if scale_base is not None:
101
+ raise NotImplementedError
102
+
103
+ # Generate and save the inverse frequency buffer (non-trainable)
104
+ self.dim = dim
105
+ self.base = base
106
+ self.scale_base = scale_base
107
+ self.device = device
108
+
109
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
110
+ self.register_buffer("inv_freq", inv_freq)
111
+
112
+ scale = (
113
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
114
+ if scale_base is not None
115
+ else None
116
+ )
117
+ self.register_buffer("scale", scale)
118
+
119
+ self._seq_len_cached = 0
120
+ self._cos_cached = None
121
+ self._sin_cached = None
122
+ self._cos_k_cached = None
123
+ self._sin_k_cached = None
124
+
125
+ def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0) -> None:
126
+ # Reset the tables if the sequence length has changed,
127
+ # or if we're on a new device (possibly due to tracing for instance)
128
+ seqlen = x.shape[1] + seqlen_offset
129
+
130
+ # Re-generate the inverse frequency buffer if it's not fp32
131
+ # (for instance if model.half() was called)
132
+ if self.inv_freq.dtype != "torch.float32":
133
+ self.inv_freq = 1.0 / (
134
+ self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
135
+ )
136
+
137
+ if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
138
+ self._seq_len_cached = seqlen
139
+ t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
140
+
141
+ # Don't do einsum, it converts fp32 to fp16
142
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
143
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
144
+ if self.scale is None:
145
+ self._cos_cached = torch.cos(freqs).to(x.dtype)
146
+ self._sin_cached = torch.sin(freqs).to(x.dtype)
147
+ else:
148
+ power = (
149
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
150
+ ) / self.scale_base
151
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
152
+
153
+ # We want the multiplication by scale to happen in fp32
154
+ self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
155
+ self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
156
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
157
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
158
+
159
+ def apply_rotary_emb_qkv(
160
+ self,
161
+ qkv: torch.FloatTensor,
162
+ sin: torch.FloatTensor,
163
+ cos: torch.FloatTensor,
164
+ sin_k: Optional[torch.FloatTensor] = None,
165
+ cos_k: Optional[torch.FloatTensor] = None,
166
+ ) -> torch.FloatTensor:
167
+ _, seqlen, three, _, headdim = qkv.shape
168
+ assert three == 3
169
+
170
+ rotary_seqlen, rotary_dim = cos.shape
171
+ rotary_dim *= 2
172
+ assert rotary_dim <= headdim
173
+ assert seqlen <= rotary_seqlen
174
+
175
+ cos_k = cos if cos_k is None else cos_k
176
+ sin_k = sin if sin_k is None else sin_k
177
+ assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
178
+
179
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
180
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
181
+
182
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
183
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
184
+
185
+ # Splits the queries and keys in half
186
+ q1, q2 = q_rot.chunk(2, dim=-1)
187
+ k1, k2 = k_rot.chunk(2, dim=-1)
188
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
189
+
190
+ # Casts to fp32 are necessary to prevent fp16 overflow issues
191
+ q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
192
+
193
+ # Computes the new keys and queries, recasting to original dtype
194
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
195
+
196
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
197
+
198
+ return torch.cat(
199
+ [
200
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
201
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
202
+ qkv[:, :, 2:3, :, :],
203
+ ],
204
+ axis=2,
205
+ )
206
+
207
+ def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
208
+ """Perform the forward pass.
209
+ Args:
210
+ qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
211
+ seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
212
+ Returns:
213
+ New `qkv` and the cached sinusoids.
214
+ """
215
+
216
+ self._update_cos_sin_cache(qkv, seqlen_offset)
217
+
218
+ return self.apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
219
+
220
+
221
+ def _update_kv_cache(kv, inference_params, layer_idx):
222
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
223
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
224
+ # Pre-allocate memory for key-values for inference.
225
+ num_heads, head_dim = kv.shape[-2:]
226
+ if layer_idx not in inference_params.key_value_memory_dict:
227
+ kv_cache = torch.empty(
228
+ inference_params.max_batch_size, inference_params.max_sequence_len, 2,
229
+ num_heads, head_dim, dtype=kv.dtype, device=kv.device
230
+ )
231
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
232
+ else:
233
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
234
+
235
+ # Adjust key and value for inference
236
+ batch_start = inference_params.batch_size_offset
237
+ batch_end = batch_start + kv.shape[0]
238
+ sequence_start = inference_params.sequence_len_offset
239
+ sequence_end = sequence_start + kv.shape[1]
240
+ assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
241
+ assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
242
+
243
+ assert kv_cache is not None
244
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
245
+ kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
246
+ return kv
247
+
248
+
249
+ class MLP(nn.Module):
250
+ """Multi-Layer Perceptron.
251
+ Reference:
252
+ Attention Is All You Need.
253
+ https://arxiv.org/pdf/1706.03762.pdf.
254
+ """
255
+
256
+ def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None) -> None:
257
+ super().__init__()
258
+
259
+ act_fn = config.activation_function if act_fn is None else act_fn
260
+ assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
261
+
262
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
263
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
264
+
265
+ self.fc1 = nn.Linear(config.n_embd, n_inner)
266
+ self.fc2 = nn.Linear(n_inner, config.n_embd)
267
+ self.act = ACT2FN[act_fn]
268
+
269
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
270
+ error_msgs):
271
+ old_keys = [prefix + "fc_in.weight", prefix + "fc_out.weight", prefix + "fc_in.bias", prefix + "fc_out.bias"]
272
+ new_keys = [prefix + "fc1.weight", prefix + "fc2.weight", prefix + "fc1.bias", prefix + "fc2.bias"]
273
+
274
+ if all(k in state_dict for k in old_keys) and not all(k in state_dict for k in new_keys):
275
+ # Older version of `MLP` saved with different key names.
276
+ for old_key, new_key in zip(old_keys, new_keys):
277
+ state_dict[new_key] = state_dict.pop(old_key)
278
+
279
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
280
+ error_msgs)
281
+
282
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
283
+ hidden_states = self.fc1(hidden_states)
284
+ hidden_states = self.act(hidden_states)
285
+ hidden_states = self.fc2(hidden_states)
286
+
287
+ return hidden_states
288
+
289
+
290
+ class FusedMLP(nn.Module):
291
+ """Fused Multi-Layer Perceptron from `flash-attn`.
292
+ Reference:
293
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
294
+ """
295
+
296
+ def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None,
297
+ raise_on_missing: bool = False) -> None:
298
+ super().__init__()
299
+
300
+ act_fn = config.activation_function if act_fn is None else act_fn
301
+ assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
302
+
303
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
304
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
305
+
306
+ gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"]
307
+ activation = "gelu_approx" if act_fn in gelu_activations else "relu"
308
+
309
+ self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
310
+
311
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
312
+ return self.mlp(hidden_states)
313
+
314
+
315
+ class SelfAttention(nn.Module):
316
+ """Implement the scaled dot product attention with softmax.
317
+ Adapted from https://github.com/Dao-AILab/flash-attention.
318
+ Arguments
319
+ ---------
320
+ softmax_scale: The temperature to use for the softmax attention.
321
+ (default: 1/sqrt(d_keys) where d_keys is computed at
322
+ runtime)
323
+ attention_dropout: The dropout rate to apply to the attention
324
+ (default: 0.0)
325
+ """
326
+
327
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
328
+ super().__init__()
329
+ self.causal = causal
330
+ self.softmax_scale = softmax_scale
331
+ self.drop = nn.Dropout(attention_dropout)
332
+
333
+ def forward(self, qkv, causal=None, key_padding_mask=None):
334
+ """Implements the multihead softmax attention.
335
+ Arguments
336
+ ---------
337
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
338
+ causal: if passed, will override self.causal
339
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
340
+ False means to mask out. (B, S)
341
+ """
342
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
343
+ causal = self.causal if causal is None else causal
344
+ q, k, v = qkv.unbind(dim=2)
345
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
346
+ scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
347
+ if key_padding_mask is not None:
348
+ padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
349
+ device=scores.device)
350
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
351
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
352
+ scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
353
+ if causal:
354
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
355
+ # So we have to construct the mask in float
356
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
357
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
358
+ scores = scores + causal_mask.to(dtype=scores.dtype)
359
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
360
+ attention_drop = self.drop(attention)
361
+ output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
362
+ return output
363
+
364
+
365
+ class CrossAttention(nn.Module):
366
+ """Implement the scaled dot product attention with softmax.
367
+ Adapted from https://github.com/Dao-AILab/flash-attention.
368
+ Arguments
369
+ ---------
370
+ softmax_scale: The temperature to use for the softmax attention.
371
+ (default: 1/sqrt(d_keys) where d_keys is computed at
372
+ runtime)
373
+ attention_dropout: The dropout rate to apply to the attention
374
+ (default: 0.0)
375
+ """
376
+
377
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
378
+ super().__init__()
379
+ self.causal = causal
380
+ self.softmax_scale = softmax_scale
381
+ self.drop = nn.Dropout(attention_dropout)
382
+
383
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
384
+ """Implements the multihead softmax attention.
385
+ Arguments
386
+ ---------
387
+ q: The tensor containing the query. (B, Sq, H, D)
388
+ kv: The tensor containing the key and value. (B, Sk, 2, H, D)
389
+ causal: if passed, will override self.causal
390
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
391
+ False means to mask out. (B, Sk)
392
+ """
393
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
394
+ causal = self.causal if causal is None else causal
395
+ seqlen_k = kv.shape[1]
396
+ assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
397
+ k, v = kv.unbind(dim=2)
398
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
399
+ scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
400
+ if key_padding_mask is not None:
401
+ padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
402
+ device=scores.device)
403
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
404
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
405
+ scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
406
+ if causal:
407
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
408
+ # So we have to construct the mask in float
409
+ causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
410
+ device=scores.device), 1)
411
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
412
+ scores = scores + causal_mask.to(dtype=scores.dtype)
413
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
414
+ attention_drop = self.drop(attention)
415
+ output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
416
+ return output
417
+
418
+
419
+ def find_mha_dims(
420
+ config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
421
+ ) -> Tuple[int, int]:
422
+ """Validate and return the number of heads and head dimension for multi-head attention.
423
+ Args:
424
+ config: Model configuration.
425
+ n_head: Number of heads.
426
+ head_dim: Head dimension.
427
+ Returns:
428
+ Number of heads and head dimension.
429
+ """
430
+
431
+ assert all(
432
+ hasattr(config, attr) for attr in ["n_embd", "n_head"]
433
+ ), "`config` must have `n_embd` and `n_head` attributes."
434
+
435
+ if head_dim is None:
436
+ assert (
437
+ config.n_embd % config.n_head == 0
438
+ ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
439
+
440
+ if n_head is None and head_dim is None:
441
+ head_dim = config.n_embd // config.n_head
442
+ n_head = config.n_head
443
+ elif n_head is None or head_dim is None:
444
+ raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
445
+
446
+ return n_head, head_dim
447
+
448
+
449
+ class MHA(nn.Module):
450
+ """Multi-head attention layer.
451
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
452
+
453
+ def __init__(
454
+ self,
455
+ config: PretrainedConfig,
456
+ rotary_dim: Optional[int] = None,
457
+ n_head: Optional[int] = None,
458
+ head_dim: Optional[int] = None,
459
+ bias: Optional[bool] = True,
460
+ dropout: Optional[float] = 0.0,
461
+ softmax_scale: Optional[float] = None,
462
+ causal: Optional[bool] = True,
463
+ layer_idx: Optional[int] = None,
464
+ rotary_emb_scale_base: Optional[float] = None,
465
+ return_residual: Optional[bool] = False,
466
+ checkpointing: Optional[bool] = False,
467
+ device: Optional[str] = None,
468
+ dtype: Optional[torch.dtype] = None,
469
+ fused_dense: Optional[bool] = True,
470
+ flash_attn: Optional[bool] = True,
471
+ cutlass_attn: Optional[bool] = False,
472
+ flash_rotary: Optional[bool] = True,
473
+ raise_on_missing: Optional[bool] = False
474
+ ) -> None:
475
+ super().__init__()
476
+
477
+ factory_kwargs = {"device": device, "dtype": dtype}
478
+ n_head, head_dim = find_mha_dims(config, n_head, head_dim)
479
+
480
+ self.hidden_size = config.n_embd
481
+ self.n_head = n_head
482
+ self.head_dim = head_dim
483
+ self.op_size = n_head * head_dim
484
+
485
+ self.causal = causal
486
+ self.layer_idx = layer_idx
487
+ self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
488
+ self.fused_dense = fused_dense
489
+ self.flash_attn = flash_attn
490
+ self.cutlass_attn = cutlass_attn
491
+ self.flash_rotary = flash_rotary
492
+ self.return_residual = return_residual
493
+ self.checkpointing = checkpointing
494
+
495
+ if self.rotary_emb_dim > 0:
496
+ rotary_kwargs = {"device": device}
497
+ if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
498
+ rotary_kwargs["scale_base"] = rotary_emb_scale_base
499
+
500
+ self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
501
+ else:
502
+ pass
503
+
504
+ self.Wqkv = nn.Linear(self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs)
505
+ self.out_proj = nn.Linear(self.op_size, self.hidden_size, bias=bias, **factory_kwargs)
506
+
507
+ self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
508
+ self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
509
+
510
+ def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
511
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
512
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
513
+
514
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
515
+
516
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
517
+
518
+ def forward(
519
+ self,
520
+ x: torch.FloatTensor,
521
+ x_kv: Optional[torch.FloatTensor] = None,
522
+ key_padding_mask: Optional[torch.BoolTensor] = None,
523
+ cu_seqlens: Optional[torch.LongTensor] = None,
524
+ max_seqlen: Optional[int] = None,
525
+ mixer_subset: Optional[torch.LongTensor] = None,
526
+ past_cache: Optional[InferenceParams] = None,
527
+ **kwargs
528
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
529
+ """Perform the forward pass.
530
+ Args:
531
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
532
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
533
+ is the is the sum of the sequence lengths in the batch.
534
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
535
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
536
+ (batch, seqlen). Only applicable when not using FlashAttention.
537
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
538
+ of the sequences in the batch, used to index into x. Only applicable when using
539
+ FlashAttention.
540
+ max_seqlen: int. Maximum sequence length in the batch.
541
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
542
+ before applying the query projection. Useful for e.g., ViT where we only care
543
+ about the CLS token in the last layer.
544
+ past_cache: For generation only.
545
+ Returns:
546
+ (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
547
+ else (total, hidden_dim) where total is the is the sum of the sequence lengths
548
+ in the batch.
549
+ """
550
+
551
+ if cu_seqlens is not None:
552
+ assert max_seqlen is not None
553
+ assert key_padding_mask is None
554
+ assert self.flash_attn
555
+ assert self.rotary_emb_dim == 0
556
+
557
+ if key_padding_mask is not None:
558
+ assert cu_seqlens is None
559
+ assert max_seqlen is None
560
+ assert not self.flash_attn
561
+
562
+ if past_cache is not None:
563
+ assert key_padding_mask is None
564
+ assert cu_seqlens is None and max_seqlen is None
565
+
566
+ attn_kwargs = {"key_padding_mask": key_padding_mask}
567
+
568
+ assert x_kv is None and mixer_subset is None
569
+
570
+ qkv = self.Wqkv(x)
571
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
572
+
573
+ if past_cache is None:
574
+ if self.rotary_emb_dim > 0:
575
+ qkv = self.rotary_emb(qkv)
576
+ context = self.inner_attn(qkv, **attn_kwargs)
577
+
578
+ else:
579
+ if self.rotary_emb_dim > 0:
580
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
581
+ q = qkv[:, :, 0]
582
+ kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
583
+ # If we're processing the prompt, causal=None (use self.causal).
584
+ # If we're decoding, then causal=False.
585
+ causal = None if past_cache.sequence_len_offset == 0 else False
586
+ context = self.inner_cross_attn(q, kv, causal=causal)
587
+
588
+ out = rearrange(context, "... h d -> ... (h d)")
589
+ out = self.out_proj(out)
590
+
591
+ return out if not self.return_residual else (out, x)
592
+
593
+
594
+ class ParallelBlock(nn.Module):
595
+ """Parallel block.
596
+ This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
597
+ """
598
+
599
+ def __init__(
600
+ self,
601
+ config: PretrainedConfig,
602
+ mixer: Optional[Dict[str, Any]] = None,
603
+ mlp: Optional[Dict[str, Any]] = None,
604
+ block_idx: Optional[int] = None,
605
+ ) -> None:
606
+ super().__init__()
607
+
608
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
609
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
610
+ self.block_idx = block_idx
611
+
612
+ self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
613
+ mlp_cls = mlp.pop('mlp_cls')
614
+ if mlp_cls == 'fused_mlp':
615
+ self.mlp = FusedMLP(config=config, **mlp)
616
+ else:
617
+ self.mlp = MLP(config=config, **mlp)
618
+
619
+ def forward(self, hidden_states: torch.FloatTensor,
620
+ past_cache: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
621
+ residual = hidden_states
622
+ hidden_states = self.ln(hidden_states)
623
+
624
+ attn_outputs = self.mixer(hidden_states, past_cache=past_cache)
625
+ if isinstance(attn_outputs, tuple):
626
+ attn_outputs = attn_outputs[0]
627
+
628
+ attn_outputs = self.resid_dropout(attn_outputs)
629
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
630
+
631
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
632
+
633
+ return hidden_states
634
+
635
+
636
+ class CausalLMHead(nn.Module):
637
+ """Causal Language Modeling head.
638
+ Reference:
639
+ Improving Language Understanding by Generative Pre-Training.
640
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
641
+ """
642
+
643
+ def __init__(self, config: PretrainedConfig) -> None:
644
+ super().__init__()
645
+
646
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
647
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
648
+
649
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
650
+ hidden_states = self.ln(hidden_states)
651
+ logits = self.linear(hidden_states).to(torch.float32)
652
+
653
+ return logits
654
+
655
+
656
+ class CausalLMLoss(nn.Module):
657
+ """Causal Language Modeling loss.
658
+ Reference:
659
+ Improving Language Understanding by Generative Pre-Training.
660
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
661
+ """
662
+
663
+ def __init__(self, shift_labels: Optional[bool] = True) -> None:
664
+ super().__init__()
665
+
666
+ self.shift_labels = shift_labels
667
+ self.loss_fct = nn.CrossEntropyLoss()
668
+
669
+ def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
670
+ if self.shift_labels:
671
+ logits = logits[..., :-1, :].contiguous()
672
+ labels = labels[..., 1:].contiguous()
673
+
674
+ loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
675
+
676
+ return loss
677
+
678
+
679
+ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
680
+ """MixFormer (sequential for DeepSpeed) pre-trained model."""
681
+
682
+ config_class = MixFormerVLSequentialConfig
683
+ base_model_prefix = "transformer"
684
+ supports_gradient_checkpointing = True
685
+
686
+ def __init__(self, *inputs, **kwargs) -> None:
687
+ super().__init__(*inputs, **kwargs)
688
+
689
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs) -> Dict[str, Any]:
690
+ if "use_cache" in kwargs and not kwargs["use_cache"]:
691
+ return {"input_ids": input_ids}
692
+
693
+ if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
694
+ past_key_values = InferenceParams(
695
+ max_batch_size=input_ids.shape[0],
696
+ max_sequence_len=self.config.n_positions,
697
+ sequence_len_offset=0,
698
+ batch_size_offset=0,
699
+ fused_ft_kernel=False,
700
+ key_value_memory_dict={},
701
+ )
702
+ else:
703
+ # assume past_key_values has cached all but last token in input_ids
704
+ past_key_values.sequence_len_offset = len(input_ids[0]) - 1
705
+ input_ids = input_ids[:, -1].unsqueeze(-1)
706
+
707
+ return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
708
+
709
+
710
+ class MixFormerVLSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
711
+ """MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
712
+
713
+ _keys_to_ignore_on_load_missing = [""]
714
+ _keys_to_ignore_on_load_unexpected = [r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
715
+ _no_split_modules = ["ParallelBlock"]
716
+
717
+ def __init__(self, config: MixFormerVLSequentialConfig) -> None:
718
+ super().__init__(config)
719
+
720
+ modules = [Embedding(config)]
721
+ block_config = config.architecture
722
+
723
+ if not isinstance(block_config, list):
724
+ block_config = [block_config for _ in range(config.n_layer)]
725
+
726
+ if config.n_layer != len(block_config):
727
+ config.n_layer = len(block_config)
728
+
729
+ for block_idx, block in enumerate(block_config):
730
+ # `block_cls` with `legacy` value is for backward compatibility
731
+ # `path` key is for backward compatibility
732
+ block = copy.deepcopy(block) or {"block_cls": "parallel"}
733
+ block_cls = block.pop("path", None) or block.pop("block_cls", None)
734
+
735
+ block["block_idx"] = block_idx
736
+ modules.append(ParallelBlock(config, **block))
737
+
738
+ modules.append(CausalLMHead(config))
739
+
740
+ self.layers = nn.Sequential(*modules)
741
+ self.loss = CausalLMLoss()
742
+ self.visual = VisionTransformer(**config.visual)
743
+ self.switcher = nn.Linear(config.visual.output_dim, config.n_embd, bias=False)
744
+
745
+ self.post_init()
746
+
747
+ def get_input_embeddings(self) -> nn.Embedding:
748
+ return self.layers[0].wte
749
+
750
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
751
+ self.layers[0].wte = new_embeddings
752
+
753
+ def get_output_embeddings(self) -> nn.Linear:
754
+ return self.layers[-1].linear
755
+
756
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
757
+ self.layers[-1].linear = new_embeddings
758
+
759
+ def forward(
760
+ self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None,
761
+ past_key_values: Optional[torch.FloatTensor] = None, **kwargs
762
+ ) -> CausalLMOutputWithPast:
763
+ if past_key_values is None and input_ids is not None \
764
+ and torch.any(input_ids == self.config.visual['image_start_id']):
765
+ bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
766
+ eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1) # image_end_id = image_start_id + 1
767
+ assert (bos_pos[0] == eos_pos[0]).all() # 断言batch中的每个样本都有图片的起始和终止符
768
+ img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
769
+ images = []
770
+ for i, a, b in img_pos:
771
+ image = input_ids[i][a+1: b-1].tolist()
772
+ image = image[ : image.index(self.config.visual['image_start_id'] + 2)] # image_pad_id = image_start_id + 2
773
+ images.append(bytes(image).decode('utf-8'))
774
+
775
+ images = self.visual.encode(images)
776
+ assert images.shape[0] == len(images)
777
+ else:
778
+ images = None
779
+
780
+ hidden_states = self.layers[0](input_ids)
781
+ if images is not None:
782
+ for idx, (i, a, b) in enumerate(img_pos):
783
+ hidden_states[i][a + 1: b] = self.switcher(images[idx])
784
+ if not past_key_values:
785
+ for module in self.layers[1:-1]:
786
+ hidden_states = module(hidden_states)
787
+ else:
788
+ for module in self.layers[1:-1]:
789
+ hidden_states = module(hidden_states, past_cache=past_key_values)
790
+ lm_logits = self.layers[-1](hidden_states)
791
+
792
+ loss = None
793
+ if labels is not None:
794
+ loss = self.loss(lm_logits, labels)
795
+
796
+ return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
tokenization_vitphi.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import requests
12
+ import unicodedata
13
+ from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
14
+
15
+ import tiktoken
16
+ import numpy as np
17
+ from PIL import Image
18
+ from PIL import ImageFont
19
+ from PIL import ImageDraw
20
+ from transformers import PreTrainedTokenizer, AddedToken
21
+ from transformers.utils import try_to_load_from_cache
22
+
23
+ import matplotlib.colors as mcolors
24
+ from matplotlib.font_manager import FontProperties
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.tiktoken", "ttf": "SimSun.ttf"}
29
+ # FONT_PATH = try_to_load_from_cache("Qwen/Qwen-VL-Chat", "SimSun.ttf")
30
+ FONT_PATH = None
31
+ if FONT_PATH is None:
32
+ if not os.path.exists("SimSun.ttf"):
33
+ ttf = requests.get("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/SimSun.ttf")
34
+ open("SimSun.ttf", "wb").write(ttf.content)
35
+ FONT_PATH = "SimSun.ttf"
36
+
37
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
38
+ ENDOFTEXT = "<|endoftext|>"
39
+ # <|endoftext|> 50256
40
+ IMSTART = "<|im_start|>"
41
+ IMEND = "<|im_end|>"
42
+ # as the default behavior is changed to allow special tokens in
43
+ # regular texts, the surface forms of special tokens need to be
44
+ # as different as possible to minimize the impact
45
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
46
+ SPECIAL_TOKENS = (
47
+ # ENDOFTEXT,
48
+ IMSTART,
49
+ IMEND,
50
+ ) + EXTRAS
51
+ IMG_TOKEN_SPAN = 256
52
+
53
+
54
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
55
+ with open(tiktoken_bpe_file, "rb") as f:
56
+ contents = f.read()
57
+ return {
58
+ base64.b64decode(token): int(rank)
59
+ for token, rank in (line.split() for line in contents.splitlines() if line)
60
+ }
61
+
62
+
63
+ def _list_find(
64
+ input_list: List[Any],
65
+ candidates: Tuple[Any],
66
+ start: int = 0,
67
+ ):
68
+ for i in range(start, len(input_list)):
69
+ if input_list[i] in candidates:
70
+ return i
71
+ return -1
72
+
73
+
74
+ def _replace_closed_tag(
75
+ input_tokens: List[Any],
76
+ start_tags: Union[Any, Tuple[Any]],
77
+ end_tags: Union[Any, Tuple[Any]],
78
+ inclusive_replace_func: Callable,
79
+ exclusive_replace_func: Callable = lambda x: x,
80
+ ):
81
+ if isinstance(start_tags, (str, int)):
82
+ start_tags = (start_tags,)
83
+ if isinstance(end_tags, (str, int)):
84
+ end_tags = (end_tags,)
85
+ assert len(start_tags) == len(end_tags)
86
+
87
+ output_tokens = []
88
+ end = 0
89
+ while True:
90
+ start = _list_find(input_tokens, start_tags, end)
91
+ if start == -1:
92
+ break
93
+ output_tokens.extend(exclusive_replace_func(input_tokens[end: start]))
94
+ tag_idx = start_tags.index(input_tokens[start])
95
+ end = _list_find(input_tokens, (end_tags[tag_idx],), start)
96
+ if end == -1:
97
+ raise ValueError("Unclosed image token")
98
+ output_tokens.extend(inclusive_replace_func(input_tokens[start: end + 1]))
99
+ end += 1
100
+ output_tokens.extend(exclusive_replace_func(input_tokens[end:]))
101
+ return output_tokens
102
+
103
+
104
+ class VitPhiTokenizer(PreTrainedTokenizer):
105
+ """VitPhi tokenizer."""
106
+
107
+ vocab_files_names = VOCAB_FILES_NAMES
108
+
109
+ def __init__(
110
+ self,
111
+ vocab_file,
112
+ errors="replace",
113
+ image_start_tag='<img>',
114
+ image_end_tag='</img>',
115
+ image_pad_tag='<imgpad>',
116
+ ref_start_tag='<ref>',
117
+ ref_end_tag='</ref>',
118
+ box_start_tag='<box>',
119
+ box_end_tag='</box>',
120
+ quad_start_tag='<quad>',
121
+ quad_end_tag='</quad>',
122
+ **kwargs,
123
+ ):
124
+ super().__init__(**kwargs)
125
+ self.image_start_tag = image_start_tag
126
+ self.image_end_tag = image_end_tag
127
+ self.image_pad_tag = image_pad_tag
128
+ self.ref_start_tag = ref_start_tag
129
+ self.ref_end_tag = ref_end_tag
130
+ self.box_start_tag = box_start_tag
131
+ self.box_end_tag = box_end_tag
132
+ self.quad_start_tag = quad_start_tag
133
+ self.quad_end_tag = quad_end_tag
134
+ self.IMAGE_ST = (
135
+ ref_start_tag, ref_end_tag,
136
+ box_start_tag, box_end_tag,
137
+ quad_start_tag, quad_end_tag,
138
+ image_start_tag, image_end_tag,
139
+ image_pad_tag
140
+ )
141
+
142
+ self.errors = errors # how to handle errors in decoding
143
+
144
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
145
+ self.special_tokens = {
146
+ token: index
147
+ for index, token in enumerate(
148
+ SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
149
+ )
150
+ }
151
+ self.img_start_id = self.special_tokens[self.image_start_tag]
152
+ self.img_end_id = self.special_tokens[self.image_end_tag]
153
+ self.img_pad_id = self.special_tokens[self.image_pad_tag]
154
+ self.ref_start_id = self.special_tokens[self.ref_start_tag]
155
+ self.ref_end_id = self.special_tokens[self.ref_end_tag]
156
+ self.box_start_id = self.special_tokens[self.box_start_tag]
157
+ self.box_end_id = self.special_tokens[self.box_end_tag]
158
+ self.quad_start_id = self.special_tokens[self.quad_start_tag]
159
+ self.quad_end_id = self.special_tokens[self.quad_end_tag]
160
+
161
+ enc = tiktoken.Encoding(
162
+ "VitPhi",
163
+ pat_str=PAT_STR,
164
+ mergeable_ranks=self.mergeable_ranks,
165
+ special_tokens=self.special_tokens,
166
+ )
167
+ assert (
168
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
169
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
170
+
171
+ self.decoder = {
172
+ v: k for k, v in self.mergeable_ranks.items()
173
+ } # type: dict[int, bytes|str]
174
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
175
+
176
+ self.tokenizer = enc # type: tiktoken.Encoding
177
+
178
+ self.eod_id = self.tokenizer.eot_token
179
+ self.im_start_id = self.special_tokens[IMSTART]
180
+ self.im_end_id = self.special_tokens[IMEND]
181
+
182
+ def __len__(self) -> int:
183
+ return self.tokenizer.n_vocab
184
+
185
+ def get_vocab(self) -> Dict[bytes, int]:
186
+ return self.mergeable_ranks
187
+
188
+ def convert_tokens_to_ids(
189
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
190
+ ) -> List[int]:
191
+ ids = []
192
+ if isinstance(tokens, (str, bytes)):
193
+ if tokens in self.special_tokens:
194
+ return self.special_tokens[tokens]
195
+ else:
196
+ return self.mergeable_ranks.get(tokens)
197
+ for token in tokens:
198
+ if token in self.special_tokens:
199
+ ids.append(self.special_tokens[token])
200
+ else:
201
+ ids.append(self.mergeable_ranks.get(token))
202
+ return ids
203
+
204
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
205
+ if not special_tokens and new_tokens:
206
+ raise ValueError('Adding regular tokens is not supported')
207
+ for token in new_tokens:
208
+ surface_form = token.content if isinstance(token, AddedToken) else token
209
+ if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST:
210
+ raise ValueError('Adding unknown special tokens is not supported')
211
+ return 0
212
+
213
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
214
+ """
215
+ Save only the vocabulary of the tokenizer (vocabulary).
216
+
217
+ Returns:
218
+ `Tuple(str)`: Paths to the files saved.
219
+ """
220
+ file_path = os.path.join(save_directory, "vocab.tiktoken")
221
+ with open(file_path, "w", encoding="utf8") as w:
222
+ for k, v in self.mergeable_ranks.items():
223
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
224
+ w.write(line)
225
+ return (file_path,)
226
+
227
+ def tokenize(
228
+ self,
229
+ text: str,
230
+ allowed_special: Union[Set, str] = "all",
231
+ disallowed_special: Union[Collection, str] = (),
232
+ **kwargs,
233
+ ) -> List[Union[bytes, str]]:
234
+ """
235
+ Converts a string in a sequence of tokens.
236
+
237
+ Args:
238
+ text (`str`):
239
+ The sequence to be encoded.
240
+ allowed_special (`Literal["all"]` or `set`):
241
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
242
+ Default to "all".
243
+ disallowed_special (`Literal["all"]` or `Collection`):
244
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
245
+ Default to an empty tuple.
246
+
247
+ kwargs (additional keyword arguments, *optional*):
248
+ Will be passed to the underlying model specific encode method.
249
+
250
+ Returns:
251
+ `List[bytes|str]`: The list of tokens.
252
+ """
253
+ tokens = []
254
+ text = unicodedata.normalize("NFC", text)
255
+
256
+ # this implementation takes a detour: text -> token id -> token surface forms
257
+ for t in self.tokenizer.encode(
258
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
259
+ ):
260
+ tokens.append(self.decoder[t])
261
+
262
+ def _encode_imgurl(img_tokens):
263
+ assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag
264
+ img_tokens = img_tokens[1:-1]
265
+ img_url = b''.join(img_tokens)
266
+ out_img_tokens = list(map(self.decoder.get, img_url))
267
+ if len(out_img_tokens) > IMG_TOKEN_SPAN:
268
+ raise ValueError("The content in {}..{} is too long".format(
269
+ self.image_start_tag, self.image_end_tag))
270
+ out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens)))
271
+ out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
272
+ return out_img_tokens
273
+
274
+ return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
275
+
276
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
277
+ """
278
+ Converts a sequence of tokens in a single string.
279
+ """
280
+ text = ""
281
+ temp = b""
282
+ for t in tokens:
283
+ if isinstance(t, str):
284
+ if temp:
285
+ text += temp.decode("utf-8", errors=self.errors)
286
+ temp = b""
287
+ text += t
288
+ elif isinstance(t, bytes):
289
+ temp += t
290
+ else:
291
+ raise TypeError("token should only be of type types or str")
292
+ if temp:
293
+ text += temp.decode("utf-8", errors=self.errors)
294
+ return text
295
+
296
+ @property
297
+ def vocab_size(self):
298
+ return self.tokenizer.n_vocab
299
+
300
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
301
+ """Converts an id to a token, special tokens included"""
302
+ if index in self.decoder:
303
+ return self.decoder[index]
304
+ raise ValueError("unknown ids")
305
+
306
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
307
+ """Converts a token to an id using the vocab, special tokens included"""
308
+ if token in self.special_tokens:
309
+ return self.special_tokens[token]
310
+ if token in self.mergeable_ranks:
311
+ return self.mergeable_ranks[token]
312
+ raise ValueError("unknown token")
313
+
314
+ def _tokenize(self, text: str, **kwargs):
315
+ """
316
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
317
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
318
+
319
+ Do NOT take care of added tokens.
320
+ """
321
+ raise NotImplementedError
322
+
323
+ def _decode(
324
+ self,
325
+ token_ids: Union[int, List[int]],
326
+ skip_special_tokens: bool = False,
327
+ errors: str = None,
328
+ **kwargs,
329
+ ) -> str:
330
+ if isinstance(token_ids, int):
331
+ token_ids = [token_ids]
332
+
333
+ def _decode_imgurl(img_token_ids):
334
+ assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id
335
+ img_token_ids = img_token_ids[1:-1]
336
+ img_token_ids = img_token_ids[: img_token_ids.index(self.img_pad_id)]
337
+ img_url = bytes(img_token_ids).decode('utf-8')
338
+ return [self.img_start_id] + self.tokenizer.encode(img_url) + [self.img_end_id]
339
+
340
+ token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl)
341
+
342
+ if skip_special_tokens:
343
+ token_ids = [i for i in token_ids if i < self.eod_id]
344
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
345
+
346
+ def to_list_format(self, text: str):
347
+ text = unicodedata.normalize("NFC", text)
348
+ token_ids = self.tokenizer.encode(
349
+ text, allowed_special=set(self.IMAGE_ST + (ENDOFTEXT,)))
350
+
351
+ def _encode_vl_info(tokens):
352
+ if len(tokens) == 0:
353
+ return []
354
+ if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
355
+ key = 'image'
356
+ elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id:
357
+ key = 'ref'
358
+ elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id:
359
+ key = 'box'
360
+ elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id:
361
+ key = 'quad'
362
+ else:
363
+ _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
364
+ return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}]
365
+ _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
366
+ val = b''.join(map(_tobytes, map(self.decoder.get, tokens[1:-1]))).decode('utf-8')
367
+ return [{key: val}]
368
+
369
+ return _replace_closed_tag(
370
+ token_ids,
371
+ (self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id),
372
+ (self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id),
373
+ _encode_vl_info,
374
+ _encode_vl_info,
375
+ )
376
+
377
+ def from_list_format(self, list_format: List[Dict]):
378
+ text = ''
379
+ num_images = 0
380
+ for ele in list_format:
381
+ if 'image' in ele:
382
+ num_images += 1
383
+ text += f'Picture {num_images}:'
384
+ text += self.image_start_tag + ele['image'] + self.image_end_tag
385
+ text += '\n'
386
+ elif 'text' in ele:
387
+ text += ele['text']
388
+ elif 'box' in ele:
389
+ if 'ref' in ele:
390
+ text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
391
+ for box in ele['box']:
392
+ text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
393
+ else:
394
+ raise ValueError("Unsupport element: " + str(ele))
395
+ return text
396
+
397
+ def _fetch_latest_picture(self, response, history):
398
+ if history is None:
399
+ history = []
400
+ _history = history + [(response, None)]
401
+ for q, r in _history[::-1]:
402
+ for ele in self.to_list_format(q)[::-1]:
403
+ if 'image' in ele:
404
+ return ele['image']
405
+ return None
406
+
407
+ def _fetch_all_box_with_ref(self, text):
408
+ list_format = self.to_list_format(text)
409
+ output = []
410
+ for i, ele in enumerate(list_format):
411
+ if 'box' in ele:
412
+ bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
413
+ assert len(bbox) == 4
414
+ output.append({'box': bbox})
415
+ if i > 0 and 'ref' in list_format[i - 1]:
416
+ output[-1]['ref'] = list_format[i - 1]['ref'].strip()
417
+ return output
418
+
419
+ def draw_bbox_on_latest_picture(
420
+ self,
421
+ response,
422
+ history=None,
423
+ ) -> Optional[Image.Image]:
424
+ image = self._fetch_latest_picture(response, history)
425
+ if image is None:
426
+ return None
427
+ if image.startswith("http://") or image.startswith("https://"):
428
+ image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
429
+ h, w = image.height, image.width
430
+ else:
431
+ image = np.asarray(Image.open(image).convert("RGB"))
432
+ h, w = image.shape[0], image.shape[1]
433
+ visualizer = Visualizer(image)
434
+
435
+ boxes = self._fetch_all_box_with_ref(response)
436
+ if not boxes:
437
+ return None
438
+ color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color
439
+ for box in boxes:
440
+ if 'ref' in box: # random new color for new refexps
441
+ color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()])
442
+ x1, y1, x2, y2 = box['box']
443
+ x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
444
+ visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color)
445
+ if 'ref' in box:
446
+ visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left")
447
+ return visualizer.output
448
+
449
+
450
+ import colorsys
451
+ import logging
452
+ import math
453
+ import numpy as np
454
+ import matplotlib as mpl
455
+ import matplotlib.colors as mplc
456
+ import matplotlib.figure as mplfigure
457
+ import torch
458
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
459
+ from PIL import Image
460
+ import random
461
+
462
+ logger = logging.getLogger(__name__)
463
+
464
+
465
+ class VisImage:
466
+ def __init__(self, img, scale=1.0):
467
+ self.img = img
468
+ self.scale = scale
469
+ self.width, self.height = img.shape[1], img.shape[0]
470
+ self._setup_figure(img)
471
+
472
+ def _setup_figure(self, img):
473
+ fig = mplfigure.Figure(frameon=False)
474
+ self.dpi = fig.get_dpi()
475
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
476
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
477
+ fig.set_size_inches(
478
+ (self.width * self.scale + 1e-2) / self.dpi,
479
+ (self.height * self.scale + 1e-2) / self.dpi,
480
+ )
481
+ self.canvas = FigureCanvasAgg(fig)
482
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
483
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
484
+ ax.axis("off")
485
+ self.fig = fig
486
+ self.ax = ax
487
+ self.reset_image(img)
488
+
489
+ def reset_image(self, img):
490
+ img = img.astype("uint8")
491
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
492
+
493
+ def save(self, filepath):
494
+ self.fig.savefig(filepath)
495
+
496
+ def get_image(self):
497
+ canvas = self.canvas
498
+ s, (width, height) = canvas.print_to_buffer()
499
+
500
+ buffer = np.frombuffer(s, dtype="uint8")
501
+
502
+ img_rgba = buffer.reshape(height, width, 4)
503
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
504
+ return rgb.astype("uint8")
505
+
506
+
507
+ class Visualizer:
508
+ def __init__(self, img_rgb, metadata=None, scale=1.0):
509
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
510
+ self.font_path = FONT_PATH
511
+ self.output = VisImage(self.img, scale=scale)
512
+ self.cpu_device = torch.device("cpu")
513
+
514
+ # too small texts are useless, therefore clamp to 14
515
+ self._default_font_size = max(
516
+ np.sqrt(self.output.height * self.output.width) // 30, 15 // scale
517
+ )
518
+
519
+ def draw_text(
520
+ self,
521
+ text,
522
+ position,
523
+ *,
524
+ font_size=None,
525
+ color="g",
526
+ horizontal_alignment="center",
527
+ rotation=0,
528
+ ):
529
+ if not font_size:
530
+ font_size = self._default_font_size
531
+
532
+ # since the text background is dark, we don't want the text to be dark
533
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
534
+ color[np.argmax(color)] = max(0.8, np.max(color))
535
+
536
+ x, y = position
537
+ self.output.ax.text(
538
+ x,
539
+ y,
540
+ text,
541
+ size=font_size * self.output.scale,
542
+ fontproperties=FontProperties(fname=self.font_path),
543
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
544
+ verticalalignment="top",
545
+ horizontalalignment=horizontal_alignment,
546
+ color=color,
547
+ zorder=10,
548
+ rotation=rotation,
549
+ )
550
+ return self.output
551
+
552
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
553
+ x0, y0, x1, y1 = box_coord
554
+ width = x1 - x0
555
+ height = y1 - y0
556
+
557
+ linewidth = max(self._default_font_size / 4, 1)
558
+
559
+ self.output.ax.add_patch(
560
+ mpl.patches.Rectangle(
561
+ (x0, y0),
562
+ width,
563
+ height,
564
+ fill=False,
565
+ edgecolor=edge_color,
566
+ linewidth=linewidth * self.output.scale,
567
+ alpha=alpha,
568
+ linestyle=line_style,
569
+ )
570
+ )
571
+ return self.output
572
+
573
+ def get_output(self):
574
+ return self.output
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_max_length": 2048,
3
+ "tokenizer_class": "VitPhiTokenizer",
4
+ "auto_map": {
5
+ "AutoTokenizer": [
6
+ "tokenization_vitphi.VitPhiTokenizer",
7
+ null
8
+ ]
9
+ }
10
+ }
visual.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+ import requests
9
+ from io import BytesIO
10
+ from functools import partial
11
+ from PIL import Image
12
+ from typing import Callable, Optional, Sequence, Tuple, List
13
+ import numpy as np
14
+
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.init import trunc_normal_
19
+ from torchvision import transforms
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+
23
+ def get_abs_pos(abs_pos, tgt_size):
24
+ # abs_pos: L, C
25
+ # tgt_size: M
26
+ # return: M, C
27
+ src_size = int(math.sqrt(abs_pos.size(0)))
28
+ tgt_size = int(math.sqrt(tgt_size))
29
+ dtype = abs_pos.dtype
30
+
31
+ if src_size != tgt_size:
32
+ return F.interpolate(
33
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
34
+ size=(tgt_size, tgt_size),
35
+ mode="bicubic",
36
+ align_corners=False,
37
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
38
+ else:
39
+ return abs_pos
40
+
41
+
42
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
43
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
44
+ """
45
+ grid_size: int of the grid height and width
46
+ return:
47
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
48
+ """
49
+ grid_h = np.arange(grid_size, dtype=np.float32)
50
+ grid_w = np.arange(grid_size, dtype=np.float32)
51
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
52
+ grid = np.stack(grid, axis=0)
53
+
54
+ grid = grid.reshape([2, 1, grid_size, grid_size])
55
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
56
+ if cls_token:
57
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
58
+ return pos_embed
59
+
60
+
61
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
62
+ assert embed_dim % 2 == 0
63
+
64
+ # use half of dimensions to encode grid_h
65
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
66
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
67
+
68
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
69
+ return emb
70
+
71
+
72
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
73
+ """
74
+ embed_dim: output dimension for each position
75
+ pos: a list of positions to be encoded: size (M,)
76
+ out: (M, D)
77
+ """
78
+ assert embed_dim % 2 == 0
79
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
80
+ omega /= embed_dim / 2.
81
+ omega = 1. / 10000 ** omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = np.sin(out) # (M, D/2)
87
+ emb_cos = np.cos(out) # (M, D/2)
88
+
89
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
90
+ return emb
91
+
92
+
93
+ class Resampler(nn.Module):
94
+ """
95
+ A 2D perceiver-resampler network with one cross attention layers by
96
+ (grid_size**2) learnable queries and 2d sincos pos_emb
97
+ Outputs:
98
+ A tensor with the shape of (grid_size**2, embed_dim)
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ grid_size,
104
+ embed_dim,
105
+ num_heads,
106
+ kv_dim=None,
107
+ norm_layer=nn.LayerNorm
108
+ ):
109
+ super().__init__()
110
+ self.num_queries = grid_size ** 2
111
+ self.embed_dim = embed_dim
112
+ self.num_heads = num_heads
113
+
114
+ self.pos_embed = nn.Parameter(
115
+ torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
116
+ ).requires_grad_(False)
117
+
118
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
119
+ trunc_normal_(self.query, std=.02)
120
+
121
+ if kv_dim is not None and kv_dim != embed_dim:
122
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
123
+ else:
124
+ self.kv_proj = nn.Identity()
125
+
126
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
127
+ self.ln_q = norm_layer(embed_dim)
128
+ self.ln_kv = norm_layer(embed_dim)
129
+
130
+ self.apply(self._init_weights)
131
+
132
+ def _init_weights(self, m):
133
+ if isinstance(m, nn.Linear):
134
+ trunc_normal_(m.weight, std=.02)
135
+ if isinstance(m, nn.Linear) and m.bias is not None:
136
+ nn.init.constant_(m.bias, 0)
137
+ elif isinstance(m, nn.LayerNorm):
138
+ nn.init.constant_(m.bias, 0)
139
+ nn.init.constant_(m.weight, 1.0)
140
+
141
+ def forward(self, x, attn_mask=None):
142
+
143
+ pos_embed = get_abs_pos(self.pos_embed, x.size(1))
144
+
145
+ x = self.kv_proj(x)
146
+ x = self.ln_kv(x).permute(1, 0, 2)
147
+
148
+ N = x.shape[1]
149
+ q = self.ln_q(self.query)
150
+ out = self.attn(
151
+ self._repeat(q, N) + self.pos_embed.unsqueeze(1),
152
+ x + pos_embed.unsqueeze(1),
153
+ x,
154
+ attn_mask=attn_mask)[0]
155
+ return out.permute(1, 0, 2)
156
+
157
+ def _repeat(self, query, N: int):
158
+ return query.unsqueeze(1).repeat(1, N, 1)
159
+
160
+
161
+ class VisualAttention(nn.Module):
162
+ """self-attention layer class.
163
+
164
+ Self-attention layer takes input with size [s, b, h]
165
+ and returns output of the same size.
166
+ """
167
+
168
+ def __init__(self, embed_dim, num_heads,
169
+ bias=True, kdim=None, vdim=None):
170
+ super(VisualAttention, self).__init__()
171
+ self.embed_dim = embed_dim
172
+ self.kdim = kdim if kdim is not None else embed_dim
173
+ self.vdim = vdim if vdim is not None else embed_dim
174
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
175
+
176
+ self.num_heads = num_heads
177
+
178
+ # Per attention head and per partition values.
179
+ assert embed_dim % num_heads == 0
180
+ self.hidden_size_per_attention_head = embed_dim // num_heads
181
+ self.num_attention_heads_per_partition = num_heads
182
+ self.hidden_size_per_partition = embed_dim
183
+
184
+ # Strided linear layer.
185
+ assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
186
+ self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
187
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
188
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
189
+
190
+ def forward(self, query, key, value, attn_mask=None):
191
+ # query/key/value: [sq, b, h]
192
+ sq, b, _ = query.size()
193
+
194
+ assert query is key, 'Only Support Self-Attention Currently'
195
+ sk = sq
196
+ mixed_x_layer = self.in_proj(query)
197
+
198
+ # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
199
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
200
+ (self.num_attention_heads_per_partition,
201
+ 3 * self.hidden_size_per_attention_head)
202
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
203
+
204
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
205
+ query_layer, key_layer, value_layer = mixed_x_layer.split(
206
+ self.hidden_size_per_attention_head, dim=-1)
207
+
208
+ # [sq, b, np, hn] -> [sq, b * np, hn]
209
+ query_layer = query_layer.view(sq,
210
+ b * self.num_attention_heads_per_partition,
211
+ self.hidden_size_per_attention_head).transpose(0, 1)
212
+ # [sk, b, np, hn] -> [sk, b * np, hn]
213
+ key_layer = key_layer.view(sk,
214
+ b * self.num_attention_heads_per_partition,
215
+ self.hidden_size_per_attention_head).transpose(0, 1)
216
+
217
+ q_scaled = query_layer / self.norm_factor
218
+ if attn_mask is not None:
219
+ attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
220
+ else:
221
+ attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
222
+ attention_probs = attention_probs.softmax(dim=-1)
223
+
224
+ value_layer = value_layer.view(sk,
225
+ b * self.num_attention_heads_per_partition,
226
+ self.hidden_size_per_attention_head).transpose(0, 1)
227
+
228
+ # matmul: [b * np, sq, hn]
229
+ context_layer = torch.bmm(attention_probs, value_layer)
230
+
231
+ # change view [b, np, sq, hn]
232
+ context_layer = context_layer.view(b,
233
+ self.num_attention_heads_per_partition,
234
+ sq, self.hidden_size_per_attention_head)
235
+
236
+ # [b, np, sq, hn] --> [sq, b, np, hn]
237
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
238
+
239
+ # [sq, b, np, hn] --> [sq, b, hp]
240
+ new_context_layer_shape = context_layer.size()[:-2] + \
241
+ (self.hidden_size_per_partition,)
242
+ context_layer = context_layer.view(*new_context_layer_shape)
243
+
244
+ output = self.out_proj(context_layer)
245
+
246
+ return output
247
+
248
+
249
+ class VisualAttentionBlock(nn.Module):
250
+ def __init__(
251
+ self,
252
+ d_model: int,
253
+ n_head: int,
254
+ mlp_ratio: float = 4.0,
255
+ act_layer: Callable = nn.GELU,
256
+ norm_layer: Callable = nn.LayerNorm,
257
+ is_cross_attention: bool = False,
258
+ ):
259
+ super().__init__()
260
+
261
+ self.ln_1 = norm_layer(d_model)
262
+ if is_cross_attention:
263
+ self.ln_1_kv = norm_layer(d_model)
264
+
265
+ self.ln_2 = norm_layer(d_model)
266
+ mlp_width = int(d_model * mlp_ratio)
267
+ self.attn = VisualAttention(d_model, n_head)
268
+ self.mlp = nn.Sequential(OrderedDict([
269
+ ("c_fc", nn.Linear(d_model, mlp_width)),
270
+ ("gelu", act_layer()),
271
+ ("c_proj", nn.Linear(mlp_width, d_model))
272
+ ]))
273
+
274
+ def attention(
275
+ self,
276
+ q_x: torch.Tensor,
277
+ k_x: Optional[torch.Tensor] = None,
278
+ v_x: Optional[torch.Tensor] = None,
279
+ attn_mask: Optional[torch.Tensor] = None,
280
+ ):
281
+ k_x = k_x if k_x is not None else q_x
282
+ v_x = v_x if v_x is not None else q_x
283
+
284
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
285
+ return self.attn(q_x, k_x, v_x, attn_mask=attn_mask)
286
+
287
+ def forward(
288
+ self,
289
+ q_x: torch.Tensor,
290
+ k_x: Optional[torch.Tensor] = None,
291
+ v_x: Optional[torch.Tensor] = None,
292
+ attn_mask: Optional[torch.Tensor] = None,
293
+ ):
294
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
295
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
296
+
297
+ x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
298
+ x = x + self.mlp(self.ln_2(x))
299
+ return x
300
+
301
+
302
+ class TransformerBlock(nn.Module):
303
+ def __init__(
304
+ self,
305
+ width: int,
306
+ layers: int,
307
+ heads: int,
308
+ mlp_ratio: float = 4.0,
309
+ act_layer: Callable = nn.GELU,
310
+ norm_layer: Callable = nn.LayerNorm,
311
+ ):
312
+ super().__init__()
313
+ self.width = width
314
+ self.layers = layers
315
+
316
+ self.resblocks = nn.ModuleList([
317
+ VisualAttentionBlock(
318
+ width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer)
319
+ for _ in range(layers)
320
+ ])
321
+
322
+ def get_cast_dtype(self) -> torch.dtype:
323
+ return self.resblocks[0].mlp.c_fc.weight.dtype
324
+
325
+ def get_cast_device(self) -> torch.device:
326
+ return self.resblocks[0].mlp.c_fc.weight.device
327
+
328
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
329
+ for r in self.resblocks:
330
+ x = r(x, attn_mask=attn_mask)
331
+ return x
332
+
333
+
334
+ class VisionTransformer(nn.Module):
335
+
336
+ def __init__(
337
+ self,
338
+ image_size: int,
339
+ patch_size: int,
340
+ width: int,
341
+ layers: int,
342
+ heads: int,
343
+ mlp_ratio: float,
344
+ n_queries: int = 256,
345
+ output_dim: int = 512,
346
+ **kwargs
347
+ ):
348
+ super().__init__()
349
+ image_height, image_width = self.image_size = (image_size, image_size)
350
+ patch_height, patch_width = self.patch_size = (patch_size, patch_size)
351
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
352
+ self.output_dim = output_dim
353
+
354
+ mean = (0.48145466, 0.4578275, 0.40821073)
355
+ std = (0.26862954, 0.26130258, 0.27577711)
356
+ self.image_transform = transforms.Compose([
357
+ transforms.Resize(
358
+ (image_size, image_size),
359
+ interpolation=InterpolationMode.BICUBIC
360
+ ),
361
+ transforms.ToTensor(),
362
+ transforms.Normalize(mean=mean, std=std),
363
+ ])
364
+
365
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
366
+
367
+ # class embeddings and positional embeddings
368
+ scale = width ** -0.5
369
+ self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
370
+
371
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
372
+ act_layer = nn.GELU
373
+
374
+ self.ln_pre = norm_layer(width)
375
+ self.transformer = TransformerBlock(
376
+ width,
377
+ layers,
378
+ heads,
379
+ mlp_ratio,
380
+ act_layer=act_layer,
381
+ norm_layer=norm_layer,
382
+ )
383
+
384
+ self.attn_pool = Resampler(
385
+ grid_size=int(math.sqrt(n_queries)),
386
+ embed_dim=output_dim,
387
+ num_heads=output_dim // 128,
388
+ kv_dim=width,
389
+ norm_layer=norm_layer,
390
+ )
391
+ self.ln_post = norm_layer(output_dim)
392
+ self.proj = nn.Parameter((output_dim ** -0.5) * torch.randn(output_dim, output_dim))
393
+
394
+ def forward(self, x: torch.Tensor):
395
+ x = x.to(
396
+ dtype=self.transformer.get_cast_dtype(),
397
+ device=self.transformer.get_cast_device(),
398
+ )
399
+ # to patches
400
+ x = self.conv1(x) # shape = [*, width, grid, grid]
401
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
402
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
403
+
404
+ x = x + get_abs_pos(self.positional_embedding, x.size(1))
405
+
406
+ x = self.ln_pre(x)
407
+
408
+ x = x.permute(1, 0, 2) # NLD -> LND
409
+ x = self.transformer(x)
410
+ x = x.permute(1, 0, 2) # LND -> NLD
411
+
412
+ x = self.attn_pool(x)
413
+ x = self.ln_post(x)
414
+ x = x @ self.proj
415
+
416
+ return x
417
+
418
+ def encode(self, image_paths: List[str]):
419
+ images = []
420
+ for image_path in image_paths:
421
+ if image_path.startswith("http://") or image_path.startswith("https://"):
422
+ image = Image.open(requests.get(image_path, stream=True).raw)
423
+ else:
424
+ image = Image.open(image_path)
425
+ image = image.convert("RGB")
426
+ images.append(self.image_transform(image))
427
+ images = torch.stack(images, dim=0)
428
+ return self(images)
vocab.tiktoken ADDED
The diff for this file is too large to render. See raw diff