VictorSanh commited on
Commit
e06a98d
1 Parent(s): 9956005

align implementation on transformers + include navit style changes (these changes are backward compatible)

Browse files
config.json CHANGED
@@ -21,7 +21,7 @@
21
  "vocab_size": 32000
22
  },
23
  "torch_dtype": "float32",
24
- "transformers_version": "4.35.2",
25
  "vision_config": {
26
  "hidden_size": 144,
27
  "image_size": 30,
 
21
  "vocab_size": 32000
22
  },
23
  "torch_dtype": "float32",
24
+ "transformers_version": "4.37.0.dev0",
25
  "vision_config": {
26
  "hidden_size": 144,
27
  "image_size": 30,
configuration_siglip.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -15,16 +15,9 @@
15
  """ Siglip model configuration"""
16
 
17
  import os
18
- from collections import OrderedDict
19
- from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
20
-
21
-
22
- if TYPE_CHECKING:
23
- from transformers.processing_utils import ProcessorMixin
24
- from transformers.utils import TensorType
25
 
26
  from transformers.configuration_utils import PretrainedConfig
27
- from transformers.onnx import OnnxConfig
28
  from transformers.utils import logging
29
 
30
 
@@ -46,16 +39,16 @@ class SiglipTextConfig(PretrainedConfig):
46
  documentation from [`PretrainedConfig`] for more information.
47
 
48
  Args:
49
- vocab_size (`int`, *optional*, defaults to 49408):
50
  Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
51
  the `inputs_ids` passed when calling [`SiglipModel`].
52
- hidden_size (`int`, *optional*, defaults to 512):
53
  Dimensionality of the encoder layers and the pooler layer.
54
- intermediate_size (`int`, *optional*, defaults to 2048):
55
  Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
56
  num_hidden_layers (`int`, *optional*, defaults to 12):
57
  Number of hidden layers in the Transformer encoder.
58
- num_attention_heads (`int`, *optional*, defaults to 8):
59
  Number of attention heads for each attention layer in the Transformer encoder.
60
  max_position_embeddings (`int`, *optional*, defaults to 64):
61
  The maximum sequence length that this model might ever be used with. Typically set this to something large
@@ -63,15 +56,16 @@ class SiglipTextConfig(PretrainedConfig):
63
  hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
64
  The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
65
  `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
66
- layer_norm_eps (`float`, *optional*, defaults to 1e-6):
67
  The epsilon used by the layer normalization layers.
68
  attention_dropout (`float`, *optional*, defaults to 0.0):
69
  The dropout ratio for the attention probabilities.
70
- initializer_range (`float`, *optional*, defaults to 0.02):
71
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72
- initializer_factor (`float`, *optional*, defaults to 1):
73
- A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
74
- testing).
 
75
 
76
  Example:
77
 
@@ -87,27 +81,26 @@ class SiglipTextConfig(PretrainedConfig):
87
  >>> # Accessing the model configuration
88
  >>> configuration = model.config
89
  ```"""
 
90
  model_type = "siglip_text_model"
91
 
92
  def __init__(
93
  self,
94
- vocab_size=49408,
95
- hidden_size=512,
96
- intermediate_size=2048,
97
- projection_dim=512,
98
  num_hidden_layers=12,
99
- num_attention_heads=8,
100
  max_position_embeddings=64,
101
  hidden_act="gelu_pytorch_tanh",
102
  layer_norm_eps=1e-6,
103
  attention_dropout=0.0,
104
- initializer_range=0.02,
105
- initializer_factor=1.0,
106
  # This differs from `CLIPTokenizer`'s default and from openai/siglip
107
  # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
108
  pad_token_id=1,
109
  bos_token_id=49406,
110
  eos_token_id=49407,
 
111
  **kwargs,
112
  ):
113
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -115,15 +108,13 @@ class SiglipTextConfig(PretrainedConfig):
115
  self.vocab_size = vocab_size
116
  self.hidden_size = hidden_size
117
  self.intermediate_size = intermediate_size
118
- self.projection_dim = projection_dim
119
  self.num_hidden_layers = num_hidden_layers
120
  self.num_attention_heads = num_attention_heads
121
  self.max_position_embeddings = max_position_embeddings
122
  self.layer_norm_eps = layer_norm_eps
123
  self.hidden_act = hidden_act
124
- self.initializer_range = initializer_range
125
- self.initializer_factor = initializer_factor
126
  self.attention_dropout = attention_dropout
 
127
 
128
  @classmethod
129
  def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
@@ -163,22 +154,19 @@ class SiglipVisionConfig(PretrainedConfig):
163
  Number of hidden layers in the Transformer encoder.
164
  num_attention_heads (`int`, *optional*, defaults to 12):
165
  Number of attention heads for each attention layer in the Transformer encoder.
 
 
166
  image_size (`int`, *optional*, defaults to 224):
167
  The size (resolution) of each image.
168
- patch_size (`int`, *optional*, defaults to 32):
169
  The size (resolution) of each patch.
170
  hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
171
  The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
172
  `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
173
- layer_norm_eps (`float`, *optional*, defaults to 1e-6):
174
  The epsilon used by the layer normalization layers.
175
  attention_dropout (`float`, *optional*, defaults to 0.0):
176
  The dropout ratio for the attention probabilities.
177
- initializer_range (`float`, *optional*, defaults to 0.02):
178
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
179
- initializer_factor (`float`, *optional*, defaults to 1):
180
- A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
181
- testing).
182
 
183
  Example:
184
 
@@ -201,34 +189,30 @@ class SiglipVisionConfig(PretrainedConfig):
201
  self,
202
  hidden_size=768,
203
  intermediate_size=3072,
204
- projection_dim=512,
205
  num_hidden_layers=12,
206
  num_attention_heads=12,
207
  num_channels=3,
208
  image_size=224,
209
- patch_size=32,
210
  hidden_act="gelu_pytorch_tanh",
211
  layer_norm_eps=1e-6,
212
  attention_dropout=0.0,
213
- initializer_range=0.02,
214
- initializer_factor=1.0,
215
  **kwargs,
216
  ):
217
  super().__init__(**kwargs)
218
 
219
  self.hidden_size = hidden_size
220
  self.intermediate_size = intermediate_size
221
- self.projection_dim = projection_dim
222
  self.num_hidden_layers = num_hidden_layers
223
  self.num_attention_heads = num_attention_heads
224
  self.num_channels = num_channels
225
  self.patch_size = patch_size
226
  self.image_size = image_size
227
- self.initializer_range = initializer_range
228
- self.initializer_factor = initializer_factor
229
  self.attention_dropout = attention_dropout
230
  self.layer_norm_eps = layer_norm_eps
231
  self.hidden_act = hidden_act
 
232
 
233
  @classmethod
234
  def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
@@ -264,10 +248,6 @@ class SiglipConfig(PretrainedConfig):
264
  Dictionary of configuration options used to initialize [`SiglipTextConfig`].
265
  vision_config (`dict`, *optional*):
266
  Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
267
- projection_dim (`int`, *optional*, defaults to 512):
268
- Dimentionality of text and vision projection layers.
269
- logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
270
- The inital value of the *logit_scale* paramter. Default is used as per the original Siglip implementation.
271
  kwargs (*optional*):
272
  Dictionary of keyword arguments.
273
 
@@ -297,79 +277,9 @@ class SiglipConfig(PretrainedConfig):
297
 
298
  model_type = "siglip"
299
 
300
- def __init__(
301
- self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
302
- ):
303
- # If `_config_dict` exist, we use them for the backward compatibility.
304
- # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
305
- # of confusion!).
306
- text_config_dict = kwargs.pop("text_config_dict", None)
307
- vision_config_dict = kwargs.pop("vision_config_dict", None)
308
-
309
  super().__init__(**kwargs)
310
 
311
- # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
312
- # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
313
- # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
314
- if text_config_dict is not None:
315
- if text_config is None:
316
- text_config = {}
317
-
318
- # This is the complete result when using `text_config_dict`.
319
- _text_config_dict = SiglipTextConfig(**text_config_dict).to_dict()
320
-
321
- # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
322
- for key, value in _text_config_dict.items():
323
- if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
324
- # If specified in `text_config_dict`
325
- if key in text_config_dict:
326
- message = (
327
- f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
328
- f'The value `text_config_dict["{key}"]` will be used instead.'
329
- )
330
- # If inferred from default argument values (just to be super careful)
331
- else:
332
- message = (
333
- f"`text_config_dict` is provided which will be used to initialize `SiglipTextConfig`. The "
334
- f'value `text_config["{key}"]` will be overriden.'
335
- )
336
- logger.warning(message)
337
-
338
- # Update all values in `text_config` with the ones in `_text_config_dict`.
339
- text_config.update(_text_config_dict)
340
-
341
- if vision_config_dict is not None:
342
- if vision_config is None:
343
- vision_config = {}
344
-
345
- # This is the complete result when using `vision_config_dict`.
346
- _vision_config_dict = SiglipVisionConfig(**vision_config_dict).to_dict()
347
- # convert keys to string instead of integer
348
- if "id2label" in _vision_config_dict:
349
- _vision_config_dict["id2label"] = {
350
- str(key): value for key, value in _vision_config_dict["id2label"].items()
351
- }
352
-
353
- # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
354
- for key, value in _vision_config_dict.items():
355
- if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
356
- # If specified in `vision_config_dict`
357
- if key in vision_config_dict:
358
- message = (
359
- f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
360
- f'values. The value `vision_config_dict["{key}"]` will be used instead.'
361
- )
362
- # If inferred from default argument values (just to be super careful)
363
- else:
364
- message = (
365
- f"`vision_config_dict` is provided which will be used to initialize `SiglipVisionConfig`. "
366
- f'The value `vision_config["{key}"]` will be overriden.'
367
- )
368
- logger.warning(message)
369
-
370
- # Update all values in `vision_config` with the ones in `_vision_config_dict`.
371
- vision_config.update(_vision_config_dict)
372
-
373
  if text_config is None:
374
  text_config = {}
375
  logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
@@ -381,8 +291,6 @@ class SiglipConfig(PretrainedConfig):
381
  self.text_config = SiglipTextConfig(**text_config)
382
  self.vision_config = SiglipVisionConfig(**vision_config)
383
 
384
- self.projection_dim = projection_dim
385
- self.logit_scale_init_value = logit_scale_init_value
386
  self.initializer_factor = 1.0
387
 
388
  @classmethod
@@ -396,49 +304,3 @@ class SiglipConfig(PretrainedConfig):
396
  """
397
 
398
  return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
399
-
400
-
401
- class SiglipOnnxConfig(OnnxConfig):
402
- @property
403
- def inputs(self) -> Mapping[str, Mapping[int, str]]:
404
- return OrderedDict(
405
- [
406
- ("input_ids", {0: "batch", 1: "sequence"}),
407
- ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
408
- ("attention_mask", {0: "batch", 1: "sequence"}),
409
- ]
410
- )
411
-
412
- @property
413
- def outputs(self) -> Mapping[str, Mapping[int, str]]:
414
- return OrderedDict(
415
- [
416
- ("logits_per_image", {0: "batch"}),
417
- ("logits_per_text", {0: "batch"}),
418
- ("text_embeds", {0: "batch"}),
419
- ("image_embeds", {0: "batch"}),
420
- ]
421
- )
422
-
423
- @property
424
- def atol_for_validation(self) -> float:
425
- return 1e-4
426
-
427
- def generate_dummy_inputs(
428
- self,
429
- processor: "ProcessorMixin",
430
- batch_size: int = -1,
431
- seq_length: int = -1,
432
- framework: Optional["TensorType"] = None,
433
- ) -> Mapping[str, Any]:
434
- text_input_dict = super().generate_dummy_inputs(
435
- processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
436
- )
437
- image_input_dict = super().generate_dummy_inputs(
438
- processor.image_processor, batch_size=batch_size, framework=framework
439
- )
440
- return {**text_input_dict, **image_input_dict}
441
-
442
- @property
443
- def default_onnx_opset(self) -> int:
444
- return 14
 
1
  # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
15
  """ Siglip model configuration"""
16
 
17
  import os
18
+ from typing import Union
 
 
 
 
 
 
19
 
20
  from transformers.configuration_utils import PretrainedConfig
 
21
  from transformers.utils import logging
22
 
23
 
 
39
  documentation from [`PretrainedConfig`] for more information.
40
 
41
  Args:
42
+ vocab_size (`int`, *optional*, defaults to 32000):
43
  Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
44
  the `inputs_ids` passed when calling [`SiglipModel`].
45
+ hidden_size (`int`, *optional*, defaults to 768):
46
  Dimensionality of the encoder layers and the pooler layer.
47
+ intermediate_size (`int`, *optional*, defaults to 3072):
48
  Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
49
  num_hidden_layers (`int`, *optional*, defaults to 12):
50
  Number of hidden layers in the Transformer encoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 12):
52
  Number of attention heads for each attention layer in the Transformer encoder.
53
  max_position_embeddings (`int`, *optional*, defaults to 64):
54
  The maximum sequence length that this model might ever be used with. Typically set this to something large
 
56
  hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
57
  The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
58
  `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
59
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
60
  The epsilon used by the layer normalization layers.
61
  attention_dropout (`float`, *optional*, defaults to 0.0):
62
  The dropout ratio for the attention probabilities.
63
+ pad_token_id (`int`, *optional*, defaults to 1):
64
+ The id of the padding token in the vocabulary.
65
+ bos_token_id (`int`, *optional*, defaults to 49406):
66
+ The id of the beginning-of-sequence token in the vocabulary.
67
+ eos_token_id (`int`, *optional*, defaults to 49407):
68
+ The id of the end-of-sequence token in the vocabulary.
69
 
70
  Example:
71
 
 
81
  >>> # Accessing the model configuration
82
  >>> configuration = model.config
83
  ```"""
84
+
85
  model_type = "siglip_text_model"
86
 
87
  def __init__(
88
  self,
89
+ vocab_size=32000,
90
+ hidden_size=768,
91
+ intermediate_size=3072,
 
92
  num_hidden_layers=12,
93
+ num_attention_heads=12,
94
  max_position_embeddings=64,
95
  hidden_act="gelu_pytorch_tanh",
96
  layer_norm_eps=1e-6,
97
  attention_dropout=0.0,
 
 
98
  # This differs from `CLIPTokenizer`'s default and from openai/siglip
99
  # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
100
  pad_token_id=1,
101
  bos_token_id=49406,
102
  eos_token_id=49407,
103
+ _flash_attn_2_enabled=True,
104
  **kwargs,
105
  ):
106
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
108
  self.vocab_size = vocab_size
109
  self.hidden_size = hidden_size
110
  self.intermediate_size = intermediate_size
 
111
  self.num_hidden_layers = num_hidden_layers
112
  self.num_attention_heads = num_attention_heads
113
  self.max_position_embeddings = max_position_embeddings
114
  self.layer_norm_eps = layer_norm_eps
115
  self.hidden_act = hidden_act
 
 
116
  self.attention_dropout = attention_dropout
117
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
118
 
119
  @classmethod
120
  def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
 
154
  Number of hidden layers in the Transformer encoder.
155
  num_attention_heads (`int`, *optional*, defaults to 12):
156
  Number of attention heads for each attention layer in the Transformer encoder.
157
+ num_channels (`int`, *optional*, defaults to 3):
158
+ Number of channels in the input images.
159
  image_size (`int`, *optional*, defaults to 224):
160
  The size (resolution) of each image.
161
+ patch_size (`int`, *optional*, defaults to 16):
162
  The size (resolution) of each patch.
163
  hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
164
  The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
165
  `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
166
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
167
  The epsilon used by the layer normalization layers.
168
  attention_dropout (`float`, *optional*, defaults to 0.0):
169
  The dropout ratio for the attention probabilities.
 
 
 
 
 
170
 
171
  Example:
172
 
 
189
  self,
190
  hidden_size=768,
191
  intermediate_size=3072,
 
192
  num_hidden_layers=12,
193
  num_attention_heads=12,
194
  num_channels=3,
195
  image_size=224,
196
+ patch_size=16,
197
  hidden_act="gelu_pytorch_tanh",
198
  layer_norm_eps=1e-6,
199
  attention_dropout=0.0,
200
+ _flash_attn_2_enabled=True,
 
201
  **kwargs,
202
  ):
203
  super().__init__(**kwargs)
204
 
205
  self.hidden_size = hidden_size
206
  self.intermediate_size = intermediate_size
 
207
  self.num_hidden_layers = num_hidden_layers
208
  self.num_attention_heads = num_attention_heads
209
  self.num_channels = num_channels
210
  self.patch_size = patch_size
211
  self.image_size = image_size
 
 
212
  self.attention_dropout = attention_dropout
213
  self.layer_norm_eps = layer_norm_eps
214
  self.hidden_act = hidden_act
215
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
216
 
217
  @classmethod
218
  def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
 
248
  Dictionary of configuration options used to initialize [`SiglipTextConfig`].
249
  vision_config (`dict`, *optional*):
250
  Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
 
 
 
 
251
  kwargs (*optional*):
252
  Dictionary of keyword arguments.
253
 
 
277
 
278
  model_type = "siglip"
279
 
280
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
 
 
 
 
 
 
 
 
281
  super().__init__(**kwargs)
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  if text_config is None:
284
  text_config = {}
285
  logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
 
291
  self.text_config = SiglipTextConfig(**text_config)
292
  self.vision_config = SiglipVisionConfig(**vision_config)
293
 
 
 
294
  self.initializer_factor = 1.0
295
 
296
  @classmethod
 
304
  """
305
 
306
  return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
image_processing_siglip.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -14,17 +14,16 @@
14
  # limitations under the License.
15
  """Image processor class for SigLIP."""
16
 
17
- from typing import Dict, Optional, Union
18
-
19
- import numpy as np
20
 
21
  from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
  from transformers.image_transforms import (
23
- rescale,
24
  resize,
25
  to_channel_dimension_format,
26
  )
27
  from transformers.image_utils import (
 
 
28
  ChannelDimension,
29
  ImageInput,
30
  PILImageResampling,
@@ -54,7 +53,7 @@ class SiglipImageProcessor(BaseImageProcessor):
54
  `do_resize` in the `preprocess` method.
55
  size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
56
  Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
57
- resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
58
  Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
59
  do_rescale (`bool`, *optional*, defaults to `True`):
60
  Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
@@ -62,6 +61,16 @@ class SiglipImageProcessor(BaseImageProcessor):
62
  rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
63
  Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
64
  method.
 
 
 
 
 
 
 
 
 
 
65
  """
66
 
67
  model_input_names = ["pixel_values"]
@@ -73,57 +82,24 @@ class SiglipImageProcessor(BaseImageProcessor):
73
  resample: PILImageResampling = PILImageResampling.BILINEAR,
74
  do_rescale: bool = True,
75
  rescale_factor: Union[int, float] = 1 / 255,
 
 
 
76
  **kwargs,
77
  ) -> None:
78
  super().__init__(**kwargs)
79
  size = size if size is not None else {"height": 224, "width": 224}
80
- size = get_size_dict(size, default_to_square=False)
 
81
 
82
  self.do_resize = do_resize
83
  self.size = size
84
  self.resample = resample
85
  self.do_rescale = do_rescale
86
  self.rescale_factor = rescale_factor
87
-
88
- def rescale(
89
- self,
90
- image: np.ndarray,
91
- rescale_factor: float,
92
- data_format: Optional[Union[str, ChannelDimension]] = None,
93
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
94
- **kwargs,
95
- ) -> np.ndarray:
96
- """
97
- Rescale an image by a scale factor. image = image * scale, after which image = image * 2 - 1.
98
-
99
- Args:
100
- image (`np.ndarray`):
101
- Image to rescale.
102
- scale (`float`):
103
- The scaling factor to rescale pixel values by.
104
- data_format (`str` or `ChannelDimension`, *optional*):
105
- The channel dimension format for the output image. If unset, the channel dimension format of the input
106
- image is used. Can be one of:
107
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
108
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
109
- input_data_format (`ChannelDimension` or `str`, *optional*):
110
- The channel dimension format for the input image. If unset, the channel dimension format is inferred
111
- from the input image. Can be one of:
112
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
113
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
114
-
115
- Returns:
116
- `np.ndarray`: The rescaled image.
117
- """
118
- # first, rescale to 0->1
119
- rescaled_image = rescale(
120
- image, scale=rescale_factor, data_format=data_format, input_data_format=input_data_format, **kwargs
121
- )
122
-
123
- # next, rescale to -1->1
124
- rescaled_image = 2 * rescaled_image - 1
125
-
126
- return rescaled_image
127
 
128
  def preprocess(
129
  self,
@@ -133,6 +109,9 @@ class SiglipImageProcessor(BaseImageProcessor):
133
  resample: PILImageResampling = None,
134
  do_rescale: bool = None,
135
  rescale_factor: float = None,
 
 
 
136
  return_tensors: Optional[Union[str, TensorType]] = None,
137
  data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
138
  input_data_format: Optional[Union[str, ChannelDimension]] = None,
@@ -156,6 +135,13 @@ class SiglipImageProcessor(BaseImageProcessor):
156
  Whether to rescale the image.
157
  rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
158
  Rescale factor to rescale the image by if `do_rescale` is set to `True`.
 
 
 
 
 
 
 
159
  return_tensors (`str` or `TensorType`, *optional*):
160
  The type of tensors to return. Can be one of:
161
  - Unset: Return a list of `np.ndarray`.
@@ -181,6 +167,9 @@ class SiglipImageProcessor(BaseImageProcessor):
181
  resample = resample if resample is not None else self.resample
182
  do_rescale = do_rescale if do_rescale is not None else self.do_rescale
183
  rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
 
 
 
184
 
185
  images = make_list_of_images(images)
186
 
@@ -210,14 +199,21 @@ class SiglipImageProcessor(BaseImageProcessor):
210
  input_data_format = infer_channel_dimension_format(images[0])
211
 
212
  if do_resize:
 
213
  images = [
214
- resize(image=image, size=(size["width"], size["height"]), resample=resample, input_data_format=input_data_format)
215
  for image in images
216
  ]
217
 
218
  if do_rescale:
219
  images = [
220
- self.rescale(image=image, rescale_factor=rescale_factor, input_data_format=input_data_format)
 
 
 
 
 
 
221
  for image in images
222
  ]
223
 
 
1
  # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
14
  # limitations under the License.
15
  """Image processor class for SigLIP."""
16
 
17
+ from typing import Dict, List, Optional, Union
 
 
18
 
19
  from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
20
  from transformers.image_transforms import (
 
21
  resize,
22
  to_channel_dimension_format,
23
  )
24
  from transformers.image_utils import (
25
+ IMAGENET_STANDARD_MEAN,
26
+ IMAGENET_STANDARD_STD,
27
  ChannelDimension,
28
  ImageInput,
29
  PILImageResampling,
 
53
  `do_resize` in the `preprocess` method.
54
  size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
55
  Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
56
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
57
  Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
58
  do_rescale (`bool`, *optional*, defaults to `True`):
59
  Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
 
61
  rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
62
  Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
63
  method.
64
+ do_normalize (`bool`, *optional*, defaults to `True`):
65
+ Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
66
+ `do_normalize` in the `preprocess` method.
67
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
68
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
69
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
70
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
71
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
72
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
73
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
74
  """
75
 
76
  model_input_names = ["pixel_values"]
 
82
  resample: PILImageResampling = PILImageResampling.BILINEAR,
83
  do_rescale: bool = True,
84
  rescale_factor: Union[int, float] = 1 / 255,
85
+ do_normalize: bool = True,
86
+ image_mean: Optional[Union[float, List[float]]] = None,
87
+ image_std: Optional[Union[float, List[float]]] = None,
88
  **kwargs,
89
  ) -> None:
90
  super().__init__(**kwargs)
91
  size = size if size is not None else {"height": 224, "width": 224}
92
+ image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
93
+ image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
94
 
95
  self.do_resize = do_resize
96
  self.size = size
97
  self.resample = resample
98
  self.do_rescale = do_rescale
99
  self.rescale_factor = rescale_factor
100
+ self.do_normalize = do_normalize
101
+ self.image_mean = image_mean
102
+ self.image_std = image_std
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def preprocess(
105
  self,
 
109
  resample: PILImageResampling = None,
110
  do_rescale: bool = None,
111
  rescale_factor: float = None,
112
+ do_normalize: bool = None,
113
+ image_mean: Optional[Union[float, List[float]]] = None,
114
+ image_std: Optional[Union[float, List[float]]] = None,
115
  return_tensors: Optional[Union[str, TensorType]] = None,
116
  data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
117
  input_data_format: Optional[Union[str, ChannelDimension]] = None,
 
135
  Whether to rescale the image.
136
  rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
137
  Rescale factor to rescale the image by if `do_rescale` is set to `True`.
138
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
139
+ Whether to normalize the image.
140
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
141
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
142
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
143
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
144
+ `True`.
145
  return_tensors (`str` or `TensorType`, *optional*):
146
  The type of tensors to return. Can be one of:
147
  - Unset: Return a list of `np.ndarray`.
 
167
  resample = resample if resample is not None else self.resample
168
  do_rescale = do_rescale if do_rescale is not None else self.do_rescale
169
  rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
170
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
171
+ image_mean = image_mean if image_mean is not None else self.image_mean
172
+ image_std = image_std if image_std is not None else self.image_std
173
 
174
  images = make_list_of_images(images)
175
 
 
199
  input_data_format = infer_channel_dimension_format(images[0])
200
 
201
  if do_resize:
202
+ height, width = size["height"], size["width"]
203
  images = [
204
+ resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format)
205
  for image in images
206
  ]
207
 
208
  if do_rescale:
209
  images = [
210
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
211
+ for image in images
212
+ ]
213
+
214
+ if do_normalize:
215
+ images = [
216
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
217
  for image in images
218
  ]
219
 
modeling_siglip.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 Google AI and The HuggingFace Team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -15,20 +15,27 @@
15
  """ PyTorch Siglip model."""
16
 
17
 
 
 
18
  from dataclasses import dataclass
19
  from typing import Any, Optional, Tuple, Union
20
 
 
21
  import torch
 
22
  import torch.utils.checkpoint
23
  from torch import nn
 
24
 
25
  from transformers.activations import ACT2FN
 
26
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27
  from transformers.modeling_utils import PreTrainedModel
28
  from transformers.utils import (
29
  ModelOutput,
30
  add_start_docstrings,
31
  add_start_docstrings_to_model_forward,
 
32
  logging,
33
  replace_return_docstrings,
34
  )
@@ -44,33 +51,122 @@ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
  # See all SigLIP models at https://huggingface.co/models?filter=siglip
45
  ]
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Copied from transformers.models.bart.modeling_bart._expand_mask
49
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
50
- """
51
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  """
53
- bsz, src_len = mask.size()
54
- tgt_len = tgt_len if tgt_len is not None else src_len
 
 
55
 
56
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
 
 
 
 
 
 
 
57
 
58
- inverted_mask = 1.0 - expanded_mask
59
 
60
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
- # contrastive loss function, adapted from
64
- # https://sachinruk.github.io/blog/2021-03-07-siglip.html
65
- def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
66
- return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
67
 
68
 
69
- # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->siglip
70
- def siglip_loss(similarity: torch.Tensor) -> torch.Tensor:
71
- caption_loss = contrastive_loss(similarity)
72
- image_loss = contrastive_loss(similarity.t())
73
- return (caption_loss + image_loss) / 2.0
74
 
75
 
76
  @dataclass
@@ -149,8 +245,7 @@ class SiglipOutput(ModelOutput):
149
  text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
150
  The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
151
  image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
152
- The image embeddings obtained by applying the projection layer to the pooled output of
153
- [`SiglipVisionModel`].
154
  text_model_output(`BaseModelOutputWithPooling`):
155
  The output of the [`SiglipTextModel`].
156
  vision_model_output(`BaseModelOutputWithPooling`):
@@ -188,17 +283,44 @@ class SiglipVisionEmbeddings(nn.Module):
188
  padding="valid",
189
  )
190
 
191
- self.num_patches = (self.image_size // self.patch_size) ** 2
 
192
  self.num_positions = self.num_patches
193
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
194
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
195
 
196
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
 
197
 
198
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
199
  embeddings = patch_embeds.flatten(2).transpose(1, 2)
200
 
201
- embeddings = embeddings + self.position_embedding(self.position_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  return embeddings
203
 
204
 
@@ -236,10 +358,10 @@ class SiglipTextEmbeddings(nn.Module):
236
  return embeddings
237
 
238
 
239
- # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Siglip
240
  class SiglipAttention(nn.Module):
241
  """Multi-headed attention from 'Attention Is All You Need' paper"""
242
 
 
243
  def __init__(self, config):
244
  super().__init__()
245
  self.config = config
@@ -259,86 +381,245 @@ class SiglipAttention(nn.Module):
259
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
260
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
261
 
262
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
263
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
264
-
265
  def forward(
266
  self,
267
  hidden_states: torch.Tensor,
268
  attention_mask: Optional[torch.Tensor] = None,
269
- causal_attention_mask: Optional[torch.Tensor] = None,
270
  output_attentions: Optional[bool] = False,
271
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
272
  """Input shape: Batch x Time x Channel"""
273
 
274
- bsz, tgt_len, embed_dim = hidden_states.size()
275
 
276
- # get query proj
277
- query_states = self.q_proj(hidden_states) * self.scale
278
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
279
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
280
 
281
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
282
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
283
- key_states = key_states.view(*proj_shape)
284
- value_states = value_states.view(*proj_shape)
285
 
286
- src_len = key_states.size(1)
287
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
288
 
289
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
290
  raise ValueError(
291
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
292
  f" {attn_weights.size()}"
293
  )
294
 
295
- # apply the causal_attention_mask first
296
- if causal_attention_mask is not None:
297
- if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
298
- raise ValueError(
299
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
300
- f" {causal_attention_mask.size()}"
301
- )
302
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
303
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
304
-
305
  if attention_mask is not None:
306
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
307
  raise ValueError(
308
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
309
  )
310
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
311
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
312
 
313
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 
 
 
314
 
315
- if output_attentions:
316
- # this operation is a bit akward, but it's required to
317
- # make sure that attn_weights keeps its gradient.
318
- # In order to do so, attn_weights have to reshaped
319
- # twice and have to be reused in the following
320
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
321
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
322
- else:
323
- attn_weights_reshaped = None
324
 
325
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 
326
 
327
- attn_output = torch.bmm(attn_probs, value_states)
328
 
329
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
330
- raise ValueError(
331
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
332
- f" {attn_output.size()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  )
334
 
335
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
336
- attn_output = attn_output.transpose(1, 2)
337
- attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
338
 
 
 
 
 
 
339
  attn_output = self.out_proj(attn_output)
340
 
341
- return attn_output, attn_weights_reshaped
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
 
344
  # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
@@ -362,7 +643,11 @@ class SiglipEncoderLayer(nn.Module):
362
  def __init__(self, config: SiglipConfig):
363
  super().__init__()
364
  self.embed_dim = config.hidden_size
365
- self.self_attn = SiglipAttention(config)
 
 
 
 
366
  self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
367
  self.mlp = SiglipMLP(config)
368
  self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -371,16 +656,15 @@ class SiglipEncoderLayer(nn.Module):
371
  self,
372
  hidden_states: torch.Tensor,
373
  attention_mask: torch.Tensor,
374
- causal_attention_mask: torch.Tensor,
375
  output_attentions: Optional[bool] = False,
376
  ) -> Tuple[torch.FloatTensor]:
377
  """
378
  Args:
379
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
380
- attention_mask (`torch.FloatTensor`): attention mask of size
381
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
382
- `(config.encoder_attention_heads,)`.
383
- output_attentions (`bool`, *optional*):
384
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
385
  returned tensors for more detail.
386
  """
@@ -390,7 +674,6 @@ class SiglipEncoderLayer(nn.Module):
390
  hidden_states, attn_weights = self.self_attn(
391
  hidden_states=hidden_states,
392
  attention_mask=attention_mask,
393
- causal_attention_mask=causal_attention_mask,
394
  output_attentions=output_attentions,
395
  )
396
  hidden_states = residual + hidden_states
@@ -420,39 +703,45 @@ class SiglipPreTrainedModel(PreTrainedModel):
420
 
421
  def _init_weights(self, module):
422
  """Initialize the weights"""
423
- factor = self.config.initializer_factor
424
- if isinstance(module, SiglipTextEmbeddings):
425
- module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
426
- module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
427
- elif isinstance(module, SiglipVisionEmbeddings):
428
- factor = self.config.initializer_factor
429
- nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
430
- nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
 
 
431
  elif isinstance(module, SiglipAttention):
432
- factor = self.config.initializer_factor
433
- in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
434
- out_proj_std = (module.embed_dim**-0.5) * factor
435
- nn.init.normal_(module.q_proj.weight, std=in_proj_std)
436
- nn.init.normal_(module.k_proj.weight, std=in_proj_std)
437
- nn.init.normal_(module.v_proj.weight, std=in_proj_std)
438
- nn.init.normal_(module.out_proj.weight, std=out_proj_std)
 
439
  elif isinstance(module, SiglipMLP):
440
- factor = self.config.initializer_factor
441
- in_proj_std = (
442
- (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
443
- )
444
- fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
445
- nn.init.normal_(module.fc1.weight, std=fc_std)
446
- nn.init.normal_(module.fc2.weight, std=in_proj_std)
447
- if isinstance(module, nn.LayerNorm):
 
 
 
 
 
 
 
 
 
448
  module.bias.data.zero_()
449
  module.weight.data.fill_(1.0)
450
- if isinstance(module, nn.Linear) and module.bias is not None:
451
- module.bias.data.zero_()
452
-
453
- def _set_gradient_checkpointing(self, module, value=False):
454
- if isinstance(module, SiglipEncoder):
455
- module.gradient_checkpointing = value
456
 
457
 
458
  SIGLIP_START_DOCSTRING = r"""
@@ -571,11 +860,11 @@ class SiglipEncoder(nn.Module):
571
  self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
572
  self.gradient_checkpointing = False
573
 
 
574
  def forward(
575
  self,
576
  inputs_embeds,
577
  attention_mask: Optional[torch.Tensor] = None,
578
- causal_attention_mask: Optional[torch.Tensor] = None,
579
  output_attentions: Optional[bool] = None,
580
  output_hidden_states: Optional[bool] = None,
581
  return_dict: Optional[bool] = None,
@@ -592,13 +881,6 @@ class SiglipEncoder(nn.Module):
592
  - 1 for tokens that are **not masked**,
593
  - 0 for tokens that are **masked**.
594
 
595
- [What are attention masks?](../glossary#attention-mask)
596
- causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
597
- Causal mask for the text model. Mask values selected in `[0, 1]`:
598
-
599
- - 1 for tokens that are **not masked**,
600
- - 0 for tokens that are **masked**.
601
-
602
  [What are attention masks?](../glossary#attention-mask)
603
  output_attentions (`bool`, *optional*):
604
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
@@ -619,28 +901,20 @@ class SiglipEncoder(nn.Module):
619
  all_attentions = () if output_attentions else None
620
 
621
  hidden_states = inputs_embeds
622
- for idx, encoder_layer in enumerate(self.layers):
623
  if output_hidden_states:
624
  encoder_states = encoder_states + (hidden_states,)
625
  if self.gradient_checkpointing and self.training:
626
-
627
- def create_custom_forward(module):
628
- def custom_forward(*inputs):
629
- return module(*inputs, output_attentions)
630
-
631
- return custom_forward
632
-
633
- layer_outputs = torch.utils.checkpoint.checkpoint(
634
- create_custom_forward(encoder_layer),
635
  hidden_states,
636
  attention_mask,
637
- causal_attention_mask,
638
  )
639
  else:
640
  layer_outputs = encoder_layer(
641
  hidden_states,
642
  attention_mask,
643
- causal_attention_mask,
644
  output_attentions=output_attentions,
645
  )
646
 
@@ -699,16 +973,15 @@ class SiglipTextTransformer(nn.Module):
699
 
700
  hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
701
 
702
- # note: SigLIP's text model does not use q causal mask, unlike the original CLIP model.
703
  # expand attention_mask
704
  if attention_mask is not None:
705
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
706
- attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
707
 
708
  encoder_outputs = self.encoder(
709
  inputs_embeds=hidden_states,
710
- attention_mask=None,
711
- causal_attention_mask=None,
712
  output_attentions=output_attentions,
713
  output_hidden_states=output_hidden_states,
714
  return_dict=return_dict,
@@ -775,7 +1048,8 @@ class SiglipTextModel(SiglipPreTrainedModel):
775
  >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
776
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
777
 
778
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
 
779
 
780
  >>> outputs = model(**inputs)
781
  >>> last_hidden_state = outputs.last_hidden_state
@@ -809,6 +1083,7 @@ class SiglipVisionTransformer(nn.Module):
809
  def forward(
810
  self,
811
  pixel_values,
 
812
  output_attentions: Optional[bool] = None,
813
  output_hidden_states: Optional[bool] = None,
814
  return_dict: Optional[bool] = None,
@@ -823,10 +1098,29 @@ class SiglipVisionTransformer(nn.Module):
823
  )
824
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
825
 
826
- hidden_states = self.embeddings(pixel_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827
 
828
  encoder_outputs = self.encoder(
829
  inputs_embeds=hidden_states,
 
 
 
 
 
830
  output_attentions=output_attentions,
831
  output_hidden_states=output_hidden_states,
832
  return_dict=return_dict,
@@ -835,8 +1129,10 @@ class SiglipVisionTransformer(nn.Module):
835
  last_hidden_state = encoder_outputs[0]
836
  last_hidden_state = self.post_layernorm(last_hidden_state)
837
 
838
-
839
- pooled_output = self.head(last_hidden_state)
 
 
840
 
841
  if not return_dict:
842
  return (last_hidden_state, pooled_output) + encoder_outputs[1:]
@@ -860,11 +1156,13 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
860
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
861
  self.mlp = SiglipMLP(config)
862
 
863
- def forward(self, hidden_state):
864
  batch_size = hidden_state.shape[0]
865
  probe = self.probe.repeat(batch_size, 1, 1)
866
 
867
- hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
 
 
868
 
869
  residual = hidden_state
870
  hidden_state = self.layernorm(hidden_state)
@@ -921,7 +1219,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
921
 
922
  >>> outputs = model(**inputs)
923
  >>> last_hidden_state = outputs.last_hidden_state
924
- >>> pooled_output = outputs.pooler_output # pooled CLS states
925
  ```"""
926
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
927
 
@@ -955,19 +1253,11 @@ class SiglipModel(SiglipPreTrainedModel):
955
  text_config = config.text_config
956
  vision_config = config.vision_config
957
 
958
- self.text_model = SiglipTextModel(text_config)
959
- self.vision_model = SiglipVisionModel(vision_config)
960
 
961
- self.temperature = nn.Parameter(
962
- torch.randn(
963
- 1,
964
- )
965
- )
966
- self.bias = nn.Parameter(
967
- torch.randn(
968
- 1,
969
- )
970
- )
971
 
972
  # Initialize weights and apply final processing
973
  self.post_init()
@@ -990,13 +1280,16 @@ class SiglipModel(SiglipPreTrainedModel):
990
  Examples:
991
 
992
  ```python
993
- >>> from transformers import AutoTokenizer, SiglipModel
 
994
 
995
- >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
996
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
997
 
998
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
999
- >>> text_features = model.get_text_features(**inputs)
 
 
1000
  ```"""
1001
  # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1002
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1036,9 +1329,10 @@ class SiglipModel(SiglipPreTrainedModel):
1036
  ```python
1037
  >>> from PIL import Image
1038
  >>> import requests
1039
- >>> from transformers import AutoProcessor, SiglipModel
 
1040
 
1041
- >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1042
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1043
 
1044
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
@@ -1046,7 +1340,8 @@ class SiglipModel(SiglipPreTrainedModel):
1046
 
1047
  >>> inputs = processor(images=image, return_tensors="pt")
1048
 
1049
- >>> image_features = model.get_image_features(**inputs)
 
1050
  ```"""
1051
  # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1052
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1087,21 +1382,26 @@ class SiglipModel(SiglipPreTrainedModel):
1087
  ```python
1088
  >>> from PIL import Image
1089
  >>> import requests
1090
- >>> from transformers import AutoProcessor, SiglipModel
 
1091
 
1092
- >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1093
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1094
 
1095
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1096
  >>> image = Image.open(requests.get(url, stream=True).raw)
1097
 
1098
- >>> inputs = processor(
1099
- ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1100
- ... )
1101
 
1102
- >>> outputs = model(**inputs)
1103
- >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1104
- >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
 
 
 
 
1105
  ```"""
1106
  # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1107
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1134,11 +1434,9 @@ class SiglipModel(SiglipPreTrainedModel):
1134
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1135
 
1136
  # cosine similarity as logits
1137
- logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.temperature.exp() + self.bias
1138
  logits_per_image = logits_per_text.t()
1139
 
1140
- z = torch.matmul(image_embeds, text_embeds.t()) * self.temperature.exp()
1141
-
1142
  loss = None
1143
  if return_loss:
1144
  raise NotImplementedError("SigLIP loss to be implemented")
 
1
  # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
15
  """ PyTorch Siglip model."""
16
 
17
 
18
+ import math
19
+ import warnings
20
  from dataclasses import dataclass
21
  from typing import Any, Optional, Tuple, Union
22
 
23
+ import numpy as np
24
  import torch
25
+ import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from torch import nn
28
+ from torch.nn.init import _calculate_fan_in_and_fan_out
29
 
30
  from transformers.activations import ACT2FN
31
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
32
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
33
  from transformers.modeling_utils import PreTrainedModel
34
  from transformers.utils import (
35
  ModelOutput,
36
  add_start_docstrings,
37
  add_start_docstrings_to_model_forward,
38
+ is_flash_attn_2_available,
39
  logging,
40
  replace_return_docstrings,
41
  )
 
51
  # See all SigLIP models at https://huggingface.co/models?filter=siglip
52
  ]
53
 
54
+ if is_flash_attn_2_available():
55
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
56
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
+
58
+
59
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
60
+ def _get_unpad_data(attention_mask):
61
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
62
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
63
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
64
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
65
+ return (
66
+ indices,
67
+ cu_seqlens,
68
+ max_seqlen_in_batch,
69
+ )
70
+
71
+
72
+ def _trunc_normal_(tensor, mean, std, a, b):
73
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
74
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
75
+ def norm_cdf(x):
76
+ # Computes standard normal cumulative distribution function
77
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
78
+
79
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
80
+ warnings.warn(
81
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
82
+ "The distribution of values may be incorrect.",
83
+ stacklevel=2,
84
+ )
85
 
86
+ # Values are generated by using a truncated uniform distribution and
87
+ # then using the inverse CDF for the normal distribution.
88
+ # Get upper and lower cdf values
89
+ l = norm_cdf((a - mean) / std)
90
+ u = norm_cdf((b - mean) / std)
91
+
92
+ # Uniformly fill tensor with values from [l, u], then translate to
93
+ # [2l-1, 2u-1].
94
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
95
+
96
+ # Use inverse cdf transform for normal distribution to get truncated
97
+ # standard normal
98
+ if tensor.dtype == torch.bfloat16:
99
+ tensor = tensor.to(torch.float32)
100
+ tensor.erfinv_()
101
+ tensor = tensor.to(torch.bfloat16)
102
+ else:
103
+ tensor.erfinv_()
104
+
105
+ # Transform to proper mean, std
106
+ tensor.mul_(std * math.sqrt(2.0))
107
+ tensor.add_(mean)
108
+
109
+ # Clamp to ensure it's in the proper range
110
+ tensor.clamp_(min=a, max=b)
111
+
112
+
113
+ def trunc_normal_tf_(
114
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
115
+ ) -> torch.Tensor:
116
+ """Fills the input Tensor with values drawn from a truncated
117
+ normal distribution. The values are effectively drawn from the
118
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
119
+ with values outside :math:`[a, b]` redrawn until they are within
120
+ the bounds. The method used for generating the random values works
121
+ best when :math:`a \\leq \text{mean} \\leq b`.
122
+
123
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
124
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
125
+ and the result is subsquently scaled and shifted by the mean and std args.
126
+
127
+ Args:
128
+ tensor: an n-dimensional `torch.Tensor`
129
+ mean: the mean of the normal distribution
130
+ std: the standard deviation of the normal distribution
131
+ a: the minimum cutoff value
132
+ b: the maximum cutoff value
133
  """
134
+ with torch.no_grad():
135
+ _trunc_normal_(tensor, 0, 1.0, a, b)
136
+ tensor.mul_(std).add_(mean)
137
+
138
 
139
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
140
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
141
+ if mode == "fan_in":
142
+ denom = fan_in
143
+ elif mode == "fan_out":
144
+ denom = fan_out
145
+ elif mode == "fan_avg":
146
+ denom = (fan_in + fan_out) / 2
147
 
148
+ variance = scale / denom
149
 
150
+ if distribution == "truncated_normal":
151
+ # constant is stddev of standard normal truncated to (-2, 2)
152
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
153
+ elif distribution == "normal":
154
+ with torch.no_grad():
155
+ tensor.normal_(std=math.sqrt(variance))
156
+ elif distribution == "uniform":
157
+ bound = math.sqrt(3 * variance)
158
+ with torch.no_grad():
159
+ tensor.uniform_(-bound, bound)
160
+ else:
161
+ raise ValueError(f"invalid distribution {distribution}")
162
 
163
 
164
+ def lecun_normal_(tensor):
165
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
 
 
166
 
167
 
168
+ def default_flax_embed_init(tensor):
169
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
 
 
 
170
 
171
 
172
  @dataclass
 
245
  text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
246
  The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
247
  image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
248
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
 
249
  text_model_output(`BaseModelOutputWithPooling`):
250
  The output of the [`SiglipTextModel`].
251
  vision_model_output(`BaseModelOutputWithPooling`):
 
283
  padding="valid",
284
  )
285
 
286
+ self.num_patches_per_side = self.image_size // self.patch_size
287
+ self.num_patches = self.num_patches_per_side**2
288
  self.num_positions = self.num_patches
289
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
 
290
 
291
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
292
+ batch_size = pixel_values.size(0)
293
 
294
+ patch_embeds = self.patch_embedding(pixel_values)
295
  embeddings = patch_embeds.flatten(2).transpose(1, 2)
296
 
297
+ max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
298
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
299
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
300
+ position_ids = torch.full(
301
+ size=(
302
+ batch_size,
303
+ max_nb_patches_h * max_nb_patches_w,
304
+ ),
305
+ fill_value=0,
306
+ )
307
+
308
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
309
+ nb_patches_h = p_attn_mask[:, 0].sum()
310
+ nb_patches_w = p_attn_mask[0].sum()
311
+
312
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
313
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
314
+
315
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
316
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
317
+
318
+ pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
319
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
320
+
321
+ position_ids = position_ids.to(self.position_embedding.weight.device)
322
+
323
+ embeddings = embeddings + self.position_embedding(position_ids)
324
  return embeddings
325
 
326
 
 
358
  return embeddings
359
 
360
 
 
361
  class SiglipAttention(nn.Module):
362
  """Multi-headed attention from 'Attention Is All You Need' paper"""
363
 
364
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
365
  def __init__(self, config):
366
  super().__init__()
367
  self.config = config
 
381
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
382
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
383
 
 
 
 
384
  def forward(
385
  self,
386
  hidden_states: torch.Tensor,
387
  attention_mask: Optional[torch.Tensor] = None,
 
388
  output_attentions: Optional[bool] = False,
389
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
390
  """Input shape: Batch x Time x Channel"""
391
 
392
+ batch_size, q_len, _ = hidden_states.size()
393
 
394
+ query_states = self.q_proj(hidden_states)
395
+ key_states = self.k_proj(hidden_states)
396
+ value_states = self.v_proj(hidden_states)
 
397
 
398
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
399
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
400
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
401
 
402
+ k_v_seq_len = key_states.shape[-2]
403
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
404
 
405
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
406
  raise ValueError(
407
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
408
  f" {attn_weights.size()}"
409
  )
410
 
 
 
 
 
 
 
 
 
 
 
411
  if attention_mask is not None:
412
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
413
  raise ValueError(
414
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
415
  )
416
+ attn_weights = attn_weights + attention_mask
 
417
 
418
+ # upcast attention to fp32
419
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
420
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
421
+ attn_output = torch.matmul(attn_weights, value_states)
422
 
423
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
424
+ raise ValueError(
425
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
426
+ f" {attn_output.size()}"
427
+ )
 
 
 
 
428
 
429
+ attn_output = attn_output.transpose(1, 2).contiguous()
430
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
431
 
432
+ attn_output = self.out_proj(attn_output)
433
 
434
+ return attn_output, attn_weights
435
+
436
+
437
+ class SiglipFlashAttention2(SiglipAttention):
438
+ """
439
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
440
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
441
+ flash attention and deal with padding tokens in case the input contains any of them.
442
+ """
443
+
444
+ def __init__(self, *args, **kwargs):
445
+ super().__init__(*args, **kwargs)
446
+ self.is_causal = False # Hack to make sure we don't use a causal mask
447
+
448
+ def forward(
449
+ self,
450
+ hidden_states: torch.Tensor,
451
+ attention_mask: Optional[torch.LongTensor] = None,
452
+ position_ids: Optional[torch.LongTensor] = None,
453
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
454
+ output_attentions: bool = False,
455
+ use_cache: bool = False,
456
+ **kwargs,
457
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
458
+ output_attentions = False
459
+
460
+ bsz, q_len, _ = hidden_states.size()
461
+
462
+ query_states = self.q_proj(hidden_states)
463
+ key_states = self.k_proj(hidden_states)
464
+ value_states = self.v_proj(hidden_states)
465
+
466
+ # Flash attention requires the input to have the shape
467
+ # batch_size x seq_length x head_dim x hidden_dim
468
+ # therefore we just need to keep the original shape
469
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
470
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
471
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
472
+
473
+ kv_seq_len = key_states.shape[-2]
474
+ if past_key_value is not None:
475
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
476
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
477
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
478
+
479
+ # if past_key_value is not None:
480
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
481
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
482
+
483
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
484
+ # to be able to avoid many of these transpose/reshape/view.
485
+ query_states = query_states.transpose(1, 2)
486
+ key_states = key_states.transpose(1, 2)
487
+ value_states = value_states.transpose(1, 2)
488
+
489
+ dropout_rate = self.dropout if self.training else 0.0
490
+
491
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
492
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
493
+ # cast them back in the correct dtype just to be sure everything works as expected.
494
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
495
+ # in fp32. (LlamaRMSNorm handles it correctly)
496
+
497
+ input_dtype = query_states.dtype
498
+ if input_dtype == torch.float32:
499
+ if torch.is_autocast_enabled():
500
+ target_dtype = torch.get_autocast_gpu_dtype()
501
+ # Handle the case where the model is quantized
502
+ elif hasattr(self.config, "_pre_quantization_dtype"):
503
+ target_dtype = self.config._pre_quantization_dtype
504
+ else:
505
+ target_dtype = self.q_proj.weight.dtype
506
+
507
+ logger.warning_once(
508
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
509
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
510
+ f" {target_dtype}."
511
  )
512
 
513
+ query_states = query_states.to(target_dtype)
514
+ key_states = key_states.to(target_dtype)
515
+ value_states = value_states.to(target_dtype)
516
 
517
+ attn_output = self._flash_attention_forward(
518
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
519
+ )
520
+
521
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
522
  attn_output = self.out_proj(attn_output)
523
 
524
+ if not output_attentions:
525
+ attn_weights = None
526
+
527
+ return attn_output, attn_weights
528
+
529
+ def _flash_attention_forward(
530
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
531
+ ):
532
+ """
533
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
534
+ first unpad the input, then computes the attention scores and pad the final attention scores.
535
+
536
+ Args:
537
+ query_states (`torch.Tensor`):
538
+ Input query states to be passed to Flash Attention API
539
+ key_states (`torch.Tensor`):
540
+ Input key states to be passed to Flash Attention API
541
+ value_states (`torch.Tensor`):
542
+ Input value states to be passed to Flash Attention API
543
+ attention_mask (`torch.Tensor`):
544
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
545
+ position of padding tokens and 1 for the position of non-padding tokens.
546
+ dropout (`int`, *optional*):
547
+ Attention dropout
548
+ softmax_scale (`float`, *optional*):
549
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
550
+ """
551
+
552
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
553
+ causal = self.is_causal and query_length != 1
554
+
555
+ # Contains at least one padding token in the sequence
556
+ if attention_mask is not None:
557
+ batch_size = query_states.shape[0]
558
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
559
+ query_states, key_states, value_states, attention_mask, query_length
560
+ )
561
+
562
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
563
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
564
+
565
+ attn_output_unpad = flash_attn_varlen_func(
566
+ query_states,
567
+ key_states,
568
+ value_states,
569
+ cu_seqlens_q=cu_seqlens_q,
570
+ cu_seqlens_k=cu_seqlens_k,
571
+ max_seqlen_q=max_seqlen_in_batch_q,
572
+ max_seqlen_k=max_seqlen_in_batch_k,
573
+ dropout_p=dropout,
574
+ softmax_scale=softmax_scale,
575
+ causal=causal,
576
+ )
577
+
578
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
579
+ else:
580
+ attn_output = flash_attn_func(
581
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
582
+ )
583
+
584
+ return attn_output
585
+
586
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
587
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
588
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
589
+
590
+ key_layer = index_first_axis(
591
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
592
+ )
593
+ value_layer = index_first_axis(
594
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
595
+ )
596
+ if query_length == kv_seq_len:
597
+ query_layer = index_first_axis(
598
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
599
+ )
600
+ cu_seqlens_q = cu_seqlens_k
601
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
602
+ indices_q = indices_k
603
+ elif query_length == 1:
604
+ max_seqlen_in_batch_q = 1
605
+ cu_seqlens_q = torch.arange(
606
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
607
+ ) # There is a memcpy here, that is very bad.
608
+ indices_q = cu_seqlens_q[:-1]
609
+ query_layer = query_layer.squeeze(1)
610
+ else:
611
+ # The -q_len: slice assumes left padding.
612
+ attention_mask = attention_mask[:, -query_length:]
613
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
614
+
615
+ return (
616
+ query_layer,
617
+ key_layer,
618
+ value_layer,
619
+ indices_q,
620
+ (cu_seqlens_q, cu_seqlens_k),
621
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
622
+ )
623
 
624
 
625
  # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
 
643
  def __init__(self, config: SiglipConfig):
644
  super().__init__()
645
  self.embed_dim = config.hidden_size
646
+ self.self_attn = (
647
+ SiglipAttention(config)
648
+ if not getattr(config, "_flash_attn_2_enabled", False)
649
+ else SiglipFlashAttention2(config)
650
+ )
651
  self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
652
  self.mlp = SiglipMLP(config)
653
  self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
 
656
  self,
657
  hidden_states: torch.Tensor,
658
  attention_mask: torch.Tensor,
 
659
  output_attentions: Optional[bool] = False,
660
  ) -> Tuple[torch.FloatTensor]:
661
  """
662
  Args:
663
+ hidden_states (`torch.FloatTensor`):
664
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
665
+ attention_mask (`torch.FloatTensor`):
666
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
667
+ output_attentions (`bool`, *optional*, defaults to `False`):
668
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
669
  returned tensors for more detail.
670
  """
 
674
  hidden_states, attn_weights = self.self_attn(
675
  hidden_states=hidden_states,
676
  attention_mask=attention_mask,
 
677
  output_attentions=output_attentions,
678
  )
679
  hidden_states = residual + hidden_states
 
703
 
704
  def _init_weights(self, module):
705
  """Initialize the weights"""
706
+
707
+ if isinstance(module, SiglipVisionEmbeddings):
708
+ width = (
709
+ self.config.vision_config.hidden_size
710
+ if isinstance(self.config, SiglipConfig)
711
+ else self.config.hidden_size
712
+ )
713
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
714
+ elif isinstance(module, nn.Embedding):
715
+ default_flax_embed_init(module.weight)
716
  elif isinstance(module, SiglipAttention):
717
+ nn.init.normal_(module.q_proj.weight)
718
+ nn.init.normal_(module.k_proj.weight)
719
+ nn.init.normal_(module.v_proj.weight)
720
+ nn.init.normal_(module.out_proj.weight)
721
+ nn.init.zeros_(module.q_proj.bias)
722
+ nn.init.zeros_(module.k_proj.bias)
723
+ nn.init.zeros_(module.v_proj.bias)
724
+ nn.init.zeros_(module.out_proj.bias)
725
  elif isinstance(module, SiglipMLP):
726
+ nn.init.normal_(module.fc1.weight)
727
+ nn.init.normal_(module.fc2.weight)
728
+ nn.init.normal_(module.fc1.bias, std=1e-6)
729
+ nn.init.normal_(module.fc2.bias, std=1e-6)
730
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
731
+ nn.init.normal_(module.probe.data)
732
+ nn.init.normal_(module.attention.in_proj_weight.data)
733
+ nn.init.zeros_(module.attention.in_proj_bias.data)
734
+ elif isinstance(module, SiglipModel):
735
+ logit_scale_init = torch.log(torch.tensor(1.0))
736
+ module.logit_scale.data.fill_(logit_scale_init)
737
+ module.logit_bias.data.zero_()
738
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
739
+ lecun_normal_(module.weight)
740
+ if module.bias is not None:
741
+ nn.init.zeros_(module.bias)
742
+ elif isinstance(module, nn.LayerNorm):
743
  module.bias.data.zero_()
744
  module.weight.data.fill_(1.0)
 
 
 
 
 
 
745
 
746
 
747
  SIGLIP_START_DOCSTRING = r"""
 
860
  self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
861
  self.gradient_checkpointing = False
862
 
863
+ # Ignore copy
864
  def forward(
865
  self,
866
  inputs_embeds,
867
  attention_mask: Optional[torch.Tensor] = None,
 
868
  output_attentions: Optional[bool] = None,
869
  output_hidden_states: Optional[bool] = None,
870
  return_dict: Optional[bool] = None,
 
881
  - 1 for tokens that are **not masked**,
882
  - 0 for tokens that are **masked**.
883
 
 
 
 
 
 
 
 
884
  [What are attention masks?](../glossary#attention-mask)
885
  output_attentions (`bool`, *optional*):
886
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
 
901
  all_attentions = () if output_attentions else None
902
 
903
  hidden_states = inputs_embeds
904
+ for encoder_layer in self.layers:
905
  if output_hidden_states:
906
  encoder_states = encoder_states + (hidden_states,)
907
  if self.gradient_checkpointing and self.training:
908
+ layer_outputs = self._gradient_checkpointing_func(
909
+ encoder_layer.__call__,
 
 
 
 
 
 
 
910
  hidden_states,
911
  attention_mask,
912
+ output_attentions,
913
  )
914
  else:
915
  layer_outputs = encoder_layer(
916
  hidden_states,
917
  attention_mask,
 
918
  output_attentions=output_attentions,
919
  )
920
 
 
973
 
974
  hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
975
 
976
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
977
  # expand attention_mask
978
  if attention_mask is not None:
979
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
980
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
981
 
982
  encoder_outputs = self.encoder(
983
  inputs_embeds=hidden_states,
984
+ attention_mask=attention_mask,
 
985
  output_attentions=output_attentions,
986
  output_hidden_states=output_hidden_states,
987
  return_dict=return_dict,
 
1048
  >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
1049
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1050
 
1051
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1052
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1053
 
1054
  >>> outputs = model(**inputs)
1055
  >>> last_hidden_state = outputs.last_hidden_state
 
1083
  def forward(
1084
  self,
1085
  pixel_values,
1086
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
1087
  output_attentions: Optional[bool] = None,
1088
  output_hidden_states: Optional[bool] = None,
1089
  return_dict: Optional[bool] = None,
 
1098
  )
1099
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1100
 
1101
+ batch_size = pixel_values.size(0)
1102
+ if patch_attention_mask is None:
1103
+ patch_attention_mask = torch.ones(
1104
+ size=(
1105
+ batch_size,
1106
+ pixel_values.size(2) // self.config.patch_size,
1107
+ pixel_values.size(3) // self.config.patch_size,
1108
+ ),
1109
+ dtype=torch.bool,
1110
+ device=pixel_values.device,
1111
+ )
1112
+
1113
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
1114
+
1115
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1116
 
1117
  encoder_outputs = self.encoder(
1118
  inputs_embeds=hidden_states,
1119
+ attention_mask=(
1120
+ _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
1121
+ if not self.config._flash_attn_2_enabled
1122
+ else patch_attention_mask
1123
+ ),
1124
  output_attentions=output_attentions,
1125
  output_hidden_states=output_hidden_states,
1126
  return_dict=return_dict,
 
1129
  last_hidden_state = encoder_outputs[0]
1130
  last_hidden_state = self.post_layernorm(last_hidden_state)
1131
 
1132
+ pooled_output = self.head(
1133
+ hidden_state=last_hidden_state,
1134
+ attention_mask=patch_attention_mask,
1135
+ )
1136
 
1137
  if not return_dict:
1138
  return (last_hidden_state, pooled_output) + encoder_outputs[1:]
 
1156
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1157
  self.mlp = SiglipMLP(config)
1158
 
1159
+ def forward(self, hidden_state, attention_mask):
1160
  batch_size = hidden_state.shape[0]
1161
  probe = self.probe.repeat(batch_size, 1, 1)
1162
 
1163
+ hidden_state = self.attention(
1164
+ query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
1165
+ )[0]
1166
 
1167
  residual = hidden_state
1168
  hidden_state = self.layernorm(hidden_state)
 
1219
 
1220
  >>> outputs = model(**inputs)
1221
  >>> last_hidden_state = outputs.last_hidden_state
1222
+ >>> pooled_output = outputs.pooler_output # pooled features
1223
  ```"""
1224
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1225
 
 
1253
  text_config = config.text_config
1254
  vision_config = config.vision_config
1255
 
1256
+ self.text_model = SiglipTextTransformer(text_config)
1257
+ self.vision_model = SiglipVisionTransformer(vision_config)
1258
 
1259
+ self.logit_scale = nn.Parameter(torch.randn(1))
1260
+ self.logit_bias = nn.Parameter(torch.randn(1))
 
 
 
 
 
 
 
 
1261
 
1262
  # Initialize weights and apply final processing
1263
  self.post_init()
 
1280
  Examples:
1281
 
1282
  ```python
1283
+ >>> from transformers import AutoTokenizer, AutoModel
1284
+ >>> import torch
1285
 
1286
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1287
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1288
 
1289
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1290
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1291
+ >>> with torch.no_grad():
1292
+ ... text_features = model.get_text_features(**inputs)
1293
  ```"""
1294
  # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1295
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
1329
  ```python
1330
  >>> from PIL import Image
1331
  >>> import requests
1332
+ >>> from transformers import AutoProcessor, AutoModel
1333
+ >>> import torch
1334
 
1335
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1336
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1337
 
1338
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 
1340
 
1341
  >>> inputs = processor(images=image, return_tensors="pt")
1342
 
1343
+ >>> with torch.no_grad():
1344
+ ... image_features = model.get_image_features(**inputs)
1345
  ```"""
1346
  # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1347
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
1382
  ```python
1383
  >>> from PIL import Image
1384
  >>> import requests
1385
+ >>> from transformers import AutoProcessor, AutoModel
1386
+ >>> import torch
1387
 
1388
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1389
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1390
 
1391
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1392
  >>> image = Image.open(requests.get(url, stream=True).raw)
1393
 
1394
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1395
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1396
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1397
 
1398
+ >>> with torch.no_grad():
1399
+ ... outputs = model(**inputs)
1400
+
1401
+ >>> logits_per_image = outputs.logits_per_image
1402
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1403
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1404
+ 31.9% that image 0 is 'a photo of 2 cats'
1405
  ```"""
1406
  # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1407
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
1434
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1435
 
1436
  # cosine similarity as logits
1437
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1438
  logits_per_image = logits_per_text.t()
1439
 
 
 
1440
  loss = None
1441
  if return_loss:
1442
  raise NotImplementedError("SigLIP loss to be implemented")
processing_siglip.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Image/Text processor class for SigLIP.
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ from transformers.feature_extraction_utils import BatchFeature
22
+ from transformers.image_utils import ImageInput
23
+ from transformers.processing_utils import ProcessorMixin
24
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
25
+ from transformers.utils import TensorType
26
+
27
+
28
+ class SiglipProcessor(ProcessorMixin):
29
+ r"""
30
+ Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
31
+
32
+ [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
33
+ [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
34
+
35
+ Args:
36
+ image_processor ([`SiglipImageProcessor`]):
37
+ The image processor is a required input.
38
+ tokenizer ([`SiglipTokenizer`]):
39
+ The tokenizer is a required input.
40
+ """
41
+
42
+ attributes = ["image_processor", "tokenizer"]
43
+ image_processor_class = "SiglipImageProcessor"
44
+ tokenizer_class = "SiglipTokenizer"
45
+
46
+ def __init__(self, image_processor, tokenizer):
47
+ super().__init__(image_processor, tokenizer)
48
+
49
+ def __call__(
50
+ self,
51
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
52
+ images: ImageInput = None,
53
+ padding: Union[bool, str, PaddingStrategy] = False,
54
+ truncation: Union[bool, str, TruncationStrategy] = None,
55
+ max_length: int = None,
56
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
57
+ ) -> BatchFeature:
58
+ """
59
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
60
+ and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
61
+ the text. To prepare the image(s), this method forwards the `images` argument to
62
+ SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
63
+ of the above two methods for more information.
64
+
65
+ Args:
66
+ text (`str`, `List[str]`, `List[List[str]]`):
67
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
68
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
69
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
70
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
71
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
72
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
73
+ number of channels, H and W are image height and width.
74
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
75
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
76
+ index) among:
77
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
78
+ sequence if provided).
79
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
80
+ acceptable input length for the model if that argument is not provided.
81
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
82
+ lengths).
83
+ max_length (`int`, *optional*):
84
+ Maximum length of the returned list and optionally padding length (see above).
85
+ truncation (`bool`, *optional*):
86
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
87
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
88
+ If set, will return tensors of a particular framework. Acceptable values are:
89
+
90
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
91
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
92
+ - `'np'`: Return NumPy `np.ndarray` objects.
93
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
94
+
95
+ Returns:
96
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
97
+
98
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
99
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
100
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
101
+ `None`).
102
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
103
+ """
104
+
105
+ if text is None and images is None:
106
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
107
+
108
+ if text is not None:
109
+ encoding = self.tokenizer(
110
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
111
+ )
112
+
113
+ if images is not None:
114
+ image_features = self.image_processor(images, return_tensors=return_tensors)
115
+
116
+ if text is not None and images is not None:
117
+ encoding["pixel_values"] = image_features.pixel_values
118
+ return encoding
119
+ elif text is not None:
120
+ return encoding
121
+ else:
122
+ return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
123
+
124
+ def decode(self, *args, **kwargs):
125
+ """
126
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
127
+ the docstring of this method for more information.
128
+ """
129
+ return self.tokenizer.decode(*args, **kwargs)
130
+
131
+ def batch_decode(self, *args, **kwargs):
132
+ """
133
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
134
+ refer to the docstring of this method for more information.
135
+ """
136
+ return self.tokenizer.batch_decode(*args, **kwargs)
137
+
138
+ @property
139
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
140
+ def model_input_names(self):
141
+ tokenizer_input_names = self.tokenizer.model_input_names
142
+ image_processor_input_names = self.image_processor.model_input_names
143
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
tokenization_siglip.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Tokenization class for SigLIP model."""
16
+
17
+ import os
18
+ import re
19
+ import string
20
+ import warnings
21
+ from shutil import copyfile
22
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
23
+
24
+ import sentencepiece as spm
25
+
26
+ from transformers.convert_slow_tokenizer import import_protobuf
27
+ from transformers.tokenization_utils import PreTrainedTokenizer
28
+ from transformers.tokenization_utils_base import AddedToken
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers.tokenization_utils_base import TextInput
33
+ from transformers.utils import logging, requires_backends
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
39
+
40
+ PRETRAINED_VOCAB_FILES_MAP = {
41
+ "vocab_file": {
42
+ "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/spiece.model",
43
+ }
44
+ }
45
+
46
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
47
+ "google/siglip-base-patch16-224": 256,
48
+ }
49
+
50
+ SPIECE_UNDERLINE = "▁"
51
+
52
+
53
+ class SiglipTokenizer(PreTrainedTokenizer):
54
+ """
55
+ Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
56
+
57
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
58
+ this superclass for more information regarding those methods.
59
+
60
+ Args:
61
+ vocab_file (`str`):
62
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
63
+ contains the vocabulary necessary to instantiate a tokenizer.
64
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
65
+ The end of sequence token.
66
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
67
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
68
+ token instead.
69
+ pad_token (`str`, *optional*, defaults to `"</s>"`):
70
+ The token used for padding, for example when batching sequences of different lengths.
71
+ additional_special_tokens (`List[str]`, *optional*):
72
+ Additional special tokens used by the tokenizer.
73
+ sp_model_kwargs (`dict`, *optional*):
74
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
75
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
76
+ to set:
77
+
78
+ - `enable_sampling`: Enable subword regularization.
79
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
80
+
81
+ - `nbest_size = {0,1}`: No sampling is performed.
82
+ - `nbest_size > 1`: samples from the nbest_size results.
83
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
84
+ using forward-filtering-and-backward-sampling algorithm.
85
+
86
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
87
+ BPE-dropout.
88
+ model_max_length (`int`, *optional*, defaults to 64):
89
+ The maximum length (in number of tokens) for model inputs.
90
+ do_lower_case (`bool`, *optional*, defaults to `True`):
91
+ Whether or not to lowercase the input when tokenizing.
92
+ """
93
+
94
+ vocab_files_names = VOCAB_FILES_NAMES
95
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
96
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
97
+ model_input_names = ["input_ids", "attention_mask"]
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_file,
102
+ eos_token="</s>",
103
+ unk_token="<unk>",
104
+ pad_token="</s>",
105
+ additional_special_tokens=None,
106
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
107
+ model_max_length=64,
108
+ do_lower_case=True,
109
+ **kwargs,
110
+ ) -> None:
111
+ requires_backends(self, "protobuf")
112
+
113
+ pad_token = (
114
+ AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True)
115
+ if isinstance(pad_token, str)
116
+ else pad_token
117
+ )
118
+ unk_token = (
119
+ AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True)
120
+ if isinstance(unk_token, str)
121
+ else unk_token
122
+ )
123
+ eos_token = (
124
+ AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True)
125
+ if isinstance(eos_token, str)
126
+ else eos_token
127
+ )
128
+
129
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
130
+
131
+ self.do_lower_case = do_lower_case
132
+ self.vocab_file = vocab_file
133
+
134
+ self.sp_model = self.get_spm_processor()
135
+ self.vocab_file = vocab_file
136
+
137
+ super().__init__(
138
+ eos_token=eos_token,
139
+ unk_token=unk_token,
140
+ pad_token=pad_token,
141
+ additional_special_tokens=additional_special_tokens,
142
+ sp_model_kwargs=self.sp_model_kwargs,
143
+ model_max_length=model_max_length,
144
+ do_lower_case=do_lower_case,
145
+ **kwargs,
146
+ )
147
+
148
+ def get_spm_processor(self):
149
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
150
+ with open(self.vocab_file, "rb") as f:
151
+ sp_model = f.read()
152
+ model_pb2 = import_protobuf()
153
+ model = model_pb2.ModelProto.FromString(sp_model)
154
+ normalizer_spec = model_pb2.NormalizerSpec()
155
+ normalizer_spec.add_dummy_prefix = False
156
+ model.normalizer_spec.MergeFrom(normalizer_spec)
157
+ sp_model = model.SerializeToString()
158
+ tokenizer.LoadFromSerializedProto(sp_model)
159
+ return tokenizer
160
+
161
+ @property
162
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
163
+ def vocab_size(self):
164
+ return self.sp_model.get_piece_size()
165
+
166
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
167
+ def get_vocab(self):
168
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
169
+ vocab.update(self.added_tokens_encoder)
170
+ return vocab
171
+
172
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
173
+ def get_special_tokens_mask(
174
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
175
+ ) -> List[int]:
176
+ """
177
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
178
+ special tokens using the tokenizer `prepare_for_model` method.
179
+
180
+ Args:
181
+ token_ids_0 (`List[int]`):
182
+ List of IDs.
183
+ token_ids_1 (`List[int]`, *optional*):
184
+ Optional second list of IDs for sequence pairs.
185
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
186
+ Whether or not the token list is already formatted with special tokens for the model.
187
+
188
+ Returns:
189
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
190
+ """
191
+ if already_has_special_tokens:
192
+ return super().get_special_tokens_mask(
193
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
194
+ )
195
+
196
+ # normal case: some special tokens
197
+ if token_ids_1 is None:
198
+ return ([0] * len(token_ids_0)) + [1]
199
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
200
+
201
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
202
+ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
203
+ """Do not add eos again if user already added it."""
204
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
205
+ warnings.warn(
206
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
207
+ " eos tokens being added."
208
+ )
209
+ return token_ids
210
+ else:
211
+ return token_ids + [self.eos_token_id]
212
+
213
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
214
+ def create_token_type_ids_from_sequences(
215
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
216
+ ) -> List[int]:
217
+ """
218
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
219
+ use of token type ids, therefore a list of zeros is returned.
220
+
221
+ Args:
222
+ token_ids_0 (`List[int]`):
223
+ List of IDs.
224
+ token_ids_1 (`List[int]`, *optional*):
225
+ Optional second list of IDs for sequence pairs.
226
+
227
+ Returns:
228
+ `List[int]`: List of zeros.
229
+ """
230
+ eos = [self.eos_token_id]
231
+
232
+ if token_ids_1 is None:
233
+ return len(token_ids_0 + eos) * [0]
234
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
235
+
236
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
237
+ def build_inputs_with_special_tokens(
238
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
239
+ ) -> List[int]:
240
+ """
241
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
242
+ adding special tokens. A sequence has the following format:
243
+
244
+ - single sequence: `X </s>`
245
+ - pair of sequences: `A </s> B </s>`
246
+
247
+ Args:
248
+ token_ids_0 (`List[int]`):
249
+ List of IDs to which the special tokens will be added.
250
+ token_ids_1 (`List[int]`, *optional*):
251
+ Optional second list of IDs for sequence pairs.
252
+
253
+ Returns:
254
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
255
+ """
256
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
257
+ if token_ids_1 is None:
258
+ return token_ids_0
259
+ else:
260
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
261
+ return token_ids_0 + token_ids_1
262
+
263
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
264
+ def __getstate__(self):
265
+ state = self.__dict__.copy()
266
+ state["sp_model"] = None
267
+ return state
268
+
269
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
270
+ def __setstate__(self, d):
271
+ self.__dict__ = d
272
+
273
+ # for backward compatibility
274
+ if not hasattr(self, "sp_model_kwargs"):
275
+ self.sp_model_kwargs = {}
276
+
277
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
278
+ self.sp_model.Load(self.vocab_file)
279
+
280
+ def remove_punctuation(self, text: str) -> str:
281
+ return text.translate(str.maketrans("", "", string.punctuation))
282
+
283
+ # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
284
+ def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
285
+ """Returns canonicalized `text` (puncuation removed).
286
+
287
+ Args:
288
+ text (`str`):
289
+ String to be canonicalized.
290
+ keep_punctuation_exact_string (`str`, *optional*):
291
+ If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
292
+ (but will still remove '{' and '}' that appear separately).
293
+ """
294
+ if keep_punctuation_exact_string:
295
+ text = keep_punctuation_exact_string.join(
296
+ self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string)
297
+ )
298
+ else:
299
+ text = self.remove_punctuation(text)
300
+ text = re.sub(r"\s+", " ", text)
301
+ text = text.strip()
302
+
303
+ return text
304
+
305
+ def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
306
+ """
307
+ Converts a string to a list of tokens.
308
+ """
309
+ tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
310
+
311
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
312
+ tokens = tokens[1:]
313
+ return tokens
314
+
315
+ @property
316
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
317
+ def unk_token_length(self):
318
+ return len(self.sp_model.encode(str(self.unk_token)))
319
+
320
+ def _tokenize(self, text, **kwargs):
321
+ """
322
+ Returns a tokenized string.
323
+
324
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
325
+ SPIECE_UNDERLINE.
326
+
327
+ For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
328
+
329
+ Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
330
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
331
+ """
332
+ text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
333
+ tokens = self.sp_model.encode(text, out_type=str)
334
+
335
+ # 1. Encode string + prefix ex: "<unk> Hey"
336
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
337
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
338
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
339
+
340
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
341
+ def _convert_token_to_id(self, token):
342
+ """Converts a token (str) in an id using the vocab."""
343
+ return self.sp_model.piece_to_id(token)
344
+
345
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
346
+ def _convert_id_to_token(self, index):
347
+ """Converts an index (integer) in a token (str) using the vocab."""
348
+ token = self.sp_model.IdToPiece(index)
349
+ return token
350
+
351
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.convert_tokens_to_string
352
+ def convert_tokens_to_string(self, tokens):
353
+ """Converts a sequence of tokens (string) in a single string."""
354
+ current_sub_tokens = []
355
+ # since we manually add the prefix space, we have to remove it
356
+ tokens[0] = tokens[0].lstrip(SPIECE_UNDERLINE)
357
+ out_string = ""
358
+ prev_is_special = False
359
+ for token in tokens:
360
+ # make sure that special tokens are not decoded using sentencepiece model
361
+ if token in self.all_special_tokens:
362
+ if not prev_is_special:
363
+ out_string += " "
364
+ out_string += self.sp_model.decode(current_sub_tokens) + token
365
+ prev_is_special = True
366
+ current_sub_tokens = []
367
+ else:
368
+ current_sub_tokens.append(token)
369
+ prev_is_special = False
370
+ out_string += self.sp_model.decode(current_sub_tokens)
371
+ return out_string.strip()
372
+
373
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
374
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
375
+ if not os.path.isdir(save_directory):
376
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
377
+ return
378
+ out_vocab_file = os.path.join(
379
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
380
+ )
381
+
382
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
383
+ copyfile(self.vocab_file, out_vocab_file)
384
+ elif not os.path.isfile(self.vocab_file):
385
+ with open(out_vocab_file, "wb") as fi:
386
+ content_spiece_model = self.sp_model.serialized_model_proto()
387
+ fi.write(content_spiece_model)
388
+
389
+ return (out_vocab_file,)