Files changed (10) hide show
  1. README.md +1 -60
  2. README_old.md +0 -62
  3. config.json +4 -4
  4. configuration_imp.py +0 -175
  5. model.safetensors +0 -3
  6. modeling_imp.py +0 -1262
  7. pytorch_model.bin +0 -3
  8. tokenizer.json +0 -0
  9. vision_encoder.py +0 -593
  10. vocab.json +0 -0
README.md CHANGED
@@ -1,62 +1,3 @@
1
  ---
2
- license: creativeml-openrail-m
3
- language:
4
- - en
5
- metrics:
6
- - bleu
7
- tags:
8
- - endpoints
9
- - text-generation-inference
10
- inference: true
11
  ---
12
-
13
- <h3 align='center' style='font-size: 24px;'>Blazzing Fast Tiny Vision Language Model</h3>
14
-
15
-
16
- <p align='center', style='font-size: 16px;' >A Custom 3B parameter Model. Built by <a href="https://www.linkedin.com/in/manishkumarthota/">@Manish</a> The model is released for research purposes only, commercial use is not allowed. </p>
17
-
18
- ## How to use
19
-
20
-
21
- **Install dependencies**
22
- ```bash
23
- pip install transformers # latest version is ok, but we recommend v4.31.0
24
- pip install -q pillow accelerate einops
25
- ```
26
-
27
- You can use the following code for model inference. The format of text instruction is similar to [LLaVA](https://github.com/haotian-liu/LLaVA).
28
-
29
- ```Python
30
- import torch
31
- from transformers import AutoModelForCausalLM, AutoTokenizer
32
- from PIL import Image
33
-
34
- torch.set_default_device("cuda")
35
-
36
- #Create model
37
- model = AutoModelForCausalLM.from_pretrained(
38
- "ManishThota/CustomModel",
39
- torch_dtype=torch.float16,
40
- device_map="auto",
41
- trust_remote_code=True)
42
- tokenizer = AutoTokenizer.from_pretrained("ManishThota/CustomModel", trust_remote_code=True)
43
-
44
- #function to generate the answer
45
- def predict(question, image_path):
46
- #Set inputs
47
- text = f"USER: <image>\n{question}? ASSISTANT:"
48
- image = Image.open(image_path)
49
-
50
- input_ids = tokenizer(text, return_tensors='pt').input_ids.to('cuda')
51
- image_tensor = model.image_preprocess(image)
52
-
53
- #Generate the answer
54
- output_ids = model.generate(
55
- input_ids,
56
- max_new_tokens=25,
57
- images=image_tensor,
58
- use_cache=True)[0]
59
-
60
- return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
61
-
62
- ```
 
1
  ---
2
+ license: apache-2.0
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README_old.md DELETED
@@ -1,62 +0,0 @@
1
- ---
2
- license: creativeml-openrail-m
3
- language:
4
- - en
5
- metrics:
6
- - bleu
7
- ---
8
- <h1 align='center' style='font-size: 36px; font-weight: bold;'>Sparrow</h1>
9
- <h3 align='center' style='font-size: 24px;'>Blazzing Fast Tiny Vision Language Model</h3>
10
-
11
-
12
- <p align="center">
13
- <img src="https://cdn-uploads.huggingface.co/production/uploads/650c7fbb8ffe1f53bdbe1aec/DTjDSq2yG-5Cqnk6giPFq.jpeg" width="50%" height="auto"/>
14
- </p>
15
-
16
- <p align='center', style='font-size: 16px;' >A Custom 3B parameter Model Enhanced for Educational Contexts: This specialized model integrates slide-text pairs from machine learning classes, leveraging a unique training approach. It connects a frozen pre-trained vision encoder (SigLip) with a frozen language model (Phi-2) through an innovative projector. The model employs attention mechanisms and language modeling loss to deeply understand and generate educational content, specifically tailored to the context of machine learning education. Built by <a href="https://www.linkedin.com/in/manishkumarthota/">@Manish</a> The model is released for research purposes only, commercial use is not allowed. </p>
17
-
18
- ## How to use
19
-
20
-
21
- **Install dependencies**
22
- ```bash
23
- pip install transformers # latest version is ok, but we recommend v4.31.0
24
- pip install -q pillow accelerate einops
25
- ```
26
-
27
- You can use the following code for model inference. The format of text instruction is similar to [LLaVA](https://github.com/haotian-liu/LLaVA).
28
-
29
- ```Python
30
- import torch
31
- from transformers import AutoModelForCausalLM, AutoTokenizer
32
- from PIL import Image
33
-
34
- torch.set_default_device("cuda")
35
-
36
- #Create model
37
- model = AutoModelForCausalLM.from_pretrained(
38
- "ManishThota/Sparrow",
39
- torch_dtype=torch.float16,
40
- device_map="auto",
41
- trust_remote_code=True)
42
- tokenizer = AutoTokenizer.from_pretrained("ManishThota/SparrowVQE", trust_remote_code=True)
43
-
44
- #function to generate the answer
45
- def predict(question, image_path):
46
- #Set inputs
47
- text = f"USER: <image>\n{question}? ASSISTANT:"
48
- image = Image.open(image_path)
49
-
50
- input_ids = tokenizer(text, return_tensors='pt').input_ids.to('cuda')
51
- image_tensor = model.image_preprocess(image)
52
-
53
- #Generate the answer
54
- output_ids = model.generate(
55
- input_ids,
56
- max_new_tokens=25,
57
- images=image_tensor,
58
- use_cache=True)[0]
59
-
60
- return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
61
-
62
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -1,13 +1,13 @@
1
  {
2
- "_name_or_path": "ManishThota/Sparrow",
3
  "activation_function": "gelu_new",
4
  "architectures": [
5
  "ImpForCausalLM"
6
  ],
7
  "attn_pdrop": 0.0,
8
  "auto_map": {
9
- "AutoConfig": "configuration_imp.ImpConfig",
10
- "AutoModelForCausalLM": "modeling_imp.ImpForCausalLM"
11
  },
12
  "embd_pdrop": 0.0,
13
  "eos_token_id": 50295,
@@ -29,7 +29,7 @@
29
  "mm_vision_select_feature": "patch",
30
  "mm_vision_select_layer": -2,
31
  "mm_vision_tower": "google/siglip-so400m-patch14-384",
32
- "model_type": "Sparrow",
33
  "n_embd": 2560,
34
  "n_head": 32,
35
  "n_head_kv": null,
 
1
  {
2
+ "_name_or_path": "MILVLG/imp-v1-3b",
3
  "activation_function": "gelu_new",
4
  "architectures": [
5
  "ImpForCausalLM"
6
  ],
7
  "attn_pdrop": 0.0,
8
  "auto_map": {
9
+ "AutoConfig": "MILVLG/imp-v1-3b--configuration_imp.ImpConfig",
10
+ "AutoModelForCausalLM": "MILVLG/imp-v1-3b--modeling_imp.ImpForCausalLM"
11
  },
12
  "embd_pdrop": 0.0,
13
  "eos_token_id": 50295,
 
29
  "mm_vision_select_feature": "patch",
30
  "mm_vision_select_layer": -2,
31
  "mm_vision_tower": "google/siglip-so400m-patch14-384",
32
+ "model_type": "imp",
33
  "n_embd": 2560,
34
  "n_head": 32,
35
  "n_head_kv": null,
configuration_imp.py DELETED
@@ -1,175 +0,0 @@
1
-
2
- # ------------------------------- Phi-2 ---------------------------------------------
3
- # Copyright (c) Microsoft Corporation.
4
- # Licensed under the MIT license.
5
- # https://huggingface.co/google/siglip-so400m-patch14-384
6
- #
7
- # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
8
- # Licensed under the BSD 3-Clause License.
9
- # ------------------------------- SigLIP --------------------------------------------
10
- # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
11
- #
12
- # Licensed under the Apache License, Version 2.0 (the "License");
13
- # you may not use this file except in compliance with the License.
14
- # You may obtain a copy of the License at
15
- #
16
- # http://www.apache.org/licenses/LICENSE-2.0
17
- #
18
- # Unless required by applicable law or agreed to in writing, software
19
- # distributed under the License is distributed on an "AS IS" BASIS,
20
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
- # See the License for the specific language governing permissions and
22
- # limitations under the License.
23
- # ------------------------------- Llava ---------------------------------------------
24
- # Copyright 2023 Haotian Liu
25
- #
26
- # Licensed under the Apache License, Version 2.0 (the "License");
27
- # you may not use this file except in compliance with the License.
28
- # You may obtain a copy of the License at
29
- #
30
- # http://www.apache.org/licenses/LICENSE-2.0
31
- #
32
- # Unless required by applicable law or agreed to in writing, software
33
- # distributed under the License is distributed on an "AS IS" BASIS,
34
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35
- # See the License for the specific language governing permissions and
36
- # limitations under the License.
37
- # -----------------------------------------------------------------------------------
38
-
39
-
40
- import os
41
- import math
42
- from typing import Optional, Union
43
-
44
- from transformers import PretrainedConfig
45
- from transformers.utils import logging
46
-
47
- logger = logging.get_logger(__name__)
48
-
49
-
50
- class PhiConfig(PretrainedConfig):
51
- """Phi configuration."""
52
-
53
- model_type = "phi-msft"
54
- attribute_map = {
55
- "max_position_embeddings": "n_positions",
56
- "hidden_size": "n_embd",
57
- "num_attention_heads": "n_head",
58
- "num_hidden_layers": "n_layer",
59
- }
60
-
61
- def __init__(
62
- self,
63
- vocab_size: int = 50304,
64
- n_positions: int = 2048,
65
- n_embd: int = 1024,
66
- n_layer: int = 20,
67
- n_inner: Optional[int] = None,
68
- n_head: int = 16,
69
- n_head_kv: Optional[int] = None,
70
- rotary_dim: Optional[int] = 32,
71
- activation_function: Optional[str] = "gelu_new",
72
- flash_attn: bool = False,
73
- flash_rotary: bool = False,
74
- fused_dense: bool = False,
75
- attn_pdrop: float = 0.0,
76
- embd_pdrop: float = 0.0,
77
- resid_pdrop: float = 0.0,
78
- layer_norm_epsilon: float = 1e-5,
79
- initializer_range: float = 0.02,
80
- tie_word_embeddings: bool = False,
81
- pad_vocab_size_multiple: int = 64,
82
- **kwargs
83
- ) -> None:
84
- self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
85
- self.n_positions = n_positions
86
- self.n_embd = n_embd
87
- self.n_layer = n_layer
88
- self.n_inner = n_inner
89
- self.n_head = n_head
90
- self.n_head_kv = n_head_kv
91
- self.rotary_dim = min(rotary_dim, n_embd // n_head)
92
- self.activation_function = activation_function
93
- self.flash_attn = flash_attn
94
- self.flash_rotary = flash_rotary
95
- self.fused_dense = fused_dense
96
- self.attn_pdrop = attn_pdrop
97
- self.embd_pdrop = embd_pdrop
98
- self.resid_pdrop = resid_pdrop
99
- self.layer_norm_epsilon = layer_norm_epsilon
100
- self.initializer_range = initializer_range
101
-
102
- super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
103
-
104
-
105
-
106
- class SiglipVisionConfig(PretrainedConfig):
107
-
108
- model_type = "siglip_vision_model"
109
-
110
- def __init__(
111
- self,
112
- hidden_size=768,
113
- intermediate_size=3072,
114
- num_hidden_layers=12,
115
- num_attention_heads=12,
116
- num_channels=3,
117
- image_size=224,
118
- patch_size=16,
119
- hidden_act="gelu_pytorch_tanh",
120
- layer_norm_eps=1e-6,
121
- attention_dropout=0.0,
122
- **kwargs,
123
- ):
124
- super().__init__(**kwargs)
125
-
126
- self.hidden_size = hidden_size
127
- self.intermediate_size = intermediate_size
128
- self.num_hidden_layers = num_hidden_layers
129
- self.num_attention_heads = num_attention_heads
130
- self.num_channels = num_channels
131
- self.patch_size = patch_size
132
- self.image_size = image_size
133
- self.attention_dropout = attention_dropout
134
- self.layer_norm_eps = layer_norm_eps
135
- self.hidden_act = hidden_act
136
-
137
- @classmethod
138
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
139
- cls._set_token_in_kwargs(kwargs)
140
-
141
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
142
-
143
- # get the vision config dict if we are loading from SiglipConfig
144
- if config_dict.get("model_type") == "siglip":
145
- config_dict = config_dict["vision_config"]
146
-
147
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
148
- logger.warning(
149
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
150
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
151
- )
152
-
153
- return cls.from_dict(config_dict, **kwargs)
154
-
155
-
156
- class ImpConfig(PhiConfig):
157
- model_type = "imp"
158
-
159
- def __init__(self, **kwargs):
160
- super().__init__(**kwargs)
161
- self.image_token_index = getattr(self, "image_token_index", 50296)
162
- self.image_token = getattr(self, "image_token", "<image>")
163
-
164
- if not hasattr(self, "vision_tower_config") and hasattr(self, "mm_vision_tower"):
165
- vision_tower_config = SiglipVisionConfig.from_pretrained(self.mm_vision_tower)
166
- self.vision_tower_config = vision_tower_config.to_diff_dict()
167
-
168
- @property
169
- def vision_tower_cfg(self):
170
- cfg = SiglipVisionConfig.from_dict(self.vision_tower_config)
171
- # imp-v1 only supports `patch` feature for now w/o cls token
172
- # cfg.mm_vision_select_feature = self.mm_vision_select_feature
173
- cfg.mm_vision_select_layer = self.mm_vision_select_layer
174
- cfg.mm_vision_tower = self.mm_vision_tower
175
- return cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f22e7b5e04ac6d134a269cbb2d6c724aafd81bb4446b3ad567225fb93b757e75
3
- size 6373981888
 
 
 
 
modeling_imp.py DELETED
@@ -1,1262 +0,0 @@
1
- # Copyright (c) MILVLG team.
2
- # Licensed under the Apache 2.0 license.
3
- #
4
- # Some code here is copied from the project Phi-2 (https://huggingface.co/microsoft/phi-2),
5
- # SigLIP@transformers==4.37.0.dev0 (https://huggingface.co/google/siglip-so400m-patch14-384),
6
- # and Llava (https://github.com/haotian-liu/LLaVA), and modified by
7
- # Zhenwei Shao (shaozw@hdu.edu.cn) @ MILVLG. We thank them for their great works.
8
- # And their original licenses and copyright should be inherited (see the statements
9
- # in `configuration_imp.py` for more details).
10
-
11
-
12
- # Be careful: The way how `past_key_values.seqlen_offset` is updated is modified from
13
- # the implementation of original Phi-2. See the comments below for details.
14
-
15
- from __future__ import annotations
16
- import os
17
- import math
18
- import re
19
- from dataclasses import dataclass, field
20
- from typing import Any, Dict, Optional, Tuple, Union, List
21
- from abc import ABC, abstractmethod
22
-
23
- import torch
24
- import torch.nn as nn
25
- from einops import rearrange, repeat
26
- from transformers import (
27
- PretrainedConfig,
28
- PreTrainedModel,
29
- AutoConfig,
30
- AutoModelForCausalLM
31
- )
32
- from transformers.activations import ACT2FN
33
- from transformers.modeling_outputs import CausalLMOutputWithPast
34
- import sys
35
- from .configuration_imp import PhiConfig, ImpConfig
36
- from .vision_encoder import VisionTower
37
-
38
- try:
39
- from flash_attn.bert_padding import pad_input, unpad_input
40
- from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
41
- from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
42
- from flash_attn.ops.fused_dense import FusedDense
43
- except:
44
- pad_input, unpad_input = None, None
45
- FlashRotaryEmbedding = None
46
- FlashSelfAttention, FlashCrossAttention = None, None
47
- FusedDense = None
48
-
49
-
50
- @dataclass
51
- class InferenceParams:
52
- """Inference parameters passed to model to efficiently calculate
53
- and store context during inference.
54
-
55
- Reference:
56
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
57
-
58
- Args:
59
- max_seqlen: Maximum sequence length.
60
- max_batch_size: Maximum batch size.
61
- seqlen_offset: Sequence length offset.
62
- batch_size_offset: Batch size offset.
63
- key_value_memory_dict: Key value memory dictionary.
64
- lengths_per_sample: Lengths per sample.
65
-
66
- """
67
-
68
- max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
69
-
70
- max_batch_size: int = field(metadata={"help": "Maximum batch size."})
71
-
72
- seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
73
-
74
- batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
75
-
76
- key_value_memory_dict: Dict[str, Any] = field(
77
- default_factory=dict, metadata={"help": "Key value memory dictionary."}
78
- )
79
-
80
- lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
81
-
82
-
83
- class Embedding(nn.Module):
84
- """Token embedding with dropout."""
85
-
86
- def __init__(self, config: PretrainedConfig) -> None:
87
- super().__init__()
88
-
89
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
90
- self.drop = nn.Dropout(config.embd_pdrop)
91
-
92
- def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
93
- input_shape = input_ids.size()
94
- input_ids = input_ids.view(-1, input_shape[-1])
95
-
96
- hidden_states = self.wte(input_ids)
97
- hidden_states = self.drop(hidden_states)
98
-
99
- return hidden_states
100
-
101
-
102
-
103
- def _apply_rotary_emb(
104
- x: torch.FloatTensor,
105
- cos: torch.FloatTensor,
106
- sin: torch.FloatTensor,
107
- ) -> torch.FloatTensor:
108
- _, seqlen, _, _ = x.shape
109
- _, rotary_dim = cos.shape
110
- rotary_dim *= 2
111
-
112
- x_rot = x[:, :, :, :rotary_dim]
113
- x_pass = x[:, :, :, rotary_dim:]
114
-
115
- x1, x2 = x_rot.chunk(2, dim=-1)
116
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
117
- x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
118
-
119
- x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
120
-
121
- return torch.cat([x_rot, x_pass], axis=-1)
122
-
123
-
124
- def _apply_rotary_emb_kv(
125
- kv: torch.FloatTensor,
126
- cos: torch.FloatTensor,
127
- sin: torch.FloatTensor,
128
- cos_k: Optional[torch.FloatTensor] = None,
129
- sin_k: Optional[torch.FloatTensor] = None,
130
- ) -> torch.FloatTensor:
131
- _, seqlen, _, _, _ = kv.shape
132
- _, rotary_dim = cos.shape
133
- rotary_dim *= 2
134
-
135
- k_rot = kv[:, :, 0, :, :rotary_dim]
136
- k_pass = kv[:, :, 0, :, rotary_dim:]
137
-
138
- k1, k2 = k_rot.chunk(2, dim=-1)
139
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
140
- k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
141
-
142
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
143
-
144
- return torch.cat(
145
- [
146
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
147
- kv[:, :, 1:2, :, :],
148
- ],
149
- axis=2,
150
- )
151
-
152
-
153
- def _apply_rotary_emb_qkv(
154
- qkv: torch.FloatTensor,
155
- cos: torch.FloatTensor,
156
- sin: torch.FloatTensor,
157
- cos_k: Optional[torch.FloatTensor] = None,
158
- sin_k: Optional[torch.FloatTensor] = None,
159
- ) -> torch.FloatTensor:
160
- _, seqlen, _, _, _ = qkv.shape
161
- _, rotary_dim = cos.shape
162
- rotary_dim *= 2
163
-
164
- q_rot = qkv[:, :, 0, :, :rotary_dim]
165
- q_pass = qkv[:, :, 0, :, rotary_dim:]
166
-
167
- k_rot = qkv[:, :, 1, :, :rotary_dim]
168
- k_pass = qkv[:, :, 1, :, rotary_dim:]
169
-
170
- q1, q2 = q_rot.chunk(2, dim=-1)
171
- k1, k2 = k_rot.chunk(2, dim=-1)
172
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
173
- q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
174
-
175
- q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
176
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
177
-
178
- return torch.cat(
179
- [
180
- torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
181
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
182
- qkv[:, :, 2:3, :, :],
183
- ],
184
- axis=2,
185
- )
186
-
187
-
188
- class RotaryEmbedding(nn.Module):
189
- """Rotary positional embedding (RoPE).
190
-
191
- Reference:
192
- RoFormer: Enhanced Transformer with Rotary Position Embedding.
193
- https://arxiv.org/pdf/2104.09864.pdf.
194
-
195
- """
196
-
197
- def __init__(
198
- self,
199
- dim: int,
200
- base: int = 10000,
201
- scale_base: Optional[float] = None,
202
- pos_idx_in_fp32: bool = True,
203
- max_position_embeddings: int = 2048,
204
- device: Optional[str] = None,
205
- **kwargs,
206
- ) -> None:
207
- super().__init__()
208
-
209
- if scale_base is not None:
210
- raise NotImplementedError
211
-
212
- self.dim = dim
213
- self.base = float(base)
214
- self.scale_base = scale_base
215
- self.pos_idx_in_fp32 = pos_idx_in_fp32
216
- self.max_position_embeddings = max_position_embeddings
217
- self.device = device
218
-
219
- # Generate and save the inverse frequency buffer (non-trainable)
220
- inv_freq = self._compute_inv_freq(device)
221
- self.register_buffer("inv_freq", inv_freq, persistent=False)
222
-
223
- # Generate and save the scale buffer (non-trainable)
224
- scale = (
225
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
226
- if scale_base is not None
227
- else None
228
- )
229
- self.register_buffer("scale", scale, persistent=False)
230
-
231
- # Initialize cached attributes since ONNX can't rely on dynamic initialization
232
- self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
233
-
234
- def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
235
- return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
236
-
237
- def _update_cos_sin_cache(
238
- self,
239
- seqlen: int,
240
- device: Optional[str] = None,
241
- dtype: Optional[torch.dtype] = None,
242
- ) -> None:
243
- self._seq_len_cached = seqlen
244
-
245
- # fp32 is preferred since the output of `torch.arange` can be quite large
246
- # and bf16 would lose a lot of precision
247
- if self.pos_idx_in_fp32:
248
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
249
- if self.inv_freq.dtype != torch.float32:
250
- inv_freq = self._compute_inv_freq(device=device)
251
- else:
252
- inv_freq = self.inv_freq
253
- else:
254
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
255
- inv_freq = self.inv_freq
256
-
257
- # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
258
- freqs = torch.outer(t, inv_freq)
259
- if self.scale is None:
260
- self._cos_cached = torch.cos(freqs).to(dtype)
261
- self._sin_cached = torch.sin(freqs).to(dtype)
262
- else:
263
- power = (
264
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
265
- ) / self.scale_base
266
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
267
-
268
- # Force the scale multiplication to happen in fp32
269
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
270
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
271
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
272
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
273
-
274
- def forward(
275
- self,
276
- qkv: torch.Tensor,
277
- kv: Optional[torch.Tensor] = None,
278
- seqlen_offset: int = 0,
279
- **kwargs,
280
- ) -> Tuple[torch.Tensor, torch.Tensor]:
281
- if (
282
- self._seq_len_cached < qkv.shape[1] + seqlen_offset
283
- or self._cos_cached.device != qkv.device
284
- or self._cos_cached.dtype != qkv.dtype
285
- or (self.training and self._cos_cached.is_inference())
286
- ):
287
- self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
288
-
289
- if kv is None:
290
- return _apply_rotary_emb_qkv(
291
- qkv,
292
- self._cos_cached[seqlen_offset:],
293
- self._sin_cached[seqlen_offset:],
294
- )
295
- else:
296
- q = _apply_rotary_emb(
297
- qkv,
298
- self._cos_cached[seqlen_offset:],
299
- self._sin_cached[seqlen_offset:],
300
- )
301
- kv = _apply_rotary_emb_kv(
302
- kv,
303
- self._cos_cached[seqlen_offset:],
304
- self._sin_cached[seqlen_offset:],
305
- )
306
-
307
- return q, kv
308
-
309
-
310
- class MLP(nn.Module):
311
- """Multi-Layer Perceptron.
312
-
313
- Reference:
314
- Attention Is All You Need.
315
- https://arxiv.org/pdf/1706.03762.pdf.
316
-
317
- """
318
-
319
- def __init__(
320
- self,
321
- config: PretrainedConfig,
322
- n_inner: Optional[int] = None,
323
- act_fn: Optional[str] = None,
324
- ) -> None:
325
- super().__init__()
326
-
327
- act_fn = config.activation_function if act_fn is None else act_fn
328
-
329
- n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
330
- n_inner = n_inner if n_inner is not None else 4 * config.n_embd
331
-
332
- self.fc1 = nn.Linear(config.n_embd, n_inner)
333
- self.fc2 = nn.Linear(n_inner, config.n_embd)
334
- self.act = ACT2FN[act_fn]
335
-
336
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
337
- hidden_states = self.fc1(hidden_states)
338
- hidden_states = self.act(hidden_states)
339
- hidden_states = self.fc2(hidden_states)
340
-
341
- return hidden_states
342
-
343
-
344
- class SelfAttention(nn.Module):
345
- """Self-attention layer (compatible with PyTorch).
346
-
347
- Reference:
348
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
349
-
350
- """
351
-
352
- def __init__(
353
- self,
354
- causal: bool = True,
355
- softmax_scale: Optional[float] = None,
356
- attention_dropout: float = 0.0,
357
- ) -> None:
358
- super().__init__()
359
-
360
- self.causal = causal
361
- self.softmax_scale = softmax_scale
362
- self.drop = nn.Dropout(attention_dropout)
363
-
364
- @torch.autocast("cpu", enabled=False)
365
- @torch.autocast("cuda", enabled=False)
366
- def forward(
367
- self,
368
- qkv: torch.FloatTensor,
369
- causal: bool = None,
370
- key_padding_mask: Optional[torch.BoolTensor] = None,
371
- **kwargs,
372
- ) -> torch.FloatTensor:
373
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
374
- q, k, v = qkv.unbind(dim=2)
375
-
376
- q = q.to(torch.float32)
377
- k = k.to(torch.float32)
378
-
379
- causal = self.causal if causal is None else causal
380
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
381
-
382
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
383
- # using float16, which might lead to overflow
384
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
385
-
386
- if key_padding_mask is not None:
387
- padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
388
- padding_mask.masked_fill_(key_padding_mask, 0.0)
389
-
390
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
391
-
392
- if causal:
393
- causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
394
- scores = scores + causal_mask.to(dtype=scores.dtype)
395
-
396
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
397
- attention = self.drop(attention)
398
-
399
- output = torch.einsum("bhts,bshd->bthd", attention, v)
400
-
401
- return output
402
-
403
-
404
- class CrossAttention(nn.Module):
405
- """Cross-attention layer (compatible with PyTorch).
406
-
407
- Reference:
408
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
409
-
410
- """
411
-
412
- def __init__(
413
- self,
414
- causal: bool = True,
415
- softmax_scale: Optional[float] = None,
416
- attention_dropout: float = 0.0,
417
- ) -> None:
418
- super().__init__()
419
-
420
- self.causal = causal
421
- self.softmax_scale = softmax_scale
422
- self.drop = nn.Dropout(attention_dropout)
423
-
424
- @torch.autocast("cpu", enabled=False)
425
- @torch.autocast("cuda", enabled=False)
426
- def forward(
427
- self,
428
- q: torch.FloatTensor,
429
- kv: torch.FloatTensor,
430
- causal: bool = None,
431
- key_padding_mask: Optional[torch.BoolTensor] = None,
432
- **kwargs,
433
- ) -> torch.FloatTensor:
434
- batch_size, seqlen_q = q.shape[0], q.shape[1]
435
- seqlen_k = kv.shape[1]
436
-
437
- if kv.shape[3] != q.shape[2]:
438
- kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
439
- k, v = kv.unbind(dim=2)
440
-
441
- q = q.to(torch.float32)
442
- k = k.to(torch.float32)
443
-
444
- causal = self.causal if causal is None else causal
445
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
446
-
447
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
448
- # using float16, which might lead to overflow
449
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
450
-
451
- if key_padding_mask is not None:
452
- padding_mask = torch.full(
453
- (batch_size, seqlen_k),
454
- -10000.0,
455
- dtype=scores.dtype,
456
- device=scores.device,
457
- )
458
- padding_mask.masked_fill_(key_padding_mask, 0.0)
459
-
460
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
461
-
462
- if causal:
463
- rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
464
- cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
465
- causal_mask = cols > rows + seqlen_k - seqlen_q
466
-
467
- scores = scores.masked_fill(causal_mask, -10000.0)
468
-
469
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
470
- attention = self.drop(attention)
471
-
472
- output = torch.einsum("bhts,bshd->bthd", attention, v)
473
-
474
- return output
475
-
476
-
477
- def _find_mha_dims(
478
- config: PretrainedConfig,
479
- n_head: Optional[int] = None,
480
- n_head_kv: Optional[int] = None,
481
- head_dim: Optional[int] = None,
482
- ) -> Tuple[int, int]:
483
- if n_head is None and head_dim is None:
484
- head_dim = config.n_embd // config.n_head
485
- n_head = config.n_head
486
- elif n_head is None or head_dim is None:
487
- raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
488
-
489
- if n_head_kv is None:
490
- n_head_kv = getattr(config, "n_head_kv", None) or n_head
491
-
492
- return n_head, n_head_kv, head_dim
493
-
494
-
495
- def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
496
- num_heads, head_dim = kv.shape[-2:]
497
-
498
- if layer_idx not in inference_params.key_value_memory_dict:
499
- inference_params.key_value_memory_dict[layer_idx] = torch.empty(
500
- inference_params.max_batch_size,
501
- inference_params.max_seqlen,
502
- 2,
503
- num_heads,
504
- head_dim,
505
- dtype=kv.dtype,
506
- device=kv.device,
507
- )
508
-
509
- batch_start = inference_params.batch_size_offset
510
- batch_end = batch_start + kv.shape[0]
511
-
512
- sequence_start = inference_params.seqlen_offset
513
- sequence_end = sequence_start + kv.shape[1]
514
-
515
- # When the current sequence length is equal to or larger than the maximum sequence length,
516
- # we need to concatenate the current `kv` with the cached `kv` to expand its length
517
- if sequence_end >= inference_params.max_seqlen:
518
- inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
519
-
520
- inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
521
- kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
522
-
523
- return kv
524
-
525
-
526
- class MHA(nn.Module):
527
- """Multi-head attention layer."""
528
-
529
- def __init__(
530
- self,
531
- config: PretrainedConfig,
532
- dtype: Optional[torch.dtype] = None,
533
- device: Optional[str] = None,
534
- rotary_dim: Optional[int] = None,
535
- rotary_base: float = 10000.0,
536
- rotary_scale_base: Optional[float] = None,
537
- n_head: Optional[int] = None,
538
- n_head_kv: Optional[int] = None,
539
- head_dim: Optional[int] = None,
540
- bias: bool = True,
541
- causal: bool = True,
542
- softmax_scale: Optional[float] = None,
543
- layer_idx: Optional[int] = None,
544
- return_residual: bool = False,
545
- checkpointing: bool = False,
546
- ) -> None:
547
- super().__init__()
548
-
549
- # Rotary embedding
550
- self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
551
- if self.rotary_dim > 0:
552
- rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
553
- if rotary_cls is None:
554
- rotary_cls = RotaryEmbedding
555
-
556
- rotary_kwargs = {}
557
- if rotary_cls is RotaryEmbedding:
558
- rotary_kwargs["max_position_embeddings"] = config.n_positions
559
-
560
- self.rotary_emb = rotary_cls(
561
- self.rotary_dim,
562
- base=rotary_base,
563
- scale_base=rotary_scale_base,
564
- device=device,
565
- **rotary_kwargs,
566
- )
567
-
568
- # MLP
569
- self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
570
- config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
571
- )
572
- op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
573
- hidden_size = config.n_embd
574
-
575
- linear_cls = FusedDense if config.fused_dense else nn.Linear
576
- if linear_cls is None:
577
- linear_cls = nn.Linear
578
-
579
- self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
580
- self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
581
-
582
- # Attention
583
- attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
584
- if attn_cls is None:
585
- attn_cls = SelfAttention
586
-
587
- cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
588
- if cross_attn_cls is None:
589
- cross_attn_cls = CrossAttention
590
-
591
- self.inner_attn = attn_cls(
592
- causal=causal,
593
- softmax_scale=softmax_scale,
594
- attention_dropout=config.attn_pdrop,
595
- )
596
- self.inner_cross_attn = cross_attn_cls(
597
- causal=causal,
598
- softmax_scale=softmax_scale,
599
- attention_dropout=config.attn_pdrop,
600
- )
601
-
602
- self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
603
- self.layer_idx = layer_idx
604
- self.return_residual = return_residual
605
- self.checkpointing = checkpointing
606
-
607
- def _forward_self_attn(
608
- self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
609
- ) -> torch.FloatTensor:
610
- qkv = self.Wqkv(x)
611
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
612
-
613
- if self.rotary_dim > 0:
614
- qkv = self.rotary_emb(qkv)
615
-
616
- if self.flash_attn:
617
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
618
-
619
- cu_seqlens, max_seqlen = None, None
620
- if key_padding_mask is not None:
621
- # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
622
- # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
623
- qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
624
-
625
- if self.checkpointing:
626
- attn_output = torch.utils.checkpoint.checkpoint(
627
- self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
628
- )
629
- else:
630
- attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
631
-
632
- # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
633
- return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
634
-
635
- if self.checkpointing:
636
- return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
637
-
638
- return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
639
-
640
- def _forward_cross_attn(
641
- self,
642
- x: torch.FloatTensor,
643
- past_key_values: Optional[InferenceParams],
644
- key_padding_mask: Optional[torch.BoolTensor],
645
- ) -> torch.FloatTensor:
646
- batch_size = x.shape[0]
647
-
648
- qkv = self.Wqkv(x)
649
-
650
- q = qkv[..., : self.n_head * self.head_dim]
651
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
652
-
653
- kv = qkv[..., self.n_head * self.head_dim :]
654
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
655
-
656
- seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
657
- causal = None if seqlen_offset == 0 else False
658
- if self.rotary_dim > 0:
659
- q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
660
-
661
- if past_key_values is not None:
662
- kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
663
-
664
- if self.flash_attn:
665
- batch_size, seqlen_q = q.shape[0], q.shape[1]
666
- seqlen_k = kv.shape[1]
667
-
668
- cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
669
- None,
670
- None,
671
- None,
672
- None,
673
- )
674
- if key_padding_mask is not None:
675
- kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
676
-
677
- if seqlen_q == 1:
678
- key_padding_mask = torch.ones(batch_size, 1, device=q.device)
679
- elif seqlen_q != seqlen_k:
680
- key_padding_mask = key_padding_mask[:, -seqlen_q:]
681
-
682
- q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
683
-
684
- if self.checkpointing:
685
- attn_output = torch.utils.checkpoint.checkpoint(
686
- self.inner_cross_attn,
687
- q,
688
- kv,
689
- causal=causal,
690
- cu_seqlens=cu_seqlens_q,
691
- max_seqlen=max_seqlen_q,
692
- cu_seqlens_k=cu_seqlens_k,
693
- max_seqlen_k=max_seqlen_k,
694
- )
695
- else:
696
- attn_output = self.inner_cross_attn(
697
- q,
698
- kv,
699
- causal=causal,
700
- cu_seqlens=cu_seqlens_q,
701
- max_seqlen=max_seqlen_q,
702
- cu_seqlens_k=cu_seqlens_k,
703
- max_seqlen_k=max_seqlen_k,
704
- )
705
-
706
- return (
707
- pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
708
- if key_padding_mask is not None
709
- else attn_output
710
- )
711
-
712
- if self.checkpointing:
713
- return torch.utils.checkpoint.checkpoint(
714
- self.inner_cross_attn,
715
- q,
716
- kv,
717
- key_padding_mask=key_padding_mask,
718
- causal=causal,
719
- )
720
-
721
- return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
722
-
723
- def forward(
724
- self,
725
- x: torch.FloatTensor,
726
- past_key_values: Optional[InferenceParams] = None,
727
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
728
- **kwargs,
729
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
730
- if attention_mask is not None:
731
- attention_mask = attention_mask.bool()
732
- else:
733
- attention_mask = None
734
-
735
- # MHA
736
- if self.n_head == self.n_head_kv:
737
- if past_key_values is None:
738
- # If `past_key_values` are not supplied, we run self-attention
739
- attn_output = self._forward_self_attn(x, attention_mask)
740
- else:
741
- # If `past_key_values` are supplied, it means that we might have cached values and
742
- # could take advantage of cross-attention
743
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
744
- # MQA / GQA
745
- else:
746
- # Regardless of `past_key_values` being supplied or not, it always use cross-attention
747
- # because `q` and `kv` lengths might be different
748
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
749
-
750
- output = rearrange(attn_output, "... h d -> ... (h d)")
751
- output = self.out_proj(output)
752
-
753
- return output if not self.return_residual else (output, x)
754
-
755
-
756
- class ParallelBlock(nn.Module):
757
- """Parallel block.
758
-
759
- This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
760
-
761
- """
762
-
763
- def __init__(
764
- self,
765
- config: PretrainedConfig,
766
- block_idx: Optional[int] = None,
767
- ) -> None:
768
- super().__init__()
769
-
770
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
771
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
772
- self.block_idx = block_idx
773
-
774
- self.mixer = MHA(config, layer_idx=block_idx)
775
- self.mlp = MLP(config)
776
-
777
- def forward(
778
- self,
779
- hidden_states: torch.FloatTensor,
780
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
781
- attention_mask: Optional[torch.BoolTensor] = None,
782
- **kwargs,
783
- ) -> torch.FloatTensor:
784
- residual = hidden_states
785
- hidden_states = self.ln(hidden_states)
786
-
787
- attn_outputs = self.mixer(
788
- hidden_states,
789
- past_key_values=past_key_values,
790
- attention_mask=attention_mask,
791
- )
792
- if isinstance(attn_outputs, tuple):
793
- attn_outputs = attn_outputs[0]
794
-
795
- attn_outputs = self.resid_dropout(attn_outputs)
796
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
797
-
798
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
799
-
800
- return hidden_states
801
-
802
-
803
- class CausalLMHead(nn.Module):
804
- """Causal Language Modeling head.
805
-
806
- Reference:
807
- Improving Language Understanding by Generative Pre-Training.
808
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
809
-
810
- """
811
-
812
- def __init__(self, config: PretrainedConfig) -> None:
813
- super().__init__()
814
-
815
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
816
- self.linear = nn.Linear(config.n_embd, config.vocab_size)
817
-
818
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
819
- hidden_states = self.ln(hidden_states)
820
- logits = self.linear(hidden_states).to(torch.float32)
821
-
822
- return logits
823
-
824
-
825
- class PhiPreTrainedModel(PreTrainedModel):
826
- """Phi pre-trained model."""
827
-
828
- config_class = PhiConfig
829
- base_model_prefix = "transformer"
830
- supports_gradient_checkpointing = True
831
- _no_split_modules = ["ParallelBlock", "CLIPEncoderLayer", "Block"]
832
-
833
- def __init__(self, *inputs, **kwargs) -> None:
834
- super().__init__(*inputs, **kwargs)
835
-
836
- def _init_weights(self, module: nn.Module) -> None:
837
- if isinstance(module, (nn.Linear,)):
838
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
839
- if module.bias is not None:
840
- module.bias.data.zero_()
841
- elif isinstance(module, nn.Embedding):
842
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
843
- if module.padding_idx is not None:
844
- module.weight.data[module.padding_idx].zero_()
845
- elif isinstance(module, nn.LayerNorm):
846
- if module.bias is not None:
847
- module.bias.data.zero_()
848
- module.weight.data.fill_(1.0)
849
-
850
- def prepare_inputs_for_generation(
851
- self,
852
- input_ids: torch.LongTensor,
853
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
854
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
855
- **kwargs,
856
- ) -> Dict[str, Any]:
857
- if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
858
- past_key_values = InferenceParams(
859
- max_seqlen=self.config.n_positions,
860
- max_batch_size=input_ids.shape[0],
861
- seqlen_offset=0,
862
- batch_size_offset=0,
863
- key_value_memory_dict={},
864
- lengths_per_sample=None,
865
- )
866
- else:
867
- # ======================================================================
868
- # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
869
- # inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...]
870
- # past_key_values.seqlen_offset = input_ids.shape[1] - 1
871
- # ======================================================================
872
- # I change the way of updating `past_key_values.seqlen_offset` to make the inference of imp work.
873
- # [Edited by zhenwei - 2024-01-20 21:15]
874
- input_ids = input_ids[:, -1].unsqueeze(-1)
875
-
876
- return {
877
- "input_ids": input_ids,
878
- "past_key_values": past_key_values,
879
- "attention_mask": attention_mask,
880
- }
881
-
882
-
883
- class LlavaMetaModel(ABC):
884
- """
885
- Define the APIs for building components that are related to image perceiving.
886
- This implementation is based on the implementation from the Llave project.
887
- """
888
-
889
- def get_vision_tower(self):
890
- vision_tower = getattr(self, 'vision_tower', None)
891
- if type(vision_tower) is list:
892
- vision_tower = vision_tower[0]
893
- return vision_tower
894
-
895
- def build_vision_tower(self, config):
896
- self.vision_tower = VisionTower(config.vision_tower_cfg)
897
-
898
- def build_vision_projector(self, config):
899
- projector_type = getattr(config, 'mm_projector_type', 'linear')
900
-
901
- if projector_type == 'linear':
902
- self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
903
- return
904
-
905
- mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
906
- if mlp_gelu_match:
907
- mlp_depth = int(mlp_gelu_match.group(1))
908
- modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
909
- for _ in range(1, mlp_depth):
910
- modules.append(nn.GELU())
911
- modules.append(nn.Linear(config.hidden_size, config.hidden_size))
912
- self.mm_projector = nn.Sequential(*modules)
913
- return
914
-
915
- if projector_type == 'identity':
916
- self.mm_projector = nn.Identity()
917
- return
918
-
919
- raise ValueError(f'Unknown projector type: {projector_type}')
920
-
921
-
922
- class ImpModel(PhiPreTrainedModel, LlavaMetaModel):
923
- """Imp model. This implementation is modified from the implementation of Phi-2"""
924
-
925
- config_class = ImpConfig
926
- # _keys_to_ignore_on_load_missing = [""]
927
- # _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
928
-
929
- def __init__(self, config: ImpConfig) -> None:
930
- super().__init__(config)
931
-
932
- self.embd = Embedding(config)
933
- self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
934
- self.gradient_checkpointing = False
935
-
936
- if hasattr(config, "mm_vision_tower"):
937
- self.build_vision_tower(config)
938
- self.build_vision_projector(config)
939
-
940
- self.post_init()
941
-
942
- def embed_tokens(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
943
- return self.embd(input_ids)[0]
944
-
945
- def get_input_embeddings(self) -> nn.Embedding:
946
- return self.embd.wte
947
-
948
- def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
949
- self.embd.wte = new_embeddings
950
-
951
- def forward(
952
- self,
953
- input_ids: torch.LongTensor,
954
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
955
- attention_mask: Optional[torch.BoolTensor] = None,
956
- inputs_embeds: Optional[torch.FloatTensor] = None
957
- ) -> torch.FloatTensor:
958
-
959
- if inputs_embeds is None:
960
- hidden_states = self.embd(input_ids)
961
- else:
962
- hidden_states = inputs_embeds
963
-
964
- for layer in self.h:
965
- if self.gradient_checkpointing and self.training:
966
-
967
- def create_custom_forward(module):
968
- def custom_forward(*inputs):
969
- # None for past_key_value
970
- return module(*inputs)
971
-
972
- return custom_forward
973
-
974
- hidden_states = torch.utils.checkpoint.checkpoint(
975
- create_custom_forward(layer),
976
- hidden_states,
977
- None,
978
- attention_mask,
979
- )
980
- else:
981
- hidden_states = layer(
982
- hidden_states,
983
- past_key_values=past_key_values,
984
- attention_mask=attention_mask,
985
- )
986
-
987
- # I change the way of updating `past_key_values.seqlen_offset` to make the inference of imp work.
988
- # [Edited by zhenwei - 2024-01-20 21:15]
989
- if past_key_values is not None: # FIXME: when multi-batch inference, it is a bug
990
- past_key_values.seqlen_offset += hidden_states.shape[1]
991
-
992
- return hidden_states
993
-
994
-
995
- class LlavaMetaForCausalLM(ABC):
996
- """This implementation is based on the implementation from the Llave project."""
997
-
998
- def init_constants(self, config):
999
- self.IGNORE_INDEX = getattr(config, 'ignore_index', -100)
1000
- self.IMAGE_TOKEN_INDEX = getattr(config, 'image_token_index', 50296)
1001
- self.DEFAULT_IMAGE_TOKEN = getattr(config, 'image_token', "<image>")
1002
-
1003
- @abstractmethod
1004
- def get_model(self):
1005
- pass
1006
-
1007
- def get_vision_tower(self):
1008
- return self.get_model().get_vision_tower()
1009
-
1010
- def encode_images(self, images):
1011
- image_features = self.get_model().get_vision_tower()(images)
1012
- image_features = self.get_model().mm_projector(image_features)
1013
- return image_features
1014
-
1015
- def prepare_inputs_labels_for_multimodal(
1016
- self, input_ids, position_ids, attention_mask, past_key_values, labels, images
1017
- ):
1018
- vision_tower = self.get_vision_tower()
1019
- # if vision_tower is None or images is None or past_key_values.seqlen_offset != 0:
1020
- if vision_tower is None or images is None or input_ids.shape[1] == 1:
1021
- if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
1022
- target_shape = past_key_values.seqlen_offset + 1
1023
- # inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...]
1024
- attention_mask = torch.cat((attention_mask, torch.ones(
1025
- (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
1026
- dtype=attention_mask.dtype,
1027
- device=attention_mask.device
1028
- )), dim=1)
1029
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
1030
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
1031
-
1032
- if type(images) is list or images.ndim == 5:
1033
- concat_images = torch.cat([image for image in images], dim=0)
1034
- concat_images = concat_images.to(device=self.device, dtype=vision_tower.dtype)
1035
- image_features = self.encode_images(concat_images)
1036
- split_sizes = [image.shape[0] for image in images]
1037
- image_features = torch.split(image_features, split_sizes, dim=0)
1038
- image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
1039
- else:
1040
- images = images.to(device=self.device, dtype=vision_tower.dtype)
1041
- image_features = self.encode_images(images).to(self.device)
1042
-
1043
- # TODO: image start / end is not implemented here to support pretraining.
1044
- if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
1045
- raise NotImplementedError
1046
-
1047
- # Let's just add dummy tensors if they do not exist,
1048
- # it is a headache to deal with None all the time.
1049
- # But it is not ideal, and if you have a better idea,
1050
- # please open an issue / submit a PR, thanks.
1051
- _labels = labels
1052
- _position_ids = position_ids
1053
- _attention_mask = attention_mask
1054
- if attention_mask is None:
1055
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1056
- else:
1057
- attention_mask = attention_mask.bool()
1058
- if position_ids is None:
1059
- position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
1060
- if labels is None:
1061
- labels = torch.full_like(input_ids, self.IGNORE_INDEX)
1062
-
1063
- # remove the padding using attention_mask -- TODO: double check
1064
- input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
1065
- labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
1066
-
1067
- new_input_embeds = []
1068
- new_labels = []
1069
- cur_image_idx = 0
1070
- for batch_idx, cur_input_ids in enumerate(input_ids):
1071
- num_images = (cur_input_ids == self.IMAGE_TOKEN_INDEX).sum()
1072
- if num_images == 0:
1073
- cur_image_features = image_features[cur_image_idx]
1074
- cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
1075
- cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
1076
- new_input_embeds.append(cur_input_embeds)
1077
- new_labels.append(labels[batch_idx])
1078
- cur_image_idx += 1
1079
- continue
1080
-
1081
- image_token_indices = [-1] + torch.where(cur_input_ids == self.IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
1082
- cur_input_ids_noim = []
1083
- cur_labels = labels[batch_idx]
1084
- cur_labels_noim = []
1085
- for i in range(len(image_token_indices) - 1):
1086
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
1087
- cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
1088
- split_sizes = [x.shape[0] for x in cur_labels_noim]
1089
- cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
1090
- # print(cur_input_embeds.shape)
1091
- cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
1092
- cur_new_input_embeds = []
1093
- cur_new_labels = []
1094
-
1095
- for i in range(num_images + 1):
1096
- cur_new_input_embeds.append(cur_input_embeds_no_im[i])
1097
- cur_new_labels.append(cur_labels_noim[i])
1098
- if i < num_images:
1099
- cur_image_features = image_features[cur_image_idx]
1100
- cur_image_idx += 1
1101
- cur_new_input_embeds.append(cur_image_features)
1102
- cur_new_labels.append(torch.full((cur_image_features.shape[0],), self.IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
1103
-
1104
- cur_new_input_embeds = torch.cat(cur_new_input_embeds)
1105
- cur_new_labels = torch.cat(cur_new_labels)
1106
-
1107
- new_input_embeds.append(cur_new_input_embeds)
1108
- new_labels.append(cur_new_labels)
1109
-
1110
- # Truncate sequences to max length as image embeddings can make the sequence longer
1111
- tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
1112
- if tokenizer_model_max_length is not None:
1113
- new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
1114
- new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
1115
-
1116
- # Combine them
1117
- max_len = max(x.shape[0] for x in new_input_embeds)
1118
- batch_size = len(new_input_embeds)
1119
-
1120
- new_input_embeds_padded = []
1121
- new_labels_padded = torch.full((batch_size, max_len), self.IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
1122
- attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
1123
- position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
1124
-
1125
- for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
1126
- cur_len = cur_new_embed.shape[0]
1127
- if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
1128
- new_input_embeds_padded.append(torch.cat((
1129
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
1130
- cur_new_embed
1131
- ), dim=0))
1132
- if cur_len > 0:
1133
- new_labels_padded[i, -cur_len:] = cur_new_labels
1134
- attention_mask[i, -cur_len:] = True
1135
- position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
1136
- else:
1137
- new_input_embeds_padded.append(torch.cat((
1138
- cur_new_embed,
1139
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
1140
- ), dim=0))
1141
- if cur_len > 0:
1142
- new_labels_padded[i, :cur_len] = cur_new_labels
1143
- attention_mask[i, :cur_len] = True
1144
- position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
1145
-
1146
- new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
1147
-
1148
- if _labels is None:
1149
- new_labels = None
1150
- else:
1151
- new_labels = new_labels_padded
1152
-
1153
- if _attention_mask is None:
1154
- attention_mask = None
1155
- else:
1156
- attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
1157
-
1158
- if _position_ids is None:
1159
- position_ids = None
1160
-
1161
- return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
1162
-
1163
-
1164
- class ImpForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM):
1165
- """Imp for Causal Language Modeling."""
1166
-
1167
- # _keys_to_ignore_on_load_missing = [""]
1168
- # _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
1169
- config_class = ImpConfig
1170
-
1171
- def __init__(self, config: ImpConfig) -> None:
1172
- super().__init__(config)
1173
-
1174
- self.transformer = ImpModel(config)
1175
- self.lm_head = CausalLMHead(config)
1176
-
1177
- self.post_init()
1178
- self.init_constants(config)
1179
-
1180
- def get_output_embeddings(self) -> nn.Linear:
1181
- return self.lm_head.linear
1182
-
1183
- def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1184
- self.lm_head.linear = new_embeddings
1185
-
1186
- def get_model(self):
1187
- return self.transformer
1188
-
1189
- def image_preprocess(self, images):
1190
- return self.get_vision_tower().image_processor(images)['pixel_values']
1191
-
1192
- def backbone_forward(
1193
- self,
1194
- input_ids: torch.LongTensor,
1195
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
1196
- attention_mask: Optional[torch.BoolTensor] = None,
1197
- labels: Optional[torch.LongTensor] = None,
1198
- inputs_embeds: Optional[torch.FloatTensor] = None,
1199
- **kwargs,
1200
- ) -> CausalLMOutputWithPast:
1201
- hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
1202
- lm_logits = self.lm_head(hidden_states)
1203
-
1204
- return CausalLMOutputWithPast(loss=None, logits=lm_logits, past_key_values=past_key_values)
1205
-
1206
- def forward(
1207
- self,
1208
- input_ids: torch.LongTensor = None,
1209
- attention_mask: Optional[torch.Tensor] = None,
1210
- position_ids: Optional[torch.LongTensor] = None,
1211
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1212
- inputs_embeds: Optional[torch.FloatTensor] = None,
1213
- labels: Optional[torch.LongTensor] = None,
1214
- use_cache: Optional[bool] = None,
1215
- output_attentions: Optional[bool] = None,
1216
- output_hidden_states: Optional[bool] = None,
1217
- images: Optional[torch.FloatTensor] = None,
1218
- return_dict: Optional[bool] = None,
1219
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1220
-
1221
- if inputs_embeds is None:
1222
- (
1223
- input_ids,
1224
- position_ids,
1225
- attention_mask,
1226
- past_key_values,
1227
- inputs_embeds,
1228
- labels
1229
- ) = self.prepare_inputs_labels_for_multimodal(
1230
- input_ids,
1231
- position_ids,
1232
- attention_mask,
1233
- past_key_values,
1234
- labels,
1235
- images
1236
- )
1237
-
1238
- return self.backbone_forward(
1239
- input_ids=input_ids,
1240
- attention_mask=attention_mask,
1241
- position_ids=position_ids,
1242
- past_key_values=past_key_values,
1243
- inputs_embeds=inputs_embeds,
1244
- labels=labels,
1245
- use_cache=use_cache,
1246
- output_attentions=output_attentions,
1247
- output_hidden_states=output_hidden_states,
1248
- return_dict=return_dict
1249
- )
1250
-
1251
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
1252
- images = kwargs.pop("images", None)
1253
- _inputs = super().prepare_inputs_for_generation(
1254
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
1255
- )
1256
- if images is not None:
1257
- _inputs['images'] = images
1258
- return _inputs
1259
-
1260
-
1261
- AutoConfig.register("imp", ImpConfig)
1262
- AutoModelForCausalLM.register(ImpConfig, ImpForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c162cd9d0a121183d6c71232d2e8bfbcbd293e9d37b9ecfea8534800a5350efd
3
- size 6374152890
 
 
 
 
tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
vision_encoder.py DELETED
@@ -1,593 +0,0 @@
1
- # Copyright (c) MILVLG team.
2
- # Licensed under the Apache 2.0 license.
3
- #
4
- # Some code here is copied from the project Phi-2 (https://huggingface.co/microsoft/phi-2),
5
- # SigLIP@transformers==4.37.0.dev0 (https://huggingface.co/google/siglip-so400m-patch14-384),
6
- # and Llava (https://github.com/haotian-liu/LLaVA), and modified by
7
- # Zhenwei Shao (shaozw@hdu.edu.cn) @ MILVLG. We thank them for their great works.
8
- # And their original licenses and copyright should be inherited (see the statements
9
- # in `configuration_imp.py` for more details).
10
-
11
-
12
- from typing import Any, Optional, Tuple, Union, List, Dict
13
- from dataclasses import dataclass
14
- import math
15
- import warnings
16
- from functools import partial, reduce
17
-
18
-
19
- import numpy as np
20
- from PIL import Image
21
- import torch
22
- import torch.utils.checkpoint
23
- from torch import nn
24
-
25
- from transformers.image_processing_utils import BatchFeature
26
- from transformers.image_transforms import (
27
- convert_to_rgb,
28
- normalize,
29
- rescale,
30
- resize,
31
- to_channel_dimension_format,
32
- )
33
- from transformers.image_utils import (
34
- ChannelDimension,
35
- PILImageResampling,
36
- to_numpy_array,
37
- )
38
- from transformers.activations import ACT2FN
39
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
40
- from transformers.modeling_utils import PreTrainedModel
41
- from transformers.utils import ModelOutput
42
-
43
- from .configuration_imp import SiglipVisionConfig
44
-
45
-
46
- # ============================================================================
47
- # A simple image preprocessor for SigLIP models.
48
- # ============================================================================
49
-
50
- def simple_image_processor(
51
- images,
52
- image_mean=(0.5, 0.5, 0.5),
53
- image_std=(0.5, 0.5, 0.5),
54
- size=(384, 384),
55
- resample=PILImageResampling.BICUBIC,
56
- rescale_factor=1 / 255,
57
- data_format=ChannelDimension.FIRST,
58
- return_tensors="pt"
59
- ):
60
-
61
- if isinstance(images, Image.Image):
62
- images = [images]
63
- else:
64
- assert isinstance(images, list)
65
-
66
- transforms = [
67
- convert_to_rgb,
68
- to_numpy_array,
69
- partial(resize, size=size, resample=resample, data_format=data_format),
70
- partial(rescale, scale=rescale_factor, data_format=data_format),
71
- partial(normalize, mean=image_mean, std=image_std, data_format=data_format),
72
- partial(to_channel_dimension_format, channel_dim=data_format, input_channel_dim=data_format),
73
- ]
74
-
75
- images = reduce(lambda x, f: [*map(f, x)], transforms, images)
76
- data = {"pixel_values": images}
77
-
78
- return BatchFeature(data=data, tensor_type=return_tensors)
79
-
80
- # ============================================================================
81
- # Definitions for SigLIP models.
82
- # ============================================================================
83
-
84
- @dataclass
85
- # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
86
- class SiglipVisionModelOutput(ModelOutput):
87
- """
88
- Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
89
-
90
- Args:
91
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
92
- The image embeddings obtained by applying the projection layer to the pooler_output.
93
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
94
- Sequence of hidden-states at the output of the last layer of the model.
95
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
96
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
97
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
98
-
99
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
100
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
101
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
102
- sequence_length)`.
103
-
104
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
105
- heads.
106
- """
107
-
108
- image_embeds: Optional[torch.FloatTensor] = None
109
- last_hidden_state: torch.FloatTensor = None
110
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
111
- attentions: Optional[Tuple[torch.FloatTensor]] = None
112
-
113
-
114
- class SiglipVisionEmbeddings(nn.Module):
115
- def __init__(self, config: SiglipVisionConfig):
116
- super().__init__()
117
- self.config = config
118
- self.embed_dim = config.hidden_size
119
- self.image_size = config.image_size
120
- self.patch_size = config.patch_size
121
-
122
- self.patch_embedding = nn.Conv2d(
123
- in_channels=config.num_channels,
124
- out_channels=self.embed_dim,
125
- kernel_size=self.patch_size,
126
- stride=self.patch_size,
127
- padding="valid",
128
- )
129
-
130
- self.num_patches = (self.image_size // self.patch_size) ** 2
131
- self.num_positions = self.num_patches
132
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
133
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
134
-
135
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
136
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
137
- embeddings = patch_embeds.flatten(2).transpose(1, 2)
138
-
139
- embeddings = embeddings + self.position_embedding(self.position_ids)
140
- return embeddings
141
-
142
-
143
-
144
- class SiglipAttention(nn.Module):
145
- """Multi-headed attention from 'Attention Is All You Need' paper"""
146
-
147
- # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
148
- def __init__(self, config):
149
- super().__init__()
150
- self.config = config
151
- self.embed_dim = config.hidden_size
152
- self.num_heads = config.num_attention_heads
153
- self.head_dim = self.embed_dim // self.num_heads
154
- if self.head_dim * self.num_heads != self.embed_dim:
155
- raise ValueError(
156
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
157
- f" {self.num_heads})."
158
- )
159
- self.scale = self.head_dim**-0.5
160
- self.dropout = config.attention_dropout
161
-
162
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
163
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
164
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
165
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
166
-
167
- def forward(
168
- self,
169
- hidden_states: torch.Tensor,
170
- attention_mask: Optional[torch.Tensor] = None,
171
- output_attentions: Optional[bool] = False,
172
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
173
- """Input shape: Batch x Time x Channel"""
174
-
175
- batch_size, q_len, _ = hidden_states.size()
176
-
177
- query_states = self.q_proj(hidden_states)
178
- key_states = self.k_proj(hidden_states)
179
- value_states = self.v_proj(hidden_states)
180
-
181
- query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
182
- key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
183
- value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
184
-
185
- k_v_seq_len = key_states.shape[-2]
186
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
187
-
188
- if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
189
- raise ValueError(
190
- f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
191
- f" {attn_weights.size()}"
192
- )
193
-
194
- if attention_mask is not None:
195
- if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
196
- raise ValueError(
197
- f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
198
- )
199
- attn_weights = attn_weights + attention_mask
200
-
201
- # upcast attention to fp32
202
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
203
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
204
- attn_output = torch.matmul(attn_weights, value_states)
205
-
206
- if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
207
- raise ValueError(
208
- f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
209
- f" {attn_output.size()}"
210
- )
211
-
212
- attn_output = attn_output.transpose(1, 2).contiguous()
213
- attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
214
-
215
- attn_output = self.out_proj(attn_output)
216
-
217
- return attn_output, attn_weights
218
-
219
-
220
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
221
- class SiglipMLP(nn.Module):
222
- def __init__(self, config):
223
- super().__init__()
224
- self.config = config
225
- self.activation_fn = ACT2FN[config.hidden_act]
226
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
227
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
228
-
229
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
230
- hidden_states = self.fc1(hidden_states)
231
- hidden_states = self.activation_fn(hidden_states)
232
- hidden_states = self.fc2(hidden_states)
233
- return hidden_states
234
-
235
-
236
- # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
237
- class SiglipEncoderLayer(nn.Module):
238
- def __init__(self, config: SiglipVisionConfig):
239
- super().__init__()
240
- self.embed_dim = config.hidden_size
241
- self.self_attn = SiglipAttention(config)
242
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
243
- self.mlp = SiglipMLP(config)
244
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
245
-
246
- # Ignore copy
247
- def forward(
248
- self,
249
- hidden_states: torch.Tensor,
250
- attention_mask: torch.Tensor,
251
- output_attentions: Optional[bool] = False,
252
- ) -> Tuple[torch.FloatTensor]:
253
- """
254
- Args:
255
- hidden_states (`torch.FloatTensor`):
256
- Input to the layer of shape `(batch, seq_len, embed_dim)`.
257
- attention_mask (`torch.FloatTensor`):
258
- Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
259
- output_attentions (`bool`, *optional*, defaults to `False`):
260
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
261
- returned tensors for more detail.
262
- """
263
- residual = hidden_states
264
-
265
- hidden_states = self.layer_norm1(hidden_states)
266
- hidden_states, attn_weights = self.self_attn(
267
- hidden_states=hidden_states,
268
- attention_mask=attention_mask,
269
- output_attentions=output_attentions,
270
- )
271
- hidden_states = residual + hidden_states
272
-
273
- residual = hidden_states
274
- hidden_states = self.layer_norm2(hidden_states)
275
- hidden_states = self.mlp(hidden_states)
276
- hidden_states = residual + hidden_states
277
-
278
- outputs = (hidden_states,)
279
-
280
- if output_attentions:
281
- outputs += (attn_weights,)
282
-
283
- return outputs
284
-
285
-
286
- class SiglipPreTrainedModel(PreTrainedModel):
287
- """
288
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
289
- models.
290
- """
291
-
292
- config_class = SiglipVisionConfig
293
- base_model_prefix = "siglip"
294
- supports_gradient_checkpointing = True
295
-
296
- def _init_weights(self, module):
297
- """Initialize the weights"""
298
- pass
299
-
300
- # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
301
- class SiglipEncoder(nn.Module):
302
- """
303
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
304
- [`SiglipEncoderLayer`].
305
-
306
- Args:
307
- config: SiglipVisionConfig
308
- """
309
-
310
- def __init__(self, config: SiglipVisionConfig):
311
- super().__init__()
312
- self.config = config
313
- self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
314
- self.gradient_checkpointing = False
315
-
316
- # Ignore copy
317
- def forward(
318
- self,
319
- inputs_embeds,
320
- attention_mask: Optional[torch.Tensor] = None,
321
- output_attentions: Optional[bool] = None,
322
- output_hidden_states: Optional[bool] = None,
323
- return_dict: Optional[bool] = None,
324
- ) -> Union[Tuple, BaseModelOutput]:
325
- r"""
326
- Args:
327
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
328
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
329
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
330
- than the model's internal embedding lookup matrix.
331
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
332
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
333
-
334
- - 1 for tokens that are **not masked**,
335
- - 0 for tokens that are **masked**.
336
-
337
- [What are attention masks?](../glossary#attention-mask)
338
- output_attentions (`bool`, *optional*):
339
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
340
- returned tensors for more detail.
341
- output_hidden_states (`bool`, *optional*):
342
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
343
- for more detail.
344
- return_dict (`bool`, *optional*):
345
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
346
- """
347
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
348
- output_hidden_states = (
349
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
350
- )
351
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
352
-
353
- encoder_states = () if output_hidden_states else None
354
- all_attentions = () if output_attentions else None
355
-
356
- hidden_states = inputs_embeds
357
- for encoder_layer in self.layers:
358
- if output_hidden_states:
359
- encoder_states = encoder_states + (hidden_states,)
360
- if self.gradient_checkpointing and self.training:
361
- layer_outputs = self._gradient_checkpointing_func(
362
- encoder_layer.__call__,
363
- hidden_states,
364
- attention_mask,
365
- output_attentions,
366
- )
367
- else:
368
- layer_outputs = encoder_layer(
369
- hidden_states,
370
- attention_mask,
371
- output_attentions=output_attentions,
372
- )
373
-
374
- hidden_states = layer_outputs[0]
375
-
376
- if output_attentions:
377
- all_attentions = all_attentions + (layer_outputs[1],)
378
-
379
- if output_hidden_states:
380
- encoder_states = encoder_states + (hidden_states,)
381
-
382
- if not return_dict:
383
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
384
- return BaseModelOutput(
385
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
386
- )
387
-
388
-
389
- class SiglipVisionTransformer(nn.Module):
390
- def __init__(self, config: SiglipVisionConfig):
391
- super().__init__()
392
- self.config = config
393
- embed_dim = config.hidden_size
394
-
395
- self.embeddings = SiglipVisionEmbeddings(config)
396
- self.encoder = SiglipEncoder(config)
397
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
398
- self.head = SiglipMultiheadAttentionPoolingHead(config)
399
-
400
- def forward(
401
- self,
402
- pixel_values,
403
- output_attentions: Optional[bool] = None,
404
- output_hidden_states: Optional[bool] = None,
405
- return_dict: Optional[bool] = None,
406
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
407
- r"""
408
- Returns:
409
-
410
- """
411
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
412
- output_hidden_states = (
413
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
414
- )
415
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
416
-
417
- hidden_states = self.embeddings(pixel_values)
418
-
419
- encoder_outputs = self.encoder(
420
- inputs_embeds=hidden_states,
421
- output_attentions=output_attentions,
422
- output_hidden_states=output_hidden_states,
423
- return_dict=return_dict,
424
- )
425
-
426
- last_hidden_state = encoder_outputs[0]
427
- last_hidden_state = self.post_layernorm(last_hidden_state)
428
-
429
- pooled_output = self.head(last_hidden_state)
430
-
431
- if not return_dict:
432
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
433
-
434
- return BaseModelOutputWithPooling(
435
- last_hidden_state=last_hidden_state,
436
- pooler_output=pooled_output,
437
- hidden_states=encoder_outputs.hidden_states,
438
- attentions=encoder_outputs.attentions,
439
- )
440
-
441
-
442
- class SiglipMultiheadAttentionPoolingHead(nn.Module):
443
- """Multihead Attention Pooling."""
444
-
445
- def __init__(self, config: SiglipVisionConfig):
446
- super().__init__()
447
-
448
- self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
449
- self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
450
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
451
- self.mlp = SiglipMLP(config)
452
-
453
- def forward(self, hidden_state):
454
- batch_size = hidden_state.shape[0]
455
- probe = self.probe.repeat(batch_size, 1, 1)
456
-
457
- hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
458
-
459
- residual = hidden_state
460
- hidden_state = self.layernorm(hidden_state)
461
- hidden_state = residual + self.mlp(hidden_state)
462
-
463
- return hidden_state[:, 0]
464
-
465
-
466
- class SiglipVisionModel(SiglipPreTrainedModel):
467
- config_class = SiglipVisionConfig
468
- main_input_name = "pixel_values"
469
- _no_split_modules = ["SiglipEncoderLayer"]
470
-
471
- def __init__(self, config: SiglipVisionConfig):
472
- super().__init__(config)
473
-
474
- self.vision_model = SiglipVisionTransformer(config)
475
-
476
- # Initialize weights and apply final processing
477
- self.post_init()
478
-
479
- def get_input_embeddings(self) -> nn.Module:
480
- return self.vision_model.embeddings.patch_embedding
481
-
482
- def forward(
483
- self,
484
- pixel_values,
485
- output_attentions: Optional[bool] = None,
486
- output_hidden_states: Optional[bool] = None,
487
- return_dict: Optional[bool] = None,
488
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
489
- r"""
490
- Returns:
491
-
492
- Examples:
493
-
494
- ```python
495
- >>> from PIL import Image
496
- >>> import requests
497
- >>> from transformers import AutoProcessor, SiglipVisionModel
498
-
499
- >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
500
- >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
501
-
502
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
503
- >>> image = Image.open(requests.get(url, stream=True).raw)
504
-
505
- >>> inputs = processor(images=image, return_tensors="pt")
506
-
507
- >>> outputs = model(**inputs)
508
- >>> last_hidden_state = outputs.last_hidden_state
509
- >>> pooled_output = outputs.pooler_output # pooled features
510
- ```"""
511
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
512
-
513
- return self.vision_model(
514
- pixel_values=pixel_values,
515
- output_attentions=output_attentions,
516
- output_hidden_states=output_hidden_states,
517
- return_dict=return_dict,
518
- )
519
-
520
-
521
- # ============================================================================
522
- # VisionTower module for Imp
523
- # ============================================================================
524
-
525
- class VisionTower(nn.Module):
526
- def __init__(self, vision_tower_cfg, delay_load=False):
527
- super().__init__()
528
-
529
- self.is_loaded = False
530
-
531
- self.config = vision_tower_cfg
532
- self.vision_tower_name = vision_tower_cfg.mm_vision_tower
533
- self.select_layer = vision_tower_cfg.mm_vision_select_layer
534
- # self.select_feature = getattr(vision_tower_cfg, 'mm_vision_select_feature', 'patch')
535
-
536
- self.image_processor = simple_image_processor
537
-
538
- if not delay_load:
539
- self.load_model()
540
- else:
541
- raise NotImplementedError("delay load is not implemented yet.")
542
-
543
- def load_model(self):
544
- if self.is_loaded:
545
- return
546
-
547
- # "google/siglip-so400m-patch14-384"
548
- # self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
549
- self.vision_tower = SiglipVisionModel(self.config)
550
- del self.vision_tower.vision_model.encoder.layers[(self.select_layer + 1):]
551
- self.vision_tower.vision_model.head = nn.Identity()
552
- self.vision_tower.requires_grad_(False)
553
- self.vision_tower.eval()
554
-
555
- self.is_loaded = True
556
-
557
- @torch.no_grad()
558
- def forward(self, images):
559
- if type(images) is list:
560
- image_features = []
561
- for image in images:
562
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
563
- image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
564
- assert image_features.shape[-2] == 729
565
- image_features.append(image_feature)
566
- else:
567
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
568
- image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
569
- assert image_features.shape[-2] == 729
570
-
571
- return image_features
572
-
573
- @property
574
- def dummy_feature(self):
575
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
576
-
577
- @property
578
- def dtype(self):
579
- for p in self.vision_tower.parameters():
580
- return p.dtype
581
-
582
- @property
583
- def device(self):
584
- for p in self.vision_tower.parameters():
585
- return p.device
586
-
587
- @property
588
- def hidden_size(self):
589
- return self.config.hidden_size
590
-
591
- @property
592
- def num_patches(self):
593
- return (self.config.image_size // self.config.patch_size) ** 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocab.json CHANGED
The diff for this file is too large to render. See raw diff