Tom Aarsen commited on
Commit
f741574
1 Parent(s): a2095b8

Use remote implementation instead

Browse files
Files changed (11) hide show
  1. config.json +2 -2
  2. configuration_clip.py +0 -304
  3. custom_st.py +0 -174
  4. custom_st_2.py +0 -3
  5. eva_model.py +0 -764
  6. hf_model.py +0 -297
  7. modeling_clip.py +0 -570
  8. modules.json +1 -1
  9. processing_clip.py +0 -88
  10. rope_embeddings.py +0 -165
  11. transform.py +0 -458
config.json CHANGED
@@ -6,8 +6,8 @@
6
  "JinaCLIPModel"
7
  ],
8
  "auto_map": {
9
- "AutoConfig": "configuration_clip.JinaCLIPConfig",
10
- "AutoModel": "modeling_clip.JinaCLIPModel"
11
  },
12
  "initializer_factor": 1.0,
13
  "logit_scale_init_value": 2.6592,
 
6
  "JinaCLIPModel"
7
  ],
8
  "auto_map": {
9
+ "AutoConfig": "tomaarsen/jina-clip-implementation-st--configuration_clip.JinaCLIPConfig",
10
+ "AutoModel": "tomaarsen/jina-clip-implementation-st--modeling_clip.JinaCLIPModel"
11
  },
12
  "initializer_factor": 1.0,
13
  "logit_scale_init_value": 2.6592,
configuration_clip.py DELETED
@@ -1,304 +0,0 @@
1
- # coding=utf-8
2
- #
3
- # Code mainly copied from:
4
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/configuration_clip.py
5
- # and adjusted for Jina CLIP
6
-
7
- import os
8
- from copy import deepcopy
9
- from typing import Any, Dict, Optional, Union
10
-
11
- from transformers import PretrainedConfig, logging
12
-
13
- logger = logging.get_logger(__name__)
14
-
15
-
16
- """ Jina CLIP model configuration """
17
-
18
-
19
- class JinaCLIPTextConfig(PretrainedConfig):
20
- model_type = 'jina_clip_text'
21
-
22
- def __init__(
23
- self,
24
- embed_dim: int = 768,
25
- hf_model_name_or_path: str = 'jinaai/jina-bert-flash-implementation',
26
- hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
27
- pooler_type: Optional[str] = None,
28
- proj_type: Optional[str] = None,
29
- proj_bias: bool = False,
30
- **kwargs,
31
- ):
32
- super().__init__(**kwargs)
33
-
34
- self.embed_dim = embed_dim
35
- self.hf_model_name_or_path = hf_model_name_or_path
36
- self.hf_model_config_kwargs = hf_model_config_kwargs or {}
37
- self.pooler_type = pooler_type
38
- self.proj_type = proj_type
39
- self.proj_bias = proj_bias
40
-
41
- @classmethod
42
- def from_pretrained(
43
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
44
- ) -> 'PretrainedConfig':
45
- cls._set_token_in_kwargs(kwargs)
46
-
47
- configdict, kwargs = cls.get_config_dict(
48
- pretrained_model_name_or_path, **kwargs
49
- )
50
-
51
- # get the text config dict if we are loading from JinaCLIPConfig
52
- if configdict.get('model_type') == 'jina_clip':
53
- configdict = configdict['text_config']
54
-
55
- if (
56
- 'model_type' in configdict
57
- and hasattr(cls, 'model_type')
58
- and configdict['model_type'] != cls.model_type
59
- ):
60
- logger.warning(
61
- f'You are using a model of type {configdict["model_type"]} to '
62
- f'instantiate a model of type {cls.model_type}. This is not supported '
63
- 'for all configurations of models and can yield errors.'
64
- )
65
-
66
- return cls.from_dict(configdict, **kwargs)
67
-
68
-
69
- class JinaCLIPVisionConfig(PretrainedConfig):
70
- model_type = 'jina_clip_vision'
71
-
72
- def __init__(
73
- self,
74
- embed_dim: int = 768,
75
- width: int = 768,
76
- image_size: int = 224,
77
- patch_size: int = 16,
78
- layers: int = 12,
79
- head_width: int = 64,
80
- mlp_ratio: float = 4.0,
81
- ls_init_value: Optional[float] = None,
82
- patch_dropout: float = 0.0,
83
- qkv_bias: bool = True,
84
- fused_layer_norm: bool = False,
85
- x_attention: bool = False,
86
- post_norm: bool = False,
87
- rope_embeddings: bool = False,
88
- pt_hw_seq_len: int = 16,
89
- intp_freq: bool = False,
90
- naive_swiglu: bool = False,
91
- subln: bool = False,
92
- drop_path_rate: float = 0.0,
93
- proj_type: Optional[str] = None,
94
- **kwargs,
95
- ):
96
- super().__init__(**kwargs)
97
-
98
- self.layers = layers
99
- self.embed_dim = embed_dim
100
- self.width = width
101
- self.head_width = head_width
102
- self.mlp_ratio = mlp_ratio
103
- self.image_size = image_size
104
- self.patch_size = patch_size
105
- self.ls_init_value = ls_init_value
106
- self.patch_dropout = patch_dropout
107
- self.qkv_bias = qkv_bias
108
- self.fused_layer_norm = fused_layer_norm
109
- self.x_attention = x_attention
110
- self.post_norm = post_norm
111
- self.rope_embeddings = rope_embeddings
112
- self.pt_hw_seq_len = pt_hw_seq_len
113
- self.intp_freq = intp_freq
114
- self.naive_swiglu = naive_swiglu
115
- self.subln = subln
116
- self.drop_path_rate = drop_path_rate
117
- self.proj_type = proj_type
118
-
119
- @classmethod
120
- def from_pretrained(
121
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
122
- ) -> 'PretrainedConfig':
123
- cls._set_token_in_kwargs(kwargs)
124
-
125
- configdict, kwargs = cls.get_config_dict(
126
- pretrained_model_name_or_path, **kwargs
127
- )
128
-
129
- # get the vision config dict if we are loading from JinaCLIPConfig
130
- if configdict.get('model_type') == 'jina_clip':
131
- configdict = configdict['vision_config']
132
-
133
- if (
134
- 'model_type' in configdict
135
- and hasattr(cls, 'model_type')
136
- and configdict['model_type'] != cls.model_type
137
- ):
138
- logger.warning(
139
- f'You are using a model of type {configdict["model_type"]} to '
140
- f'instantiate a model of type {cls.model_type}. This is not supported '
141
- 'for all configurations of models and can yield errors.'
142
- )
143
-
144
- return cls.from_dict(configdict, **kwargs)
145
-
146
-
147
- class JinaCLIPConfig(PretrainedConfig):
148
- model_type = 'jina_clip'
149
- is_composition = True
150
-
151
- def __init__(
152
- self,
153
- text_config: Optional[Dict] = None,
154
- vision_config: Optional[Dict] = None,
155
- add_projections: bool = False,
156
- projection_dim: int = 768,
157
- logit_scale_init_value: float = 2.6592,
158
- use_text_flash_attn: Optional[bool] = None,
159
- use_vision_xformers: Optional[bool] = None,
160
- **kwargs,
161
- ):
162
- # If `_config_dict` exist, we use them for the backward compatibility.
163
- # We pop out these 2 attributes before calling `super().__init__` to avoid
164
- # them being saved (which causes a lot of confusion!).
165
-
166
- text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
167
- vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
168
- self.use_text_flash_attn = use_text_flash_attn
169
- self.use_vision_xformers = use_vision_xformers
170
-
171
- super().__init__(**kwargs)
172
-
173
- if text_config_dict is not None:
174
- if text_config is None:
175
- text_config = {}
176
-
177
- # This is the complete result when using `text_config_dict`.
178
- _text_config_dict = JinaCLIPTextConfig(**text_config_dict).to_dict()
179
-
180
- # Give a warning if the values exist in both `_text_config_dict` and
181
- # `text_config` but being different.
182
- for key, value in _text_config_dict.items():
183
- if (
184
- key in text_config
185
- and value != text_config[key]
186
- and key not in ['transformers_version']
187
- ):
188
- # If specified in `text_config_dict`
189
- if key in text_config_dict:
190
- message = (
191
- f'`{key}` is found in both `text_config_dict` and '
192
- f'`text_config` but with different values. '
193
- f'The value `text_config_dict["{key}"]` will be used '
194
- f'instead.'
195
- )
196
- # If inferred from default argument values (
197
- # just to be super careful)
198
- else:
199
- message = (
200
- f'`text_config_dict` is provided which will be used to '
201
- f'initialize `JinaCLIPTextConfig`. The '
202
- f'value `text_config["{key}"]` will be overriden.'
203
- )
204
- logger.info(message)
205
-
206
- # Update all values in `text_config` with the ones in `_text_config_dict`.
207
- text_config.update(_text_config_dict)
208
-
209
- if vision_config_dict is not None:
210
- if vision_config is None:
211
- vision_config = {}
212
-
213
- # This is the complete result when using `vision_config_dict`.
214
- _vision_config_dict = JinaCLIPVisionConfig(**vision_config_dict).to_dict()
215
- # convert keys to string instead of integer
216
- if 'id2label' in _vision_config_dict:
217
- _vision_config_dict['id2label'] = {
218
- str(key): value
219
- for key, value in _vision_config_dict['id2label'].items()
220
- }
221
-
222
- # Give a warning if the values exist in both `_vision_config_dict`
223
- # and `vision_config` but being different.
224
- for key, value in _vision_config_dict.items():
225
- if (
226
- key in vision_config
227
- and value != vision_config[key]
228
- and key not in ['transformers_version']
229
- ):
230
- # If specified in `vision_config_dict`
231
- if key in vision_config_dict:
232
- message = (
233
- f'`{key}` is found in both `vision_config_dict` and '
234
- f'`vision_config` but with different '
235
- f'values. The value `vision_config_dict["{key}"]` will '
236
- f'be used instead.'
237
- )
238
- # If inferred from default argument values
239
- # (just to be super careful)
240
- else:
241
- message = (
242
- f'`vision_config_dict` is provided which will be used to '
243
- f'initialize `JinaCLIPVisionConfig`. '
244
- f'The value `vision_config["{key}"]` will be overriden.'
245
- )
246
- logger.info(message)
247
-
248
- # Update all values in `vision_config` with the ones in
249
- # `_vision_config_dict`.
250
- vision_config.update(_vision_config_dict)
251
-
252
- if text_config is None:
253
- text_config = {}
254
- logger.info(
255
- '`text_config` is `None`. Initializing the `JinaCLIPTextConfig` with '
256
- 'default values.'
257
- )
258
-
259
- if vision_config is None:
260
- vision_config = {}
261
- logger.info(
262
- '`vision_config` is `None`. initializing the `JinaCLIPVisionConfig` '
263
- 'with default values.'
264
- )
265
-
266
- self.text_config = JinaCLIPTextConfig(**text_config)
267
- self.vision_config = JinaCLIPVisionConfig(**vision_config)
268
-
269
- self.add_projections = add_projections
270
- self.projection_dim = projection_dim
271
- self.logit_scale_init_value = logit_scale_init_value
272
- self.initializer_factor = 1.0
273
-
274
- if not self.add_projections:
275
- if self.text_config.embed_dim != self.vision_config.embed_dim:
276
- raise ValueError(
277
- 'When projections are disabled (`add_projections=False`), text '
278
- 'and vision towers need to have the same embedding dimensionality. '
279
- f'Currently text embedding dim is {self.text_config.embed_dim} != '
280
- f'{self.vision_config.embed_dim} of the vision tower. '
281
- 'Either set the same output dim for both towers, or enable '
282
- 'projections with `add_projections=True`.'
283
- )
284
-
285
- @classmethod
286
- def from_text_vision_configs(
287
- cls,
288
- text_config: JinaCLIPTextConfig,
289
- vision_config: JinaCLIPVisionConfig,
290
- **kwargs,
291
- ):
292
- return cls(
293
- text_config=text_config.to_dict(),
294
- vision_config=vision_config.to_dict(),
295
- projection_dim=text_config.projection_dim,
296
- **kwargs,
297
- )
298
-
299
- def to_dict(self):
300
- output = deepcopy(self.__dict__)
301
- output['text_config'] = self.text_config.to_dict()
302
- output['vision_config'] = self.vision_config.to_dict()
303
- output['model_type'] = self.__class__.model_type
304
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
custom_st.py DELETED
@@ -1,174 +0,0 @@
1
- import base64
2
- from io import BytesIO
3
- import json
4
- import os
5
- from typing import Any, Dict, List, Optional, Tuple, Union
6
-
7
- from .custom_st_2 import OtherClass
8
- import requests
9
- import torch
10
- from torch import nn
11
- from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoImageProcessor
12
- from PIL import Image
13
-
14
- OtherClass()
15
-
16
- class Transformer(nn.Module):
17
- """Huggingface AutoModel to generate token embeddings.
18
- Loads the correct class, e.g. BERT / RoBERTa etc.
19
-
20
- Args:
21
- model_name_or_path: Huggingface models name
22
- (https://huggingface.co/models)
23
- max_seq_length: Truncate any inputs longer than max_seq_length
24
- model_args: Keyword arguments passed to the Huggingface
25
- Transformers model
26
- tokenizer_args: Keyword arguments passed to the Huggingface
27
- Transformers tokenizer
28
- config_args: Keyword arguments passed to the Huggingface
29
- Transformers config
30
- cache_dir: Cache dir for Huggingface Transformers to store/load
31
- models
32
- do_lower_case: If true, lowercases the input (independent if the
33
- model is cased or not)
34
- tokenizer_name_or_path: Name or path of the tokenizer. When
35
- None, then model_name_or_path is used
36
- """
37
-
38
- def __init__(
39
- self,
40
- model_name_or_path: str,
41
- max_seq_length: Optional[int] = None,
42
- model_args: Optional[Dict[str, Any]] = None,
43
- tokenizer_args: Optional[Dict[str, Any]] = None,
44
- config_args: Optional[Dict[str, Any]] = None,
45
- cache_dir: Optional[str] = None,
46
- do_lower_case: bool = False,
47
- tokenizer_name_or_path: str = None,
48
- ) -> None:
49
- super(Transformer, self).__init__()
50
- self.config_keys = ["max_seq_length", "do_lower_case"]
51
- self.do_lower_case = do_lower_case
52
- if model_args is None:
53
- model_args = {}
54
- if tokenizer_args is None:
55
- tokenizer_args = {}
56
- if config_args is None:
57
- config_args = {}
58
-
59
- config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
60
- self.jina_clip = AutoModel.from_pretrained(
61
- model_name_or_path, config=config, cache_dir=cache_dir, **model_args
62
- )
63
-
64
- if max_seq_length is not None and "model_max_length" not in tokenizer_args:
65
- tokenizer_args["model_max_length"] = max_seq_length
66
- self.tokenizer = AutoTokenizer.from_pretrained(
67
- tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
68
- cache_dir=cache_dir,
69
- **tokenizer_args,
70
- )
71
- self.preprocessor = AutoImageProcessor.from_pretrained(
72
- tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
73
- cache_dir=cache_dir,
74
- **tokenizer_args,
75
- )
76
-
77
- # No max_seq_length set. Try to infer from model
78
- if max_seq_length is None:
79
- if (
80
- hasattr(self.jina_clip, "config")
81
- and hasattr(self.jina_clip.config, "max_position_embeddings")
82
- and hasattr(self.tokenizer, "model_max_length")
83
- ):
84
- max_seq_length = min(self.jina_clip.config.max_position_embeddings, self.tokenizer.model_max_length)
85
-
86
- self.max_seq_length = max_seq_length
87
-
88
- if tokenizer_name_or_path is not None:
89
- self.jina_clip.config.tokenizer_class = self.tokenizer.__class__.__name__
90
-
91
- def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
92
- """Returns token_embeddings, cls_token"""
93
- if "input_ids" in features:
94
- embedding = self.jina_clip.get_text_features(input_ids=features["input_ids"])
95
- else:
96
- embedding = self.jina_clip.get_image_features(pixel_values=features["pixel_values"])
97
- return {"sentence_embedding": embedding}
98
-
99
- def get_word_embedding_dimension(self) -> int:
100
- return self.config.text_config.embed_dim
101
-
102
- def decode_data_image(data_image_str):
103
- header, data = data_image_str.split(',', 1)
104
- image_data = base64.b64decode(data)
105
- return Image.open(BytesIO(image_data))
106
-
107
- def tokenize(
108
- self, batch: Union[List[str]], padding: Union[str, bool] = True
109
- ) -> Dict[str, torch.Tensor]:
110
- """Tokenizes a text and maps tokens to token-ids"""
111
- images = []
112
- texts = []
113
- for sample in batch:
114
- if isinstance(sample, str):
115
- if sample.startswith('http'):
116
- response = requests.get(sample)
117
- images.append(Image.open(BytesIO(response.content)).convert('RGB'))
118
- elif sample.startswith('data:image/'):
119
- images.append(self.decode_data_image(sample).convert('RGB'))
120
- else:
121
- # TODO: Make sure that Image.open fails for non-image files
122
- try:
123
- images.append(Image.open(sample).convert('RGB'))
124
- except:
125
- texts.append(sample)
126
- elif isinstance(sample, Image.Image):
127
- images.append(sample.convert('RGB'))
128
-
129
- if images and texts:
130
- raise ValueError('Batch must contain either images or texts, not both')
131
-
132
- if texts:
133
- return self.tokenizer(
134
- texts,
135
- padding=padding,
136
- truncation="longest_first",
137
- return_tensors="pt",
138
- max_length=self.max_seq_length,
139
- )
140
- elif images:
141
- return self.preprocessor(images)
142
- return {}
143
-
144
- def save(self, output_path: str, safe_serialization: bool = True) -> None:
145
- self.jina_clip.save_pretrained(output_path, safe_serialization=safe_serialization)
146
- self.tokenizer.save_pretrained(output_path)
147
- self.preprocessor.save_pretrained(output_path)
148
-
149
- @staticmethod
150
- def load(input_path: str) -> "Transformer":
151
- # Old classes used other config names than 'sentence_bert_config.json'
152
- for config_name in [
153
- "sentence_bert_config.json",
154
- "sentence_roberta_config.json",
155
- "sentence_distilbert_config.json",
156
- "sentence_camembert_config.json",
157
- "sentence_albert_config.json",
158
- "sentence_xlm-roberta_config.json",
159
- "sentence_xlnet_config.json",
160
- ]:
161
- sbert_config_path = os.path.join(input_path, config_name)
162
- if os.path.exists(sbert_config_path):
163
- break
164
-
165
- with open(sbert_config_path) as fIn:
166
- config = json.load(fIn)
167
- # Don't allow configs to set trust_remote_code
168
- if "model_args" in config and "trust_remote_code" in config["model_args"]:
169
- config["model_args"].pop("trust_remote_code")
170
- if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
171
- config["tokenizer_args"].pop("trust_remote_code")
172
- if "config_args" in config and "trust_remote_code" in config["config_args"]:
173
- config["config_args"].pop("trust_remote_code")
174
- return Transformer(model_name_or_path=input_path, **config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
custom_st_2.py DELETED
@@ -1,3 +0,0 @@
1
-
2
- class OtherClass:
3
- pass
 
 
 
 
eva_model.py DELETED
@@ -1,764 +0,0 @@
1
- # --------------------------------------------------------
2
- # Adapted from EVA CLIP
3
- # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
- # --------------------------------------------------------
5
-
6
- import math
7
- import os
8
- from functools import partial
9
-
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
-
14
- try:
15
- from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
- except ImportError or ModuleNotFoundError:
17
- from timm.layers import drop_path, to_2tuple, trunc_normal_
18
-
19
- from .rope_embeddings import VisionRotaryEmbeddingFast
20
-
21
- if os.getenv('ENV_TYPE') == 'deepspeed':
22
- try:
23
- from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
24
- except ImportError or ModuleNotFoundError:
25
- from torch.utils.checkpoint import checkpoint
26
- else:
27
- from torch.utils.checkpoint import checkpoint
28
-
29
- try:
30
- import xformers.ops as xops
31
- except ImportError:
32
- xops = None
33
-
34
-
35
- class PatchDropout(nn.Module):
36
- """
37
- https://arxiv.org/abs/2212.00794
38
- """
39
-
40
- def __init__(self, prob, exclude_first_token=True):
41
- super().__init__()
42
- assert 0 <= prob < 1.0
43
- self.prob = prob
44
- self.exclude_first_token = exclude_first_token # exclude CLS token
45
-
46
- def forward(self, x):
47
- if not self.training or self.prob == 0.0:
48
- return x
49
-
50
- if self.exclude_first_token:
51
- cls_tokens, x = x[:, :1], x[:, 1:]
52
- else:
53
- cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
54
-
55
- batch = x.size()[0]
56
- num_tokens = x.size()[1]
57
-
58
- batch_indices = torch.arange(batch)
59
- batch_indices = batch_indices[..., None]
60
-
61
- keep_prob = 1 - self.prob
62
- num_patches_keep = max(1, int(num_tokens * keep_prob))
63
-
64
- rand = torch.randn(batch, num_tokens)
65
- patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
66
-
67
- x = x[batch_indices, patch_indices_keep]
68
-
69
- if self.exclude_first_token:
70
- x = torch.cat((cls_tokens, x), dim=1)
71
-
72
- return x, patch_indices_keep
73
-
74
-
75
- class DropPath(nn.Module):
76
- """Drop paths (Stochastic Depth) per sample (when applied in main path of
77
- residual blocks)."""
78
-
79
- def __init__(self, drop_prob=None):
80
- super(DropPath, self).__init__()
81
- self.drop_prob = drop_prob
82
-
83
- def forward(self, x):
84
- return drop_path(x, self.drop_prob, self.training)
85
-
86
- def extra_repr(self) -> str:
87
- return 'p={}'.format(self.drop_prob)
88
-
89
-
90
- class Mlp(nn.Module):
91
- def __init__(
92
- self,
93
- in_features,
94
- hidden_features=None,
95
- out_features=None,
96
- act_layer=nn.GELU,
97
- norm_layer=nn.LayerNorm,
98
- drop=0.0,
99
- subln=False,
100
- ):
101
- super().__init__()
102
- out_features = out_features or in_features
103
- hidden_features = hidden_features or in_features
104
- self.fc1 = nn.Linear(in_features, hidden_features)
105
- self.act = act_layer()
106
-
107
- self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
108
-
109
- self.fc2 = nn.Linear(hidden_features, out_features)
110
- self.drop = nn.Dropout(drop)
111
-
112
- def forward(self, x):
113
- x = self.fc1(x)
114
- x = self.act(x)
115
- # x = self.drop(x)
116
- # commit this for the orignal BERT implement
117
- x = self.ffn_ln(x)
118
-
119
- x = self.fc2(x)
120
- x = self.drop(x)
121
- return x
122
-
123
-
124
- class SwiGLU(nn.Module):
125
- def __init__(
126
- self,
127
- in_features,
128
- hidden_features=None,
129
- out_features=None,
130
- act_layer=nn.SiLU,
131
- drop=0.0,
132
- norm_layer=nn.LayerNorm,
133
- subln=False,
134
- ):
135
- super().__init__()
136
- out_features = out_features or in_features
137
- hidden_features = hidden_features or in_features
138
-
139
- self.w1 = nn.Linear(in_features, hidden_features)
140
- self.w2 = nn.Linear(in_features, hidden_features)
141
-
142
- self.act = act_layer()
143
- self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
144
- self.w3 = nn.Linear(hidden_features, out_features)
145
-
146
- self.drop = nn.Dropout(drop)
147
-
148
- def forward(self, x):
149
- x1 = self.w1(x)
150
- x2 = self.w2(x)
151
- hidden = self.act(x1) * x2
152
- x = self.ffn_ln(hidden)
153
- x = self.w3(x)
154
- x = self.drop(x)
155
- return x
156
-
157
-
158
- class Attention(nn.Module):
159
- def __init__(
160
- self,
161
- dim,
162
- num_heads=8,
163
- qkv_bias=False,
164
- qk_scale=None,
165
- attn_drop=0.0,
166
- proj_drop=0.0,
167
- window_size=None,
168
- attn_head_dim=None,
169
- xattn=False,
170
- rope=None,
171
- subln=False,
172
- norm_layer=nn.LayerNorm,
173
- ):
174
- super().__init__()
175
- self.num_heads = num_heads
176
- head_dim = dim // num_heads
177
- if attn_head_dim is not None:
178
- head_dim = attn_head_dim
179
- all_head_dim = head_dim * self.num_heads
180
- self.scale = qk_scale or head_dim**-0.5
181
-
182
- self.subln = subln
183
- if self.subln:
184
- self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
185
- self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
186
- self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
187
- else:
188
- self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
189
-
190
- if qkv_bias:
191
- self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
192
- self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
193
- else:
194
- self.q_bias = None
195
- self.v_bias = None
196
-
197
- if window_size:
198
- self.window_size = window_size
199
- self.num_relative_distance = (2 * window_size[0] - 1) * (
200
- 2 * window_size[1] - 1
201
- ) + 3
202
- self.relative_position_bias_table = nn.Parameter(
203
- torch.zeros(self.num_relative_distance, num_heads)
204
- ) # 2*Wh-1 * 2*Ww-1, nH
205
- # cls to token & token 2 cls & cls to cls
206
-
207
- # get pair-wise relative position index for each token inside the window
208
- coords_h = torch.arange(window_size[0])
209
- coords_w = torch.arange(window_size[1])
210
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
211
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
212
- relative_coords = (
213
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
214
- ) # 2, Wh*Ww, Wh*Ww
215
- relative_coords = relative_coords.permute(
216
- 1, 2, 0
217
- ).contiguous() # Wh*Ww, Wh*Ww, 2
218
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
219
- relative_coords[:, :, 1] += window_size[1] - 1
220
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
221
- relative_position_index = torch.zeros(
222
- size=(window_size[0] * window_size[1] + 1,) * 2,
223
- dtype=relative_coords.dtype,
224
- )
225
- relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
226
- relative_position_index[0, 0:] = self.num_relative_distance - 3
227
- relative_position_index[0:, 0] = self.num_relative_distance - 2
228
- relative_position_index[0, 0] = self.num_relative_distance - 1
229
-
230
- self.register_buffer('relative_position_index', relative_position_index)
231
- else:
232
- self.window_size = None
233
- self.relative_position_bias_table = None
234
- self.relative_position_index = None
235
-
236
- self.attn_drop = nn.Dropout(attn_drop)
237
- self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
238
- # self.proj = nn.Linear(all_head_dim, all_head_dim)
239
- self.proj = nn.Linear(all_head_dim, dim)
240
- self.proj_drop = nn.Dropout(proj_drop)
241
- self.xattn = xattn
242
- self.xattn_drop = attn_drop
243
-
244
- self.rope = rope
245
-
246
- def forward(self, x, rel_pos_bias=None, attn_mask=None):
247
- B, N, C = x.shape
248
- if self.subln:
249
- q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
250
- k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
251
- v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
252
-
253
- q = q.reshape(B, N, self.num_heads, -1).permute(
254
- 0, 2, 1, 3
255
- ) # B, num_heads, N, C
256
- k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
257
- v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
258
- else:
259
- qkv_bias = None
260
- if self.q_bias is not None:
261
- qkv_bias = torch.cat(
262
- (
263
- self.q_bias,
264
- torch.zeros_like(self.v_bias, requires_grad=False),
265
- self.v_bias,
266
- )
267
- )
268
-
269
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
270
- qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
271
- 2, 0, 3, 1, 4
272
- ) # 3, B, num_heads, N, C
273
- q, k, v = qkv[0], qkv[1], qkv[2]
274
-
275
- if self.rope:
276
- # slightly fast impl
277
- q_t = q[:, :, 1:, :]
278
- ro_q_t = self.rope(q_t)
279
- q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
280
-
281
- k_t = k[:, :, 1:, :]
282
- ro_k_t = self.rope(k_t)
283
- k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
284
-
285
- if self.xattn:
286
- if xops is None:
287
- raise ValueError(
288
- "Can't use xattn without xformers. Please 'pip install xformers'"
289
- )
290
- q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
291
- k = k.permute(0, 2, 1, 3)
292
- v = v.permute(0, 2, 1, 3)
293
-
294
- x = xops.memory_efficient_attention(
295
- q,
296
- k,
297
- v,
298
- p=self.xattn_drop,
299
- scale=self.scale,
300
- )
301
- x = x.reshape(B, N, -1)
302
- x = self.inner_attn_ln(x)
303
- x = self.proj(x)
304
- x = self.proj_drop(x)
305
- else:
306
- q = q * self.scale
307
- attn = q @ k.transpose(-2, -1)
308
-
309
- if self.relative_position_bias_table is not None:
310
- relative_position_bias = self.relative_position_bias_table[
311
- self.relative_position_index.view(-1)
312
- ].view(
313
- self.window_size[0] * self.window_size[1] + 1,
314
- self.window_size[0] * self.window_size[1] + 1,
315
- -1,
316
- ) # Wh*Ww,Wh*Ww,nH
317
- relative_position_bias = relative_position_bias.permute(
318
- 2, 0, 1
319
- ).contiguous() # nH, Wh*Ww, Wh*Ww
320
- attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
321
-
322
- if rel_pos_bias is not None:
323
- attn = attn + rel_pos_bias.type_as(attn)
324
-
325
- if attn_mask is not None:
326
- attn_mask = attn_mask.bool()
327
- attn = attn.masked_fill(~attn_mask[:, None, None, :], float('-inf'))
328
-
329
- attn = attn.softmax(dim=-1)
330
- attn = self.attn_drop(attn)
331
-
332
- x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
333
- x = self.inner_attn_ln(x)
334
- x = self.proj(x)
335
- x = self.proj_drop(x)
336
- return x
337
-
338
-
339
- class Block(nn.Module):
340
- def __init__(
341
- self,
342
- dim,
343
- num_heads,
344
- mlp_ratio=4.0,
345
- qkv_bias=False,
346
- qk_scale=None,
347
- drop=0.0,
348
- attn_drop=0.0,
349
- drop_path=0.0,
350
- init_values=None,
351
- act_layer=nn.GELU,
352
- norm_layer=nn.LayerNorm,
353
- window_size=None,
354
- attn_head_dim=None,
355
- xattn=False,
356
- rope=None,
357
- postnorm=False,
358
- subln=False,
359
- naiveswiglu=False,
360
- ):
361
- super().__init__()
362
- self.norm1 = norm_layer(dim)
363
- self.attn = Attention(
364
- dim,
365
- num_heads=num_heads,
366
- qkv_bias=qkv_bias,
367
- qk_scale=qk_scale,
368
- attn_drop=attn_drop,
369
- proj_drop=drop,
370
- window_size=window_size,
371
- attn_head_dim=attn_head_dim,
372
- xattn=xattn,
373
- rope=rope,
374
- subln=subln,
375
- norm_layer=norm_layer,
376
- )
377
- # NOTE: drop path for stochastic depth, we shall see if this is better
378
- # than dropout here
379
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
380
- self.norm2 = norm_layer(dim)
381
- mlp_hidden_dim = int(dim * mlp_ratio)
382
-
383
- if naiveswiglu:
384
- self.mlp = SwiGLU(
385
- in_features=dim,
386
- hidden_features=mlp_hidden_dim,
387
- subln=subln,
388
- norm_layer=norm_layer,
389
- )
390
- else:
391
- self.mlp = Mlp(
392
- in_features=dim,
393
- hidden_features=mlp_hidden_dim,
394
- act_layer=act_layer,
395
- subln=subln,
396
- drop=drop,
397
- )
398
-
399
- if init_values is not None and init_values > 0:
400
- self.gamma_1 = nn.Parameter(
401
- init_values * torch.ones((dim,)), requires_grad=True
402
- )
403
- self.gamma_2 = nn.Parameter(
404
- init_values * torch.ones((dim,)), requires_grad=True
405
- )
406
- else:
407
- self.gamma_1, self.gamma_2 = None, None
408
-
409
- self.postnorm = postnorm
410
-
411
- def forward(self, x, rel_pos_bias=None, attn_mask=None):
412
- if self.gamma_1 is None:
413
- if self.postnorm:
414
- x = x + self.drop_path(
415
- self.norm1(
416
- self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
417
- )
418
- )
419
- x = x + self.drop_path(self.norm2(self.mlp(x)))
420
- else:
421
- x = x + self.drop_path(
422
- self.attn(
423
- self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
424
- )
425
- )
426
- x = x + self.drop_path(self.mlp(self.norm2(x)))
427
- else:
428
- if self.postnorm:
429
- x = x + self.drop_path(
430
- self.gamma_1
431
- * self.norm1(
432
- self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
433
- )
434
- )
435
- x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
436
- else:
437
- x = x + self.drop_path(
438
- self.gamma_1
439
- * self.attn(
440
- self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
441
- )
442
- )
443
- x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
444
- return x
445
-
446
-
447
- class PatchEmbed(nn.Module):
448
- """Image to Patch Embedding"""
449
-
450
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
451
- super().__init__()
452
- img_size = to_2tuple(img_size)
453
- patch_size = to_2tuple(patch_size)
454
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
455
- self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
456
- self.img_size = img_size
457
- self.patch_size = patch_size
458
- self.num_patches = num_patches
459
-
460
- self.proj = nn.Conv2d(
461
- in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
462
- )
463
-
464
- def forward(self, x, **kwargs):
465
- target_dtype = self.proj.weight.dtype
466
- B, C, H, W = x.shape
467
- # FIXME look at relaxing size constraints
468
- assert H == self.img_size[0] and W == self.img_size[1], (
469
- f"Input image size ({H}*{W}) doesn't match model "
470
- f'({self.img_size[0]}*{self.img_size[1]}).'
471
- )
472
- x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
473
- return x
474
-
475
-
476
- class RelativePositionBias(nn.Module):
477
- def __init__(self, window_size, num_heads):
478
- super().__init__()
479
- self.window_size = window_size
480
- self.num_relative_distance = (2 * window_size[0] - 1) * (
481
- 2 * window_size[1] - 1
482
- ) + 3
483
- self.relative_position_bias_table = nn.Parameter(
484
- torch.zeros(self.num_relative_distance, num_heads)
485
- ) # 2*Wh-1 * 2*Ww-1, nH
486
- # cls to token & token 2 cls & cls to cls
487
-
488
- # get pair-wise relative position index for each token inside the window
489
- coords_h = torch.arange(window_size[0])
490
- coords_w = torch.arange(window_size[1])
491
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
492
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
493
- relative_coords = (
494
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
495
- ) # 2, Wh*Ww, Wh*Ww
496
- relative_coords = relative_coords.permute(
497
- 1, 2, 0
498
- ).contiguous() # Wh*Ww, Wh*Ww, 2
499
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
500
- relative_coords[:, :, 1] += window_size[1] - 1
501
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
502
- relative_position_index = torch.zeros(
503
- size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
504
- )
505
- relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
506
- relative_position_index[0, 0:] = self.num_relative_distance - 3
507
- relative_position_index[0:, 0] = self.num_relative_distance - 2
508
- relative_position_index[0, 0] = self.num_relative_distance - 1
509
-
510
- self.register_buffer('relative_position_index', relative_position_index)
511
-
512
- def forward(self):
513
- relative_position_bias = self.relative_position_bias_table[
514
- self.relative_position_index.view(-1)
515
- ].view(
516
- self.window_size[0] * self.window_size[1] + 1,
517
- self.window_size[0] * self.window_size[1] + 1,
518
- -1,
519
- ) # Wh*Ww,Wh*Ww,nH
520
- return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
521
-
522
-
523
- class EVAVisionTransformer(nn.Module):
524
- """Vision Transformer with support for patch or hybrid CNN input stage"""
525
-
526
- def __init__(
527
- self,
528
- img_size=224,
529
- patch_size=16,
530
- in_chans=3,
531
- num_classes=0,
532
- embed_dim=768,
533
- depth=12,
534
- num_heads=12,
535
- mlp_ratio=4.0,
536
- qkv_bias=False,
537
- qk_scale=None,
538
- drop_rate=0.0,
539
- attn_drop_rate=0.0,
540
- drop_path_rate=0.0,
541
- norm_layer=nn.LayerNorm,
542
- init_values=None,
543
- patch_dropout=0.0,
544
- use_abs_pos_emb=True,
545
- use_rel_pos_bias=False,
546
- use_shared_rel_pos_bias=False,
547
- rope=False,
548
- use_mean_pooling=True,
549
- init_scale=0.001,
550
- grad_checkpointing=False,
551
- xattn=False,
552
- postnorm=False,
553
- pt_hw_seq_len=16,
554
- intp_freq=False,
555
- naiveswiglu=False,
556
- subln=False,
557
- proj_type=None,
558
- ):
559
- super().__init__()
560
- self.image_size = img_size
561
- self.num_classes = num_classes
562
- self.num_features = (
563
- self.embed_dim
564
- ) = embed_dim # num_features for consistency with other models
565
-
566
- self.patch_embed = PatchEmbed(
567
- img_size=img_size,
568
- patch_size=patch_size,
569
- in_chans=in_chans,
570
- embed_dim=embed_dim,
571
- )
572
- num_patches = self.patch_embed.num_patches
573
-
574
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
575
- # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
576
- if use_abs_pos_emb:
577
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
578
- else:
579
- self.pos_embed = None
580
- self.pos_drop = nn.Dropout(p=drop_rate)
581
-
582
- if use_shared_rel_pos_bias:
583
- self.rel_pos_bias = RelativePositionBias(
584
- window_size=self.patch_embed.patch_shape, num_heads=num_heads
585
- )
586
- else:
587
- self.rel_pos_bias = None
588
-
589
- if rope:
590
- half_head_dim = embed_dim // num_heads // 2
591
- hw_seq_len = img_size // patch_size
592
- self.rope = VisionRotaryEmbeddingFast(
593
- dim=half_head_dim,
594
- pt_seq_len=pt_hw_seq_len,
595
- ft_seq_len=hw_seq_len if intp_freq else None,
596
- patch_dropout=patch_dropout,
597
- )
598
- else:
599
- self.rope = None
600
-
601
- self.naiveswiglu = naiveswiglu
602
-
603
- dpr = [
604
- x.item() for x in torch.linspace(0, drop_path_rate, depth)
605
- ] # stochastic depth decay rule
606
- self.use_rel_pos_bias = use_rel_pos_bias
607
- self.blocks = nn.ModuleList(
608
- [
609
- Block(
610
- dim=embed_dim,
611
- num_heads=num_heads,
612
- mlp_ratio=mlp_ratio,
613
- qkv_bias=qkv_bias,
614
- qk_scale=qk_scale,
615
- drop=drop_rate,
616
- attn_drop=attn_drop_rate,
617
- drop_path=dpr[i],
618
- norm_layer=norm_layer,
619
- init_values=init_values,
620
- window_size=self.patch_embed.patch_shape
621
- if use_rel_pos_bias
622
- else None,
623
- xattn=xattn,
624
- rope=self.rope,
625
- postnorm=postnorm,
626
- subln=subln,
627
- naiveswiglu=naiveswiglu,
628
- )
629
- for i in range(depth)
630
- ]
631
- )
632
- self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
633
- self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
634
- if (num_classes == embed_dim) and (proj_type is None):
635
- self.head = nn.Identity()
636
- elif proj_type == 'linear':
637
- self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias)
638
- elif proj_type == 'mlp':
639
- hidden_size = (embed_dim + num_classes) // 2
640
- self.proj = nn.Sequential(
641
- nn.Linear(embed_dim, hidden_size, bias=qkv_bias),
642
- nn.GELU(),
643
- nn.Linear(hidden_size, num_classes, bias=qkv_bias),
644
- )
645
-
646
- if self.pos_embed is not None:
647
- trunc_normal_(self.pos_embed, std=0.02)
648
-
649
- trunc_normal_(self.cls_token, std=0.02)
650
-
651
- self.apply(self._init_weights)
652
- self.fix_init_weight()
653
-
654
- if isinstance(self.head, nn.Linear):
655
- trunc_normal_(self.head.weight, std=0.02)
656
- self.head.weight.data.mul_(init_scale)
657
- if qkv_bias:
658
- self.head.bias.data.mul_(init_scale)
659
-
660
- # setting a patch_dropout of 0. would mean it is disabled and this function
661
- # would be the identity fn
662
- self.patch_dropout = (
663
- PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
664
- )
665
-
666
- self.grad_checkpointing = grad_checkpointing
667
-
668
- def fix_init_weight(self):
669
- def rescale(param, layer_id):
670
- param.div_(math.sqrt(2.0 * layer_id))
671
-
672
- for layer_id, layer in enumerate(self.blocks):
673
- rescale(layer.attn.proj.weight.data, layer_id + 1)
674
- if self.naiveswiglu:
675
- rescale(layer.mlp.w3.weight.data, layer_id + 1)
676
- else:
677
- rescale(layer.mlp.fc2.weight.data, layer_id + 1)
678
-
679
- def get_cast_dtype(self) -> torch.dtype:
680
- return self.blocks[0].mlp.fc2.weight.dtype
681
-
682
- def _init_weights(self, m):
683
- if isinstance(m, nn.Linear):
684
- trunc_normal_(m.weight, std=0.02)
685
- if m.bias is not None:
686
- nn.init.constant_(m.bias, 0)
687
- elif isinstance(m, nn.LayerNorm):
688
- nn.init.constant_(m.bias, 0)
689
- nn.init.constant_(m.weight, 1.0)
690
-
691
- def get_num_layers(self):
692
- return len(self.blocks)
693
-
694
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
695
- assert (
696
- unlocked_groups == 0
697
- ), 'partial locking not currently supported for this model'
698
- for param in self.parameters():
699
- param.requires_grad = False
700
-
701
- @torch.jit.ignore
702
- def set_grad_checkpointing(self, enable=True):
703
- self.grad_checkpointing = enable
704
-
705
- @torch.jit.ignore
706
- def no_weight_decay(self):
707
- return {'pos_embed', 'cls_token'}
708
-
709
- def get_classifier(self):
710
- return self.head
711
-
712
- def reset_classifier(self, num_classes, global_pool=''):
713
- self.num_classes = num_classes
714
- self.head = (
715
- nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
716
- )
717
-
718
- def forward_features(self, x, return_all_features=False):
719
- x = self.patch_embed(x)
720
- batch_size, seq_len, _ = x.size()
721
-
722
- cls_tokens = self.cls_token.expand(
723
- batch_size, -1, -1
724
- ) # stole cls_tokens impl from Phil Wang, thanks
725
- x = torch.cat((cls_tokens, x), dim=1)
726
- if self.pos_embed is not None:
727
- x = x + self.pos_embed
728
- x = self.pos_drop(x)
729
-
730
- # a patch_dropout of 0. would mean it is disabled and this function would do
731
- # nothing but return what was passed in
732
- if self.rope is not None:
733
- if self.training and not isinstance(self.patch_dropout, nn.Identity):
734
- x, patch_indices_keep = self.patch_dropout(x)
735
- self.rope.forward = partial(
736
- self.rope.forward, patch_indices_keep=patch_indices_keep
737
- )
738
- else:
739
- self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
740
- x = self.patch_dropout(x)
741
- else:
742
- x = self.patch_dropout(x)
743
-
744
- rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
745
- for blk in self.blocks:
746
- if self.grad_checkpointing:
747
- x = checkpoint(blk, x, (rel_pos_bias,))
748
- else:
749
- x = blk(x, rel_pos_bias=rel_pos_bias)
750
-
751
- if not return_all_features:
752
- x = self.norm(x)
753
- if self.fc_norm is not None:
754
- return self.fc_norm(x.mean(1))
755
- else:
756
- return x[:, 0]
757
- return x
758
-
759
- def forward(self, x, return_all_features=False):
760
- if return_all_features:
761
- return self.forward_features(x, return_all_features)
762
- x = self.forward_features(x)
763
- x = self.head(x)
764
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf_model.py DELETED
@@ -1,297 +0,0 @@
1
- import re
2
- from typing import Dict, Optional, Tuple
3
-
4
- import torch
5
- import torch.nn as nn
6
- from transformers import AutoConfig, AutoModel, PretrainedConfig
7
- from transformers.modeling_outputs import (
8
- BaseModelOutput,
9
- BaseModelOutputWithPooling,
10
- BaseModelOutputWithPoolingAndCrossAttentions,
11
- )
12
-
13
- """
14
- HF architecture mapping
15
- """
16
-
17
- _HF_ARCH_DICT = {
18
- # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
19
- 'roberta': {
20
- 'config_names': {
21
- 'context_length': 'max_position_embeddings',
22
- 'vocab_size': 'vocab_size',
23
- 'width': 'hidden_size',
24
- 'heads': 'num_attention_heads',
25
- 'layers': 'num_hidden_layers',
26
- 'layer_attr': 'layer',
27
- 'token_embeddings_attr': 'embeddings',
28
- },
29
- 'pooler': 'mean_pooler',
30
- },
31
- # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
32
- 'xlm-roberta': {
33
- 'config_names': {
34
- 'context_length': 'max_position_embeddings',
35
- 'vocab_size': 'vocab_size',
36
- 'width': 'hidden_size',
37
- 'heads': 'num_attention_heads',
38
- 'layers': 'num_hidden_layers',
39
- 'layer_attr': 'layer',
40
- 'token_embeddings_attr': 'embeddings',
41
- },
42
- 'pooler': 'mean_pooler',
43
- },
44
- # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
45
- 'mt5': {
46
- 'config_names': {
47
- # unlimited seqlen
48
- # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
49
- # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
50
- 'context_length': '',
51
- 'vocab_size': 'vocab_size',
52
- 'width': 'd_model',
53
- 'heads': 'num_heads',
54
- 'layers': 'num_layers',
55
- 'layer_attr': 'block',
56
- 'token_embeddings_attr': 'embed_tokens',
57
- },
58
- 'pooler': 'mean_pooler',
59
- },
60
- # https://huggingface.co/docs/transformers/model_doc/bert
61
- 'bert': {
62
- 'config_names': {
63
- 'context_length': 'max_position_embeddings',
64
- 'vocab_size': 'vocab_size',
65
- 'width': 'hidden_size',
66
- 'heads': 'num_attention_heads',
67
- 'layers': 'num_hidden_layers',
68
- },
69
- 'pooler': 'cls_pooler',
70
- },
71
- # https://huggingface.co/docs/transformers/model_doc/m2m_100
72
- 'm2m_100': {
73
- 'config_names': {
74
- 'context_length': 'max_position_embeddings',
75
- 'vocab_size': 'vocab_size',
76
- 'width': 'd_model',
77
- 'heads': 'encoder_attention_heads',
78
- 'layers': 'encoder_layers',
79
- },
80
- 'pooler': 'cls_pooler',
81
- },
82
- }
83
-
84
-
85
- """
86
- Pooling functions
87
- """
88
-
89
- _POOLERS = {}
90
-
91
-
92
- def _camel2snake(s):
93
- return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
94
-
95
-
96
- def register_pooler(cls):
97
- """Decorator registering pooler class"""
98
- _POOLERS[_camel2snake(cls.__name__)] = cls
99
- return cls
100
-
101
-
102
- @register_pooler
103
- class MeanPooler(nn.Module):
104
- """Mean pooling"""
105
-
106
- @staticmethod
107
- def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
108
- masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
109
- return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
110
-
111
-
112
- @register_pooler
113
- class MaxPooler(nn.Module):
114
- """
115
- Max pooling
116
- """
117
-
118
- @staticmethod
119
- def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
120
- masked_output = x.last_hidden_state.masked_fill(
121
- attention_mask.unsqueeze(-1), -torch.inf
122
- )
123
- return masked_output.max(1).values
124
-
125
-
126
- @register_pooler
127
- class ClsPooler(nn.Module):
128
- """
129
- CLS token pooling
130
- """
131
-
132
- def __init__(self, use_pooler_output=True):
133
- super().__init__()
134
- self.cls_token_position = 0
135
- self.use_pooler_output = use_pooler_output
136
-
137
- def forward(self, x: BaseModelOutput, _: torch.Tensor):
138
- if (
139
- self.use_pooler_output
140
- and isinstance(
141
- x,
142
- (
143
- BaseModelOutputWithPooling,
144
- BaseModelOutputWithPoolingAndCrossAttentions,
145
- ),
146
- )
147
- and (x.pooler_output is not None)
148
- ):
149
- return x.pooler_output
150
-
151
- return x.last_hidden_state[:, self.cls_token_position, :]
152
-
153
-
154
- """
155
- HF text model
156
- """
157
-
158
-
159
- class HFTextEncoder(nn.Module):
160
- output_tokens: torch.jit.Final[bool]
161
-
162
- def __init__(
163
- self,
164
- model_name_or_path: str,
165
- output_dim: int,
166
- config: PretrainedConfig = None,
167
- pooler_type: str = None,
168
- proj_type: str = None,
169
- proj_bias: bool = False,
170
- pretrained: bool = True,
171
- output_tokens: bool = False,
172
- trust_remote_code: bool = False,
173
- revision: Optional[str] = None,
174
- model_config_kwargs: Optional[Dict] = None,
175
- ):
176
- super().__init__()
177
- self.output_tokens = output_tokens
178
- self.output_dim = output_dim
179
-
180
- # TODO: find better way to get this information
181
- uses_transformer_pooler = pooler_type == 'cls_pooler'
182
- model_config_kwargs = model_config_kwargs or {}
183
-
184
- if config is None:
185
- self.config = AutoConfig.from_pretrained(
186
- model_name_or_path,
187
- trust_remote_code=trust_remote_code,
188
- code_revision=revision,
189
- )
190
- self.config.update(model_config_kwargs)
191
- create_func, model_args = (
192
- (AutoModel.from_pretrained, model_name_or_path)
193
- if pretrained
194
- else (AutoModel.from_config, self.config)
195
- )
196
- # TODO: do all model configs have this attribute?
197
- # PretrainedConfig does so yes??
198
- if (
199
- hasattr(self.config, 'is_encoder_decoder')
200
- and self.config.is_encoder_decoder
201
- ):
202
- self.transformer = create_func(model_args)
203
- self.transformer = self.transformer.encoder
204
- else:
205
- self.transformer = create_func(
206
- model_args,
207
- trust_remote_code=trust_remote_code,
208
- add_pooling_layer=uses_transformer_pooler,
209
- code_revision=revision,
210
- )
211
- else:
212
- self.config = config
213
- self.config.update(model_config_kwargs)
214
- self.transformer = AutoModel.from_config(self.config)
215
-
216
- if pooler_type is None: # get default arch pooler
217
- pooler_type = _HF_ARCH_DICT[self.config.model_type]['pooler']
218
-
219
- # FIXME downstream users of OpenCLIP models use these attr,
220
- # need to verify valid across all models
221
- self.vocab_size = getattr(self.config, 'vocab_size', 0)
222
- self.context_length = getattr(self.config, 'max_position_embeddings', 0)
223
-
224
- self.pooler = _POOLERS[pooler_type]()
225
-
226
- d_model = getattr(
227
- self.config, _HF_ARCH_DICT[self.config.model_type]['config_names']['width']
228
- )
229
- if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
230
- self.proj = nn.Identity()
231
- elif proj_type == 'linear':
232
- self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
233
- elif proj_type == 'mlp':
234
- hidden_size = (d_model + output_dim) // 2
235
- self.proj = nn.Sequential(
236
- nn.Linear(d_model, hidden_size, bias=proj_bias),
237
- nn.GELU(),
238
- nn.Linear(hidden_size, output_dim, bias=proj_bias),
239
- )
240
-
241
- def forward(self, x: torch.Tensor):
242
- attn_mask = (x != self.config.pad_token_id).long()
243
- out = self.transformer(input_ids=x, attention_mask=attn_mask)
244
- pooled_out = self.pooler(out, attn_mask)
245
- projected = self.proj(pooled_out)
246
-
247
- seq_len = out.last_hidden_state.shape[1]
248
- tokens = (
249
- out.last_hidden_state[
250
- :, torch.arange(seq_len) != self.pooler.cls_token_position, :
251
- ]
252
- if isinstance(self.pooler, ClsPooler)
253
- else out.last_hidden_state
254
- )
255
-
256
- if self.output_tokens:
257
- return projected, tokens
258
- return projected
259
-
260
- def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
261
- if not unlocked_layers: # full freezing
262
- for n, p in self.transformer.named_parameters():
263
- p.requires_grad = (
264
- (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
265
- )
266
- return
267
-
268
- encoder = (
269
- self.transformer.encoder
270
- if hasattr(self.transformer, 'encoder')
271
- else self.transformer
272
- )
273
- layer_list = getattr(
274
- encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
275
- )
276
- print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
277
- embeddings = getattr(
278
- self.transformer,
279
- _HF_ARCH_DICT[self.config.model_type]['config_names'][
280
- 'token_embeddings_attr'
281
- ],
282
- )
283
- modules = [embeddings, *layer_list][:-unlocked_layers]
284
- # freeze layers
285
- for module in modules:
286
- for n, p in module.named_parameters():
287
- p.requires_grad = (
288
- (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
289
- )
290
-
291
- @torch.jit.ignore
292
- def set_grad_checkpointing(self, _=True):
293
- self.transformer.gradient_checkpointing_enable()
294
-
295
- def init_parameters(self):
296
- pass
297
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_clip.py DELETED
@@ -1,570 +0,0 @@
1
- # coding=utf-8
2
- #
3
- # Code mainly copied from:
4
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
5
- # and adjusted for Jina CLIP
6
-
7
- from functools import partial
8
- from typing import List, Optional, Tuple, Union
9
- from io import BytesIO
10
- import requests
11
- import base64
12
- import numpy as np
13
- import torch
14
- import torch.nn.functional as f
15
- import torch.utils.checkpoint
16
- from torch import nn
17
- from transformers import (
18
- AutoImageProcessor,
19
- AutoTokenizer,
20
- BatchEncoding,
21
- BatchFeature,
22
- PreTrainedModel,
23
- logging,
24
- )
25
- from transformers.models.clip.modeling_clip import (
26
- CLIPOutput,
27
- CLIPTextModelOutput,
28
- CLIPVisionModelOutput,
29
- clip_loss,
30
- )
31
-
32
- try:
33
- from tqdm.autonotebook import trange
34
-
35
- has_tqdm = True
36
- except ImportError:
37
- has_tqdm = False
38
-
39
- from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
40
- from .eva_model import EVAVisionTransformer
41
- from .hf_model import HFTextEncoder
42
- # needed for HF to correctly import in cache
43
- from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
44
- from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform # noqa: F401
45
-
46
- logger = logging.get_logger(__name__)
47
-
48
-
49
- """ Jina CLIP model implementation """
50
-
51
-
52
- class LayerNorm(nn.LayerNorm):
53
- """Subclass torch's LayerNorm (with cast back to input dtype)."""
54
-
55
- def forward(self, x: torch.Tensor):
56
- origtype = x.dtype
57
- x = f.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
58
- return x.to(origtype)
59
-
60
-
61
- def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder:
62
- return HFTextEncoder(
63
- model_name_or_path=config.hf_model_name_or_path,
64
- output_dim=config.embed_dim,
65
- pooler_type=config.pooler_type,
66
- proj_type=config.proj_type,
67
- proj_bias=config.proj_bias,
68
- pretrained=False,
69
- output_tokens=False,
70
- trust_remote_code=True,
71
- revision=None,
72
- model_config_kwargs=config.hf_model_config_kwargs,
73
- )
74
-
75
-
76
- def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer:
77
- norm_layer = partial(LayerNorm, eps=1e-6)
78
-
79
- if config.fused_layer_norm:
80
- try:
81
- from apex.normalization import FusedLayerNorm
82
-
83
- norm_layer = partial(FusedLayerNorm, eps=1e-6)
84
- except (ModuleNotFoundError, ImportError):
85
- logger.warning('Please install apex to use fused layer norm, ignoring')
86
-
87
- return EVAVisionTransformer(
88
- img_size=config.image_size,
89
- patch_size=config.patch_size,
90
- num_classes=config.embed_dim,
91
- use_mean_pooling=False,
92
- init_values=config.ls_init_value,
93
- patch_dropout=config.patch_dropout,
94
- embed_dim=config.width,
95
- depth=config.layers,
96
- num_heads=config.width // config.head_width,
97
- mlp_ratio=config.mlp_ratio,
98
- qkv_bias=config.qkv_bias,
99
- drop_path_rate=config.drop_path_rate,
100
- norm_layer=norm_layer,
101
- xattn=config.x_attention,
102
- rope=config.rope_embeddings,
103
- postnorm=config.post_norm,
104
- pt_hw_seq_len=config.pt_hw_seq_len,
105
- intp_freq=config.intp_freq,
106
- naiveswiglu=config.naive_swiglu,
107
- subln=config.subln,
108
- proj_type=config.proj_type,
109
- )
110
-
111
-
112
- class JinaCLIPPreTrainedModel(PreTrainedModel):
113
- """
114
- An abstract class to handle weights initialization and a simple interface for
115
- downloading and loading pretrained models.
116
- """
117
-
118
- config_class = JinaCLIPConfig
119
- base_model_prefix = 'clip'
120
- supports_gradient_checkpointing = True
121
-
122
- def _init_weights(self, module):
123
- """Initialize the weights"""
124
- if isinstance(module, JinaCLIPModel):
125
- if isinstance(module.text_projection, nn.Linear):
126
- nn.init.normal_(
127
- module.text_projection.weight,
128
- std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
129
- )
130
- if isinstance(module.text_projection, nn.Linear):
131
- nn.init.normal_(
132
- module.visual_projection.weight,
133
- std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
134
- )
135
- if isinstance(module, nn.LayerNorm):
136
- module.bias.data.zero_()
137
- module.weight.data.fill_(1.0)
138
- if isinstance(module, nn.Linear) and module.bias is not None:
139
- module.bias.data.zero_()
140
-
141
-
142
- class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
143
- config_class = JinaCLIPTextConfig
144
-
145
- def __init__(self, config: JinaCLIPTextConfig):
146
- super().__init__(config)
147
- self.text_model = _build_text_tower(config)
148
- self.post_init()
149
-
150
- def forward(
151
- self,
152
- input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
153
- return_dict: Optional[bool] = None,
154
- *_,
155
- **__,
156
- ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
157
- return_dict = (
158
- return_dict if return_dict is not None else self.config.use_return_dict
159
- )
160
- x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
161
- feats = self.text_model(x=x)
162
- out = CLIPTextModelOutput(text_embeds=feats)
163
- return out if return_dict else out.to_tuple()
164
-
165
-
166
- class JinaCLIPVisionModel(JinaCLIPPreTrainedModel):
167
- config_class = JinaCLIPVisionConfig
168
- main_input_name = 'pixel_values'
169
-
170
- def __init__(self, config: JinaCLIPVisionConfig):
171
- super().__init__(config)
172
- self.vision_model = _build_vision_tower(config)
173
- self.post_init()
174
-
175
- def forward(
176
- self,
177
- pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
178
- return_dict: Optional[bool] = None,
179
- *_,
180
- **__,
181
- ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]:
182
- return_dict = (
183
- return_dict if return_dict is not None else self.config.use_return_dict
184
- )
185
- x = (
186
- pixel_values.pixel_values
187
- if isinstance(pixel_values, BatchFeature)
188
- else pixel_values
189
- )
190
- feats = self.vision_model(x=x)
191
- out = CLIPVisionModelOutput(image_embeds=feats)
192
- return out if return_dict else out.to_tuple()
193
-
194
-
195
- class JinaCLIPModel(JinaCLIPPreTrainedModel):
196
- config_class = JinaCLIPConfig
197
-
198
- def __init__(self, config: JinaCLIPConfig):
199
- super().__init__(config)
200
-
201
- if not isinstance(config.text_config, JinaCLIPTextConfig):
202
- raise ValueError(
203
- 'Attribute config.text_config is expected to be of type '
204
- f'JinaCLIPTextConfig but is of type {type(config.text_config)}.'
205
- )
206
-
207
- if not isinstance(config.vision_config, JinaCLIPVisionConfig):
208
- raise ValueError(
209
- 'Attribute config.vision_config is expected to be of type '
210
- f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.'
211
- )
212
-
213
- text_config = config.text_config
214
- vision_config = config.vision_config
215
-
216
- if config.use_text_flash_attn is not None:
217
- text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
218
- if config.use_vision_xformers is not None:
219
- vision_config.x_attention = config.use_vision_xformers
220
-
221
- self.add_projections = config.add_projections
222
- self.projection_dim = config.projection_dim
223
- self.text_embed_dim = text_config.embed_dim
224
- self.vision_embed_dim = vision_config.embed_dim
225
-
226
- self.text_model = _build_text_tower(text_config)
227
- self.vision_model = _build_vision_tower(vision_config)
228
- self.logit_scale = nn.Parameter(
229
- torch.tensor(self.config.logit_scale_init_value)
230
- )
231
-
232
- if self.add_projections:
233
- self.visual_projection = nn.Linear(
234
- self.vision_embed_dim, self.projection_dim, bias=False
235
- )
236
- self.text_projection = nn.Linear(
237
- self.text_embed_dim, self.projection_dim, bias=False
238
- )
239
- else:
240
- self.visual_projection = nn.Identity()
241
- self.text_projection = nn.Identity()
242
-
243
- self.tokenizer = None
244
- self.preprocess = None
245
- self.post_init()
246
-
247
- def get_tokenizer(self):
248
- if not self.tokenizer:
249
- self.tokenizer = AutoTokenizer.from_pretrained(
250
- self.config._name_or_path, trust_remote_code=True
251
- )
252
- return self.tokenizer
253
-
254
- def get_preprocess(self):
255
- if not self.preprocess:
256
- self.preprocess = AutoImageProcessor.from_pretrained(
257
- self.config._name_or_path, trust_remote_code=True
258
- )
259
- return self.preprocess
260
-
261
- def get_text_features(
262
- self,
263
- input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
264
- *_,
265
- **__,
266
- ) -> torch.FloatTensor:
267
- x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
268
- return self.text_projection(self.text_model(x=x))
269
-
270
- def get_image_features(
271
- self,
272
- pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
273
- *_,
274
- **__,
275
- ) -> torch.FloatTensor:
276
- x = (
277
- pixel_values.pixel_values
278
- if isinstance(pixel_values, BatchFeature)
279
- else pixel_values
280
- )
281
- return self.visual_projection(self.vision_model(x=x))
282
-
283
- @torch.inference_mode()
284
- def encode_text(
285
- self,
286
- sentences: Union[str, List[str]],
287
- batch_size: int = 32,
288
- show_progress_bar: Optional[bool] = None,
289
- convert_to_numpy: bool = True,
290
- convert_to_tensor: bool = False,
291
- device: Optional[torch.device] = None,
292
- normalize_embeddings: bool = True,
293
- **tokenizer_kwargs,
294
- ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
295
- """
296
- Computes sentence embeddings
297
- Args:
298
- sentences(`str` or `List[str]`):
299
- Sentence or sentences to be encoded
300
- batch_size(`int`, *optional*, defaults to 32):
301
- Batch size for the computation
302
- show_progress_bar(`bool`, *optional*, defaults to None):
303
- Show a progress bar when encoding sentences.
304
- If set to None, progress bar is only shown when
305
- `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
306
- convert_to_numpy(`bool`, *optional*, defaults to True):
307
- If true, the output is a list of numpy vectors.
308
- Else, it is a list of pytorch tensors.
309
- convert_to_tensor(`bool`, *optional*, defaults to False):
310
- If true, you get one large tensor as return.
311
- Overwrites any setting from convert_to_numpy
312
- device(`torch.device`, *optional*, defaults to None):
313
- Which torch.device to use for the computation
314
- normalize_embeddings(`bool`, *optional*, defaults to False):
315
- If set to true, returned vectors will have length 1. In that case,
316
- the faster dot-product (util.dot_score) instead of cosine similarity
317
- can be used.
318
- tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
319
- Keyword arguments for the tokenizer
320
- Returns:
321
- By default, a list of tensors is returned.
322
- If convert_to_tensor, a stacked tensor is returned.
323
- If convert_to_numpy, a numpy matrix is returned.
324
- """
325
- is_training = self.training
326
- self.eval()
327
- all_embeddings = []
328
-
329
- self.tokenizer = self.get_tokenizer()
330
-
331
- if show_progress_bar is None:
332
- show_progress_bar = (
333
- logger.getEffectiveLevel() == logging.INFO
334
- or logger.getEffectiveLevel() == logging.DEBUG
335
- )
336
-
337
- if convert_to_tensor:
338
- convert_to_numpy = False
339
-
340
- input_was_string = False
341
- if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
342
- sentences = [sentences]
343
- input_was_string = True
344
-
345
- if device is not None:
346
- self.to(device)
347
-
348
- permutation = np.argsort([-len(i) for i in sentences])
349
- inverse_permutation = np.argsort(permutation)
350
- sentences = [sentences[idx] for idx in permutation]
351
-
352
- tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
353
- tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
354
- tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
355
-
356
- if has_tqdm:
357
- range_iter = trange(
358
- 0,
359
- len(sentences),
360
- batch_size,
361
- desc='Encoding',
362
- disable=not show_progress_bar,
363
- )
364
- else:
365
- range_iter = range(0, len(sentences), batch_size)
366
-
367
- for i in range_iter:
368
- encoded_input = self.tokenizer(
369
- sentences[i : i + batch_size],
370
- return_tensors='pt',
371
- **tokenizer_kwargs,
372
- ).to(self.device)
373
-
374
- embeddings = self.get_text_features(input_ids=encoded_input)
375
- if normalize_embeddings:
376
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
377
- if convert_to_numpy:
378
- embeddings = embeddings.cpu()
379
- all_embeddings.extend(embeddings)
380
-
381
- all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
382
-
383
- if convert_to_tensor:
384
- all_embeddings = torch.stack(all_embeddings)
385
- elif convert_to_numpy:
386
- all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
387
-
388
- if input_was_string:
389
- all_embeddings = all_embeddings[0]
390
-
391
- self.train(is_training)
392
- return all_embeddings
393
-
394
- def decode_data_image(data_image_str):
395
- header, data = data_image_str.split(',', 1)
396
- image_data = base64.b64decode(data)
397
- return Image.open(BytesIO(image_data))
398
-
399
- @torch.inference_mode()
400
- def encode_image(
401
- self,
402
- images: Union[str, List[Union[str, "Image.Image"]]],
403
- batch_size: int = 32,
404
- show_progress_bar: Optional[bool] = None,
405
- convert_to_numpy: bool = True,
406
- convert_to_tensor: bool = False,
407
- device: Optional[torch.device] = None,
408
- normalize_embeddings: bool = True,
409
- ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
410
- """
411
- Computes image embeddings.
412
-
413
- Args:
414
- images(`str` or `List[Union[str, Image.Image]]`):
415
- image paths, URLs, PIL images, or data:image/ strings to be encoded
416
- batch_size(`int`, *optional*, defaults to 32):
417
- Batch size for the computation
418
- show_progress_bar(`bool`, *optional*, defaults to None):
419
- Show a progress bar when encoding images.
420
- If set to None, progress bar is only shown when
421
- `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
422
- convert_to_numpy(`bool`, *optional*, defaults to True):
423
- If true, the output is a list of numpy vectors.
424
- Else, it is a list of pytorch tensors.
425
- convert_to_tensor(`bool`, *optional*, defaults to False):
426
- If true, you get one large tensor as return.
427
- Overwrites any setting from convert_to_numpy
428
- device(`torch.device`, *optional*, defaults to None):
429
- Which torch.device to use for the computation
430
- normalize_embeddings(`bool`, *optional*, defaults to False):
431
- If set to true, returned vectors will have length 1. In that case,
432
- the faster dot-product (util.dot_score) instead of cosine similarity
433
- can be used.
434
- Returns:
435
- By default, a list of tensors is returned.
436
- If convert_to_tensor, a stacked tensor is returned.
437
- If convert_to_numpy, a numpy matrix is returned.
438
- """
439
-
440
- is_training = self.training
441
- self.eval()
442
-
443
- self.preprocess = self.get_preprocess()
444
- all_embeddings = []
445
-
446
- if show_progress_bar is None:
447
- show_progress_bar = (
448
- logger.getEffectiveLevel() == logging.INFO
449
- or logger.getEffectiveLevel() == logging.DEBUG
450
- )
451
-
452
- if convert_to_tensor:
453
- convert_to_numpy = False
454
-
455
- input_was_single_img = False
456
- if isinstance(images, str) or not hasattr(images, '__len__'):
457
- images = [images]
458
- input_was_single_img = True
459
-
460
- if device is not None:
461
- self.to(device)
462
-
463
- permutation = np.argsort([-len(str(i)) for i in images])
464
- inverse_permutation = np.argsort(permutation)
465
- images = [images[idx] for idx in permutation]
466
-
467
- if has_tqdm:
468
- range_iter = trange(
469
- 0,
470
- len(images),
471
- batch_size,
472
- desc='Encoding',
473
- disable=not show_progress_bar,
474
- )
475
- else:
476
- range_iter = range(0, len(images), batch_size)
477
-
478
- from PIL import Image
479
-
480
- for i in range_iter:
481
- batch_images = images[i:i+batch_size]
482
- processed_inputs = []
483
-
484
- for img in batch_images:
485
- if isinstance(img, str):
486
- if img.startswith('http'):
487
- response = requests.get(img)
488
- image = Image.open(BytesIO(response.content)).convert('RGB')
489
- elif img.startswith('data:image/'):
490
- image = decode_data_image(img).convert('RGB')
491
- else:
492
- image = Image.open(img).convert('RGB')
493
- elif isinstance(img, Image.Image):
494
- image = img.convert('RGB')
495
- else:
496
- raise ValueError("Unsupported image format")
497
-
498
- processed_inputs.append(image)
499
-
500
- processed_inputs = self.preprocess(processed_inputs)
501
- processed_inputs = processed_inputs.to(self.device)
502
- embeddings = self.get_image_features(processed_inputs)
503
-
504
- if normalize_embeddings:
505
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
506
- if convert_to_numpy:
507
- embeddings = embeddings.cpu()
508
- all_embeddings.extend(embeddings)
509
-
510
- all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
511
-
512
- if convert_to_tensor:
513
- all_embeddings = torch.stack(all_embeddings)
514
- elif convert_to_numpy:
515
- all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
516
-
517
- if input_was_single_img:
518
- all_embeddings = all_embeddings[0]
519
-
520
- self.train(is_training)
521
- return all_embeddings
522
-
523
- def forward(
524
- self,
525
- input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
526
- pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
527
- return_dict: Optional[bool] = None,
528
- return_loss: Optional[bool] = None,
529
- *_,
530
- **__,
531
- ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPOutput]:
532
- return_dict = (
533
- return_dict if return_dict is not None else self.config.use_return_dict
534
- )
535
- image_embeds = self.get_image_features(pixel_values=pixel_values)
536
- text_embeds = self.get_text_features(input_ids=input_ids)
537
-
538
- # normalized features
539
- image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
540
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
541
-
542
- # cosine similarity as logits
543
- logit_scale = self.logit_scale.exp()
544
- logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
545
- logits_per_image = logits_per_text.t()
546
-
547
- loss = None
548
- if return_loss:
549
- loss = clip_loss(logits_per_text)
550
-
551
- if not return_dict:
552
- output = (
553
- logits_per_image,
554
- logits_per_text,
555
- text_embeds,
556
- image_embeds,
557
- None,
558
- None,
559
- )
560
- return ((loss,) + output) if loss is not None else output
561
-
562
- return CLIPOutput(
563
- loss=loss,
564
- logits_per_image=logits_per_image,
565
- logits_per_text=logits_per_text,
566
- text_embeds=text_embeds,
567
- image_embeds=image_embeds,
568
- text_model_output=None,
569
- vision_model_output=None,
570
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules.json CHANGED
@@ -3,6 +3,6 @@
3
  "idx": 0,
4
  "name": "0",
5
  "path": "",
6
- "type": "jina-clip-implementation-st.6d5aa7d8b428eaba8d7908d86f43f5dd5ad6ad93.custom_st.Transformer"
7
  }
8
  ]
 
3
  "idx": 0,
4
  "name": "0",
5
  "path": "",
6
+ "type": "tomaarsen/jina-clip-implementation-st--custom_st.Transformer"
7
  }
8
  ]
processing_clip.py DELETED
@@ -1,88 +0,0 @@
1
- # coding=utf-8
2
- #
3
- # Code mainly copied from:
4
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py
5
- # and adjusted for Jina CLIP
6
-
7
- from typing import Tuple, Union
8
-
9
- import torch
10
- from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
11
- from transformers.image_utils import ImageInput, make_list_of_images
12
- from transformers.models.clip import CLIPProcessor
13
-
14
- from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform
15
-
16
- """ Jina CLIP processor implementation """
17
-
18
-
19
- class JinaCLIPProcessor(CLIPProcessor):
20
- image_processor_class = 'AutoImageProcessor'
21
- tokenizer_class = 'AutoTokenizer'
22
-
23
-
24
- """ Jina CLIP image processor implementation """
25
-
26
-
27
- class JinaCLIPImageProcessor(BaseImageProcessor):
28
- model_input_names = ['pixel_values']
29
- _valid_processor_keys = [
30
- 'size',
31
- 'mean',
32
- 'std',
33
- 'resize_mode',
34
- 'interpolation',
35
- 'fill_color',
36
- ]
37
-
38
- def __init__(
39
- self,
40
- size: Union[int, Tuple[int, int]] = 224,
41
- mean: Union[float, Tuple[float]] = OPENAI_DATASET_MEAN,
42
- std: Union[float, Tuple[float]] = OPENAI_DATASET_STD,
43
- resize_mode: str = 'shortest',
44
- interpolation: str = 'bicubic',
45
- fill_color: int = 0,
46
- **kwargs,
47
- ) -> None:
48
- super().__init__(**kwargs)
49
- self.size = size
50
- self.mean = mean
51
- self.std = std
52
- self.resize_mode = resize_mode
53
- self.interpolation = interpolation
54
- self.fill_color = fill_color
55
- self.transform = self._build_transform()
56
-
57
- def _build_transform(self):
58
- return image_transform(
59
- image_size=self.size,
60
- is_train=False,
61
- mean=self.mean,
62
- std=self.std,
63
- resize_mode=self.resize_mode,
64
- interpolation=self.interpolation,
65
- fill_color=self.fill_color,
66
- aug_cfg=None,
67
- )
68
-
69
- def to_dict(self):
70
- output = super().to_dict()
71
- output.pop('transform')
72
- return output
73
-
74
- def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
75
-
76
- _transform_needs_rebuild = False
77
- for k, v in kwargs.items():
78
- if k in self._valid_processor_keys:
79
- if v != getattr(self, k):
80
- setattr(self, k, v)
81
- _transform_needs_rebuild = True
82
-
83
- if _transform_needs_rebuild:
84
- self.transform = self._build_transform()
85
-
86
- images = make_list_of_images(images)
87
- out = torch.stack([self.transform(image) for image in images], dim=0)
88
- return BatchFeature(data={'pixel_values': out})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rope_embeddings.py DELETED
@@ -1,165 +0,0 @@
1
- # --------------------------------------------------------
2
- # Adapted from EVA CLIP
3
- # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
- # --------------------------------------------------------
5
-
6
- import logging
7
- from math import pi
8
-
9
- import torch
10
- from einops import rearrange, repeat
11
- from torch import nn
12
-
13
-
14
- def broadcast(tensors, dim=-1):
15
- num_tensors = len(tensors)
16
- shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
17
- assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
18
- shape_len = list(shape_lens)[0]
19
- dim = (dim + shape_len) if dim < 0 else dim
20
- dims = list(zip(*map(lambda t: list(t.shape), tensors)))
21
- expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
22
- assert all(
23
- [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
24
- ), 'invalid dimensions for broadcastable concatentation'
25
- max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
26
- expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
27
- expanded_dims.insert(dim, (dim, dims[dim]))
28
- expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
29
- tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
30
- return torch.cat(tensors, dim=dim)
31
-
32
-
33
- def rotate_half(x):
34
- x = rearrange(x, '... (d r) -> ... d r', r=2)
35
- x1, x2 = x.unbind(dim=-1)
36
- x = torch.stack((-x2, x1), dim=-1)
37
- return rearrange(x, '... d r -> ... (d r)')
38
-
39
-
40
- class VisionRotaryEmbedding(nn.Module):
41
- def __init__(
42
- self,
43
- dim,
44
- pt_seq_len,
45
- ft_seq_len=None,
46
- custom_freqs=None,
47
- freqs_for='lang',
48
- theta=10000,
49
- max_freq=10,
50
- num_freqs=1,
51
- ):
52
- super().__init__()
53
- if custom_freqs:
54
- freqs = custom_freqs
55
- elif freqs_for == 'lang':
56
- freqs = 1.0 / (
57
- theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
58
- )
59
- elif freqs_for == 'pixel':
60
- freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
61
- elif freqs_for == 'constant':
62
- freqs = torch.ones(num_freqs).float()
63
- else:
64
- raise ValueError(f'unknown modality {freqs_for}')
65
-
66
- if ft_seq_len is None:
67
- ft_seq_len = pt_seq_len
68
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
69
-
70
- freqs_h = torch.einsum('..., f -> ... f', t, freqs)
71
- freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2)
72
-
73
- freqs_w = torch.einsum('..., f -> ... f', t, freqs)
74
- freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2)
75
-
76
- freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
77
-
78
- self.register_buffer('freqs_cos', freqs.cos())
79
- self.register_buffer('freqs_sin', freqs.sin())
80
-
81
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
82
-
83
- def forward(self, t, start_index=0):
84
- rot_dim = self.freqs_cos.shape[-1]
85
- end_index = start_index + rot_dim
86
- assert rot_dim <= t.shape[-1], (
87
- f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in '
88
- f'all the positions {rot_dim}'
89
- )
90
- t_left, t, t_right = (
91
- t[..., :start_index],
92
- t[..., start_index:end_index],
93
- t[..., end_index:],
94
- )
95
- t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
96
-
97
- return torch.cat((t_left, t, t_right), dim=-1)
98
-
99
-
100
- class VisionRotaryEmbeddingFast(nn.Module):
101
- def __init__(
102
- self,
103
- dim,
104
- pt_seq_len,
105
- ft_seq_len=None,
106
- custom_freqs=None,
107
- freqs_for='lang',
108
- theta=10000,
109
- max_freq=10,
110
- num_freqs=1,
111
- patch_dropout=0.0,
112
- ):
113
- super().__init__()
114
- if custom_freqs:
115
- freqs = custom_freqs
116
- elif freqs_for == 'lang':
117
- freqs = 1.0 / (
118
- theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
119
- )
120
- elif freqs_for == 'pixel':
121
- freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
122
- elif freqs_for == 'constant':
123
- freqs = torch.ones(num_freqs).float()
124
- else:
125
- raise ValueError(f'unknown modality {freqs_for}')
126
-
127
- if ft_seq_len is None:
128
- ft_seq_len = pt_seq_len
129
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
130
-
131
- freqs = torch.einsum('..., f -> ... f', t, freqs)
132
- freqs = repeat(freqs, '... n -> ... (n r)', r=2)
133
- freqs = broadcast((freqs[:, None, :], freqs[None, :, :]), dim=-1)
134
-
135
- freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
136
- freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
137
-
138
- self.patch_dropout = patch_dropout
139
-
140
- self.register_buffer('freqs_cos', freqs_cos)
141
- self.register_buffer('freqs_sin', freqs_sin)
142
-
143
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
144
-
145
- def forward(self, t, patch_indices_keep=None):
146
- if patch_indices_keep is not None:
147
- batch = t.size()[0]
148
- batch_indices = torch.arange(batch)
149
- batch_indices = batch_indices[..., None]
150
-
151
- freqs_cos = repeat(
152
- self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
153
- )
154
- freqs_sin = repeat(
155
- self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
156
- )
157
-
158
- freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
159
- freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
160
- freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
161
- freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
162
-
163
- return t * freqs_cos + rotate_half(t) * freqs_sin
164
-
165
- return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transform.py DELETED
@@ -1,458 +0,0 @@
1
- import numbers
2
- import random
3
- import warnings
4
- from dataclasses import asdict, dataclass
5
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6
-
7
- import torch
8
- import torchvision.transforms.functional as F
9
- from torchvision.transforms import (
10
- CenterCrop,
11
- ColorJitter,
12
- Compose,
13
- Grayscale,
14
- InterpolationMode,
15
- Normalize,
16
- RandomResizedCrop,
17
- Resize,
18
- ToTensor,
19
- )
20
- from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
21
-
22
- OPENAI_DATASET_MEAN = tuple(OPENAI_CLIP_MEAN)
23
- OPENAI_DATASET_STD = tuple(OPENAI_CLIP_STD)
24
-
25
-
26
- @dataclass
27
- class PreprocessCfg:
28
- size: Union[int, Tuple[int, int]] = 224
29
- mode: str = 'RGB'
30
- mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
31
- std: Tuple[float, ...] = OPENAI_DATASET_STD
32
- interpolation: str = 'bicubic'
33
- resize_mode: str = 'shortest'
34
- fill_color: int = 0
35
-
36
- def __post_init__(self):
37
- assert self.mode in ('RGB',)
38
-
39
- @property
40
- def num_channels(self):
41
- return 3
42
-
43
- @property
44
- def input_size(self):
45
- return (self.num_channels,) + (self.size, self.size)
46
-
47
-
48
- _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
49
-
50
-
51
- def merge_preprocess_dict(
52
- base: Union[PreprocessCfg, Dict],
53
- overlay: Dict,
54
- ):
55
- """Merge overlay key-value pairs on top of base preprocess cfg or dict.
56
- Input dicts are filtered based on PreprocessCfg fields.
57
- """
58
- if isinstance(base, PreprocessCfg):
59
- base_clean = asdict(base)
60
- else:
61
- base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
62
- if overlay:
63
- overlay_clean = {
64
- k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None
65
- }
66
- base_clean.update(overlay_clean)
67
- return base_clean
68
-
69
-
70
- def merge_preprocess_kwargs(base: Union[PreprocessCfg, Dict], **kwargs):
71
- return merge_preprocess_dict(base, kwargs)
72
-
73
-
74
- @dataclass
75
- class AugmentationCfg:
76
- scale: Tuple[float, float] = (0.9, 1.0)
77
- ratio: Optional[Tuple[float, float]] = None
78
- color_jitter: Optional[
79
- Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]
80
- ] = None
81
- re_prob: Optional[float] = None
82
- re_count: Optional[int] = None
83
- use_timm: bool = False
84
-
85
- # params for simclr_jitter_gray
86
- color_jitter_prob: float = None
87
- gray_scale_prob: float = None
88
-
89
-
90
- def _setup_size(size, error_msg):
91
- if isinstance(size, numbers.Number):
92
- return int(size), int(size)
93
-
94
- if isinstance(size, Sequence) and len(size) == 1:
95
- return size[0], size[0]
96
-
97
- if len(size) != 2:
98
- raise ValueError(error_msg)
99
-
100
- return size
101
-
102
-
103
- class ResizeKeepRatio:
104
- """Resize and Keep Ratio
105
-
106
- Copy & paste from `timm`
107
- """
108
-
109
- def __init__(
110
- self,
111
- size,
112
- longest=0.0,
113
- interpolation=InterpolationMode.BICUBIC,
114
- random_scale_prob=0.0,
115
- random_scale_range=(0.85, 1.05),
116
- random_aspect_prob=0.0,
117
- random_aspect_range=(0.9, 1.11),
118
- ):
119
- if isinstance(size, (list, tuple)):
120
- self.size = tuple(size)
121
- else:
122
- self.size = (size, size)
123
- self.interpolation = interpolation
124
- self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
125
- self.random_scale_prob = random_scale_prob
126
- self.random_scale_range = random_scale_range
127
- self.random_aspect_prob = random_aspect_prob
128
- self.random_aspect_range = random_aspect_range
129
-
130
- @staticmethod
131
- def get_params(
132
- img,
133
- target_size,
134
- longest,
135
- random_scale_prob=0.0,
136
- random_scale_range=(0.85, 1.05),
137
- random_aspect_prob=0.0,
138
- random_aspect_range=(0.9, 1.11),
139
- ):
140
- """Get parameters"""
141
- source_size = img.size[::-1] # h, w
142
- h, w = source_size
143
- target_h, target_w = target_size
144
- ratio_h = h / target_h
145
- ratio_w = w / target_w
146
- ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (
147
- 1.0 - longest
148
- )
149
- if random_scale_prob > 0 and random.random() < random_scale_prob:
150
- ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
151
- ratio_factor = (ratio_factor, ratio_factor)
152
- else:
153
- ratio_factor = (1.0, 1.0)
154
- if random_aspect_prob > 0 and random.random() < random_aspect_prob:
155
- aspect_factor = random.uniform(
156
- random_aspect_range[0], random_aspect_range[1]
157
- )
158
- ratio_factor = (
159
- ratio_factor[0] / aspect_factor,
160
- ratio_factor[1] * aspect_factor,
161
- )
162
- size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
163
- return size
164
-
165
- def __call__(self, img):
166
- """
167
- Args:
168
- img (PIL Image): Image to be cropped and resized.
169
-
170
- Returns:
171
- PIL Image: Resized, padded to at least target size, possibly
172
- cropped to exactly target size
173
- """
174
- size = self.get_params(
175
- img,
176
- self.size,
177
- self.longest,
178
- self.random_scale_prob,
179
- self.random_scale_range,
180
- self.random_aspect_prob,
181
- self.random_aspect_range,
182
- )
183
- img = F.resize(img, size, self.interpolation)
184
- return img
185
-
186
- def __repr__(self):
187
- format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
188
- format_string += f', interpolation={self.interpolation})'
189
- format_string += f', longest={self.longest:.3f})'
190
- return format_string
191
-
192
-
193
- def center_crop_or_pad(
194
- img: torch.Tensor, output_size: List[int], fill=0
195
- ) -> torch.Tensor:
196
- """Center crops and/or pads the given image.
197
- If the image is torch Tensor, it is expected
198
- to have [..., H, W] shape, where ... means an arbitrary number of leading
199
- dimensions. If image size is smaller than output size along any edge, image is
200
- padded with 0 and then center cropped.
201
-
202
- Args:
203
- img (PIL Image or Tensor): Image to be cropped.
204
- output_size (sequence or int): (height, width) of the crop box. If int or
205
- sequence with single int, it is used for both directions.
206
- fill (int, Tuple[int]): Padding color
207
-
208
- Returns:
209
- PIL Image or Tensor: Cropped image.
210
- """
211
- if isinstance(output_size, numbers.Number):
212
- output_size = (int(output_size), int(output_size))
213
- elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
214
- output_size = (output_size[0], output_size[0])
215
-
216
- _, image_height, image_width = F.get_dimensions(img)
217
- crop_height, crop_width = output_size
218
-
219
- if crop_width > image_width or crop_height > image_height:
220
- padding_ltrb = [
221
- (crop_width - image_width) // 2 if crop_width > image_width else 0,
222
- (crop_height - image_height) // 2 if crop_height > image_height else 0,
223
- (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
224
- (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
225
- ]
226
- img = F.pad(img, padding_ltrb, fill=fill)
227
- _, image_height, image_width = F.get_dimensions(img)
228
- if crop_width == image_width and crop_height == image_height:
229
- return img
230
-
231
- crop_top = int(round((image_height - crop_height) / 2.0))
232
- crop_left = int(round((image_width - crop_width) / 2.0))
233
- return F.crop(img, crop_top, crop_left, crop_height, crop_width)
234
-
235
-
236
- class CenterCropOrPad(torch.nn.Module):
237
- """Crops the given image at the center.
238
- If the image is torch Tensor, it is expected
239
- to have [..., H, W] shape, where ... means an arbitrary number of leading
240
- dimensions. If image size is smaller than output size along any edge, image is
241
- padded with 0 and then center cropped.
242
-
243
- Args:
244
- size (sequence or int): Desired output size of the crop. If size is an
245
- int instead of sequence like (h, w), a square crop (size, size) is
246
- made. If provided a sequence of length 1, it will be interpreted as
247
- (size[0], size[0]).
248
- """
249
-
250
- def __init__(self, size, fill=0):
251
- super().__init__()
252
- self.size = _setup_size(
253
- size, error_msg='Please provide only two dimensions (h, w) for size.'
254
- )
255
- self.fill = fill
256
-
257
- def forward(self, img):
258
- """
259
- Args:
260
- img (PIL Image or Tensor): Image to be cropped.
261
-
262
- Returns:
263
- PIL Image or Tensor: Cropped image.
264
- """
265
- return center_crop_or_pad(img, self.size, fill=self.fill)
266
-
267
- def __repr__(self) -> str:
268
- return f'{self.__class__.__name__}(size={self.size})'
269
-
270
-
271
- def _convert_to_rgb(image):
272
- return image.convert('RGB')
273
-
274
-
275
- class _ColorJitter(object):
276
- """
277
- Apply Color Jitter to the PIL image with a specified probability.
278
- """
279
-
280
- def __init__(self, brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0, p=0.8):
281
- assert 0.0 <= p <= 1.0
282
- self.p = p
283
- self.transf = ColorJitter(
284
- brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
285
- )
286
-
287
- def __call__(self, img):
288
- if random.random() < self.p:
289
- return self.transf(img)
290
- else:
291
- return img
292
-
293
-
294
- class _GrayScale(object):
295
- """
296
- Apply Gray Scale to the PIL image with a specified probability.
297
- """
298
-
299
- def __init__(self, p=0.2):
300
- assert 0.0 <= p <= 1.0
301
- self.p = p
302
- self.transf = Grayscale(num_output_channels=3)
303
-
304
- def __call__(self, img):
305
- if random.random() < self.p:
306
- return self.transf(img)
307
- else:
308
- return img
309
-
310
-
311
- def image_transform(
312
- image_size: Union[int, Tuple[int, int]],
313
- is_train: bool,
314
- mean: Optional[Tuple[float, ...]] = None,
315
- std: Optional[Tuple[float, ...]] = None,
316
- resize_mode: Optional[str] = None,
317
- interpolation: Optional[str] = None,
318
- fill_color: int = 0,
319
- aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
320
- ):
321
- mean = mean or OPENAI_DATASET_MEAN
322
- if not isinstance(mean, (list, tuple)):
323
- mean = (mean,) * 3
324
-
325
- std = std or OPENAI_DATASET_STD
326
- if not isinstance(std, (list, tuple)):
327
- std = (std,) * 3
328
-
329
- interpolation = interpolation or 'bicubic'
330
- assert interpolation in ['bicubic', 'bilinear', 'random']
331
- # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for
332
- # inference if set
333
- interpolation_mode = (
334
- InterpolationMode.BILINEAR
335
- if interpolation == 'bilinear'
336
- else InterpolationMode.BICUBIC
337
- )
338
-
339
- resize_mode = resize_mode or 'shortest'
340
- assert resize_mode in ('shortest', 'longest', 'squash')
341
-
342
- if isinstance(aug_cfg, dict):
343
- aug_cfg = AugmentationCfg(**aug_cfg)
344
- else:
345
- aug_cfg = aug_cfg or AugmentationCfg()
346
-
347
- normalize = Normalize(mean=mean, std=std)
348
-
349
- if is_train:
350
- aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
351
- use_timm = aug_cfg_dict.pop('use_timm', False)
352
- if use_timm:
353
- from timm.data import create_transform # timm can still be optional
354
-
355
- if isinstance(image_size, (tuple, list)):
356
- assert len(image_size) >= 2
357
- input_size = (3,) + image_size[-2:]
358
- else:
359
- input_size = (3, image_size, image_size)
360
-
361
- aug_cfg_dict.setdefault('color_jitter', None) # disable by default
362
- # drop extra non-timm items
363
- aug_cfg_dict.pop('color_jitter_prob', None)
364
- aug_cfg_dict.pop('gray_scale_prob', None)
365
-
366
- train_transform = create_transform(
367
- input_size=input_size,
368
- is_training=True,
369
- hflip=0.0,
370
- mean=mean,
371
- std=std,
372
- re_mode='pixel',
373
- interpolation=interpolation,
374
- **aug_cfg_dict,
375
- )
376
- else:
377
- train_transform = [
378
- RandomResizedCrop(
379
- image_size,
380
- scale=aug_cfg_dict.pop('scale'),
381
- interpolation=InterpolationMode.BICUBIC,
382
- ),
383
- _convert_to_rgb,
384
- ]
385
- if aug_cfg.color_jitter_prob:
386
- assert (
387
- aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
388
- )
389
- train_transform.extend(
390
- [_ColorJitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)]
391
- )
392
- if aug_cfg.gray_scale_prob:
393
- train_transform.extend([_GrayScale(aug_cfg.gray_scale_prob)])
394
- train_transform.extend(
395
- [
396
- ToTensor(),
397
- normalize,
398
- ]
399
- )
400
- train_transform = Compose(train_transform)
401
- if aug_cfg_dict:
402
- warnings.warn(
403
- f'Unused augmentation cfg items, specify `use_timm` to use '
404
- f'({list(aug_cfg_dict.keys())}).'
405
- )
406
- return train_transform
407
- else:
408
- if resize_mode == 'longest':
409
- transforms = [
410
- ResizeKeepRatio(
411
- image_size, interpolation=interpolation_mode, longest=1
412
- ),
413
- CenterCropOrPad(image_size, fill=fill_color),
414
- ]
415
- elif resize_mode == 'squash':
416
- if isinstance(image_size, int):
417
- image_size = (image_size, image_size)
418
- transforms = [
419
- Resize(image_size, interpolation=interpolation_mode),
420
- ]
421
- else:
422
- assert resize_mode == 'shortest'
423
- if not isinstance(image_size, (tuple, list)):
424
- image_size = (image_size, image_size)
425
- if image_size[0] == image_size[1]:
426
- # simple case, use torchvision built-in Resize w/ shortest edge mode
427
- # (scalar size arg)
428
- transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
429
- else:
430
- # resize shortest edge to matching target dim for non-square target
431
- transforms = [ResizeKeepRatio(image_size)]
432
- transforms += [CenterCrop(image_size)]
433
-
434
- transforms.extend(
435
- [
436
- _convert_to_rgb,
437
- ToTensor(),
438
- normalize,
439
- ]
440
- )
441
- return Compose(transforms)
442
-
443
-
444
- def image_transform_v2(
445
- cfg: PreprocessCfg,
446
- is_train: bool,
447
- aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
448
- ):
449
- return image_transform(
450
- image_size=cfg.size,
451
- is_train=is_train,
452
- mean=cfg.mean,
453
- std=cfg.std,
454
- interpolation=cfg.interpolation,
455
- resize_mode=cfg.resize_mode,
456
- fill_color=cfg.fill_color,
457
- aug_cfg=aug_cfg,
458
- )