LinWeizheDragon commited on
Commit
3704bcf
1 Parent(s): 586935b

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./FLMR",
3
+ "architectures": [
4
+ "FLMRModelForRetrieval"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_flmr.FLMRConfig",
8
+ "AutoModel": "modeling_flmr.FLMRModelForRetrieval"
9
+ },
10
+ "context_concat_output_from_text_encoder": true,
11
+ "context_concat_output_from_vision_encoder": false,
12
+ "dim": 128,
13
+ "initializer_range": 0.02,
14
+ "load_cpu_extension": false,
15
+ "mapping_network_prefix_length": 32,
16
+ "mask_instruction_token": null,
17
+ "mask_punctuation": true,
18
+ "model_type": "flmr",
19
+ "query_concat_output_from_text_encoder": true,
20
+ "query_concat_output_from_vision_encoder": true,
21
+ "separate_query_and_context_text_encoder": false,
22
+ "separate_query_and_context_vision_encoder": false,
23
+ "text_config": {
24
+ "architectures": [
25
+ "BertForMaskedLM"
26
+ ],
27
+ "gradient_checkpointing": false,
28
+ "model_type": "flmr_text_model",
29
+ "use_cache": true,
30
+ "vocab_size": 30531
31
+ },
32
+ "torch_dtype": "float32",
33
+ "transformer_mapping_config_base": null,
34
+ "transformer_mapping_cross_attention_length": 32,
35
+ "transformer_mapping_num_hidden_layers": null,
36
+ "transformers_version": "4.37.2",
37
+ "use_transformer_mapping_network": false,
38
+ "use_vision_encoder": true,
39
+ "vision_config": {
40
+ "dropout": 0.0,
41
+ "model_type": "flmr_vision_model"
42
+ },
43
+ "vision_model_version": "openai/clip-vit-base-patch32"
44
+ }
configuration_flmr.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2010, FLMR authors, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ FLMR model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ FLMR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "LinWeizheDragon/PreFLMR_ViT-L": "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/config.json",
28
+ "LinWeizheDragon/FLMR": "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/config.json",
29
+ }
30
+
31
+
32
+ # Modified from transformers.models.clip.configuration_clip.CLIPVisionConfig with CLIP -> FLMR
33
+ class FLMRVisionConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`FLMRVisionModel`]. It is used to instantiate a
36
+ FLMR vision encoder according to the specified arguments, defining the model architecture. Instantiating a
37
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the FLMR
38
+ [openai/flmr-vit-base-patch32](https://huggingface.co/openai/flmr-vit-base-patch32) architecture.
39
+
40
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
41
+ documentation from [`PretrainedConfig`] for more information.
42
+
43
+ Args:
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ intermediate_size (`int`, *optional*, defaults to 3072):
47
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
48
+ projection_dim (`int`, *optional*, defaults to 512):
49
+ Dimentionality of text and vision projection layers.
50
+ num_hidden_layers (`int`, *optional*, defaults to 12):
51
+ Number of hidden layers in the Transformer encoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 12):
53
+ Number of attention heads for each attention layer in the Transformer encoder.
54
+ num_channels (`int`, *optional*, defaults to 3):
55
+ The number of input channels.
56
+ image_size (`int`, *optional*, defaults to 224):
57
+ The size (resolution) of each image.
58
+ patch_size (`int`, *optional*, defaults to 32):
59
+ The size (resolution) of each patch.
60
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
61
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
62
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
63
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
64
+ The epsilon used by the layer normalization layers.
65
+ attention_dropout (`float`, *optional*, defaults to 0.0):
66
+ The dropout ratio for the attention probabilities.
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ initializer_factor (`float`, *optional*, defaults to 1.0):
70
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
71
+ testing).
72
+
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import FLMRVisionConfig, FLMRVisionModel
77
+
78
+ >>> # Initializing a FLMRVisionConfig with LinWeizheDragon/FLMR style configuration
79
+ >>> configuration = FLMRVisionConfig()
80
+
81
+ >>> # Initializing a FLMRVisionModel (with random weights) from the LinWeizheDragon/FLMR style configuration
82
+ >>> model = FLMRVisionModel(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+
88
+ model_type = "flmr_vision_model"
89
+
90
+ def __init__(
91
+ self,
92
+ hidden_size=768,
93
+ intermediate_size=3072,
94
+ projection_dim=512,
95
+ num_hidden_layers=12,
96
+ num_attention_heads=12,
97
+ num_channels=3,
98
+ image_size=224,
99
+ patch_size=32,
100
+ hidden_act="quick_gelu",
101
+ layer_norm_eps=1e-5,
102
+ attention_dropout=0.0,
103
+ initializer_range=0.02,
104
+ initializer_factor=1.0,
105
+ **kwargs,
106
+ ):
107
+ super().__init__(**kwargs)
108
+
109
+ self.hidden_size = hidden_size
110
+ self.intermediate_size = intermediate_size
111
+ self.projection_dim = projection_dim
112
+ self.num_hidden_layers = num_hidden_layers
113
+ self.num_attention_heads = num_attention_heads
114
+ self.num_channels = num_channels
115
+ self.patch_size = patch_size
116
+ self.image_size = image_size
117
+ self.initializer_range = initializer_range
118
+ self.initializer_factor = initializer_factor
119
+ self.attention_dropout = attention_dropout
120
+ self.layer_norm_eps = layer_norm_eps
121
+ self.hidden_act = hidden_act
122
+
123
+ @classmethod
124
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
125
+ cls._set_token_in_kwargs(kwargs)
126
+
127
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
128
+
129
+ # get the vision config dict if we are loading from a CLIPConfig
130
+ if config_dict.get("model_type") == "clip":
131
+ config_dict = config_dict["vision_config"]
132
+
133
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
134
+ logger.warning(
135
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
136
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
137
+ )
138
+
139
+ return cls.from_dict(config_dict, **kwargs)
140
+
141
+
142
+ # Modified from transformers.models.dpr.configuration_dpr.DPRConfig with DPR -> FLMR
143
+ class FLMRTextConfig(PretrainedConfig):
144
+ r"""
145
+ [`FLMRTextConfig`] is the configuration class to store the configuration of a *FLMRTextModel*.
146
+
147
+ This is the configuration class to store the configuration of a [`FLMRTextModel`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
148
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
149
+ configuration to that of the DPRContextEncoder
150
+ [facebook/dpr-ctx_encoder-single-nq-base](https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base)
151
+ architecture.
152
+
153
+ This class is a subclass of [`BertConfig`]. Please check the superclass for the documentation of all kwargs.
154
+
155
+ Args:
156
+ vocab_size (`int`, *optional*, defaults to 30522):
157
+ Vocabulary size of the FLMR model. Defines the different tokens that can be represented by the *inputs_ids*
158
+ passed to the forward method of [`BertModel`].
159
+ hidden_size (`int`, *optional*, defaults to 768):
160
+ Dimensionality of the encoder layers and the pooler layer.
161
+ num_hidden_layers (`int`, *optional*, defaults to 12):
162
+ Number of hidden layers in the Transformer encoder.
163
+ num_attention_heads (`int`, *optional*, defaults to 12):
164
+ Number of attention heads for each attention layer in the Transformer encoder.
165
+ intermediate_size (`int`, *optional*, defaults to 3072):
166
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
167
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
168
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
169
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
170
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
171
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
172
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
173
+ The dropout ratio for the attention probabilities.
174
+ max_position_embeddings (`int`, *optional*, defaults to 512):
175
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
176
+ just in case (e.g., 512 or 1024 or 2048).
177
+ type_vocab_size (`int`, *optional*, defaults to 2):
178
+ The vocabulary size of the *token_type_ids* passed into [`BertModel`].
179
+ initializer_range (`float`, *optional*, defaults to 0.02):
180
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
181
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
182
+ The epsilon used by the layer normalization layers.
183
+ pad_token_id (`int`, *optional*, defaults to 0):
184
+ Padding token id.
185
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
186
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
187
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
188
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
189
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
190
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
191
+ projection_dim (`int`, *optional*, defaults to 0):
192
+ Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
193
+ projection is done.
194
+
195
+ Example:
196
+
197
+ ```python
198
+ >>> from transformers import FLMRTextConfig, FLMRTextModel
199
+
200
+ >>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
201
+ >>> configuration = FLMRTextConfig()
202
+
203
+ >>> # Initializing a model (with random weights) from the LinWeizheDragon/FLMR style configuration
204
+ >>> model = FLMRTextModel(configuration)
205
+
206
+ >>> # Accessing the model configuration
207
+ >>> configuration = model.config
208
+ ```"""
209
+
210
+ model_type = "flmr_text_model"
211
+
212
+ def __init__(
213
+ self,
214
+ vocab_size=30522,
215
+ hidden_size=768,
216
+ num_hidden_layers=12,
217
+ num_attention_heads=12,
218
+ intermediate_size=3072,
219
+ hidden_act="gelu",
220
+ hidden_dropout_prob=0.1,
221
+ attention_probs_dropout_prob=0.1,
222
+ max_position_embeddings=512,
223
+ type_vocab_size=2,
224
+ initializer_range=0.02,
225
+ layer_norm_eps=1e-12,
226
+ pad_token_id=0,
227
+ position_embedding_type="absolute",
228
+ projection_dim: int = 0,
229
+ **kwargs,
230
+ ):
231
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
232
+
233
+ self.vocab_size = vocab_size
234
+ self.hidden_size = hidden_size
235
+ self.num_hidden_layers = num_hidden_layers
236
+ self.num_attention_heads = num_attention_heads
237
+ self.hidden_act = hidden_act
238
+ self.intermediate_size = intermediate_size
239
+ self.hidden_dropout_prob = hidden_dropout_prob
240
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
241
+ self.max_position_embeddings = max_position_embeddings
242
+ self.type_vocab_size = type_vocab_size
243
+ self.initializer_range = initializer_range
244
+ self.layer_norm_eps = layer_norm_eps
245
+ self.projection_dim = projection_dim
246
+ self.position_embedding_type = position_embedding_type
247
+
248
+
249
+ class FLMRConfig(PretrainedConfig):
250
+ r"""
251
+ [`FLMRConfig`] is the configuration class to store the configuration of a *FLMRModelForRetrieval*.
252
+ This is the configuration class to store the configuration of a [`FLMRModelForRetrieval`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
253
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
254
+ configuration to that of the FLMR
255
+ [LinWeizheDragon/PreFLMR_ViT-G](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-G)
256
+ architecture.
257
+
258
+ Args:
259
+ vision_config (`FLMRVisionConfig`, *optional*):
260
+ Configuration for the vision encoder.
261
+ text_config (`FLMRTextConfig`, *optional*):
262
+ Configuration for the text encoder.
263
+ mask_punctuation (`bool`, *optional*, defaults to `True`):
264
+ Whether to mask punctuation tokens in the input.
265
+ mapping_network_prefix_length (`int`, *optional*, defaults to 32):
266
+ The output length of the linear mapping network.
267
+ dim (`int`, *optional*, defaults to 128):
268
+ The late-interaction dimension of the model. The output of the text encoder, vision encoder, transformer mapping network should all be projected to this dimension for late-interaction scoring.
269
+ use_vision_encoder (`bool`, *optional*, defaults to `True`):
270
+ Whether to load the vision encoder. When no vision encoder is loaded, `image_features` should be used in the forward pass rather than `pixel_values`.
271
+ initializer_range (`float`, *optional*, defaults to 0.02):
272
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
273
+ separate_query_and_context_text_encoder (`bool`, *optional*, defaults to `False`):
274
+ Whether to use separate text encoders for query and context.
275
+ separate_query_and_context_vision_encoder (`bool`, *optional*, defaults to `False`):
276
+ Whether to use separate vision encoders for query and context.
277
+ query_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `True`):
278
+ Whether to concatenate the output from the vision encoder to the output from the text encoder for the query.
279
+ query_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
280
+ Whether to concatenate the output from the text encoder to the output from the vision encoder for the query.
281
+ context_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `False`):
282
+ Whether to concatenate the output from the vision encoder to the output from the text encoder for the context.
283
+ context_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
284
+ Whether to concatenate the output from the text encoder to the output from the vision encoder for the context.
285
+ use_transformer_mapping_network (`bool`, *optional*, defaults to `False`):
286
+ Whether to add a transformer mapping network to map the features from the vision encoder to the embedding space. This option is used in PreFLMR.
287
+ transformer_mapping_config_base (`str`, *optional*):
288
+ The base configuration for the transformer mapping network. This option is used in PreFLMR. An example of this argument is `bert-base-uncased`.
289
+ transformer_mapping_num_hidden_layers (`int`, *optional*):
290
+ The number of hidden layers in the transformer mapping network. This option is used in PreFLMR.
291
+ load_cpu_extension (`bool`, *optional*, defaults to `False`):
292
+ Whether to load the CPU extension. Only set this to `True` if a CPU is used in training and inference. In any case, GPU is recommended for training and inference.
293
+ mask_instruction_token (`str`, *optional*):
294
+ The token that indicates the end of the input instruction. All tokens before this token (the first one in a sequence) will be masked. This option is used in PreFLMR.
295
+ transformer_mapping_cross_attention_length (`int`, *optional*, defaults to 32):
296
+ The length of the cross attention in the transformer mapping network. This option is used in PreFLMR.
297
+ vision_model_version (`str`, *optional*, defaults to `"openai/clip-vit-base-patch32"`):
298
+ The version of the vision model being used in this FLMR model.
299
+ This option is used in performing retrieval only. Though it does not affect the model architecture, it is highly recommended to set this argument so that it properly reflects the version of the vision model being used in the FLMR model. This arugment will be saved in the model configuration, and it can be read by the indexing engine. The indexing engine will use this argument to initialize an image processor, which can process the input image files. Find more details under `examples/research_projects/flmr-retrieval`.
300
+
301
+ Example:
302
+
303
+ ```python
304
+ >>> from transformers import FLMRConfig, FLMRModelForRetrieval
305
+
306
+ >>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
307
+ >>> configuration = FLMRConfig()
308
+
309
+ >>> # Initializing a model (with random weights) from the FLMR style configuration
310
+ >>> model = FLMRModelForRetrieval(configuration)
311
+
312
+ >>> # Accessing the model configuration
313
+ >>> configuration = model.config
314
+ ```"""
315
+
316
+ model_type = "flmr"
317
+
318
+ def __init__(
319
+ self,
320
+ vision_config: FLMRVisionConfig = None,
321
+ text_config: FLMRTextConfig = None,
322
+ mask_punctuation: bool = True,
323
+ mapping_network_prefix_length: int = 32,
324
+ dim: int = 128,
325
+ use_vision_encoder: bool = True,
326
+ initializer_range: float = 0.02,
327
+ separate_query_and_context_text_encoder: bool = False,
328
+ separate_query_and_context_vision_encoder: bool = False,
329
+ query_concat_output_from_vision_encoder: bool = True,
330
+ query_concat_output_from_text_encoder: bool = True,
331
+ context_concat_output_from_vision_encoder: bool = False,
332
+ context_concat_output_from_text_encoder: bool = True,
333
+ use_transformer_mapping_network: bool = False,
334
+ transformer_mapping_config_base: str = None,
335
+ transformer_mapping_num_hidden_layers: int = None,
336
+ load_cpu_extension: bool = False,
337
+ mask_instruction_token: str = None,
338
+ transformer_mapping_cross_attention_length: int = 32,
339
+ vision_model_version: str = "openai/clip-vit-base-patch32",
340
+ **kwargs,
341
+ ):
342
+ super().__init__(**kwargs)
343
+
344
+ if vision_config is None:
345
+ vision_config = {}
346
+ if text_config is None:
347
+ text_config = {}
348
+
349
+ if not isinstance(vision_config, FLMRVisionConfig):
350
+ vision_config = FLMRVisionConfig(**vision_config)
351
+ if not isinstance(text_config, FLMRTextConfig):
352
+ text_config = FLMRTextConfig(**text_config)
353
+
354
+ self.vision_config = vision_config
355
+ self.text_config = text_config
356
+ self.dim = dim
357
+ self.initializer_range = initializer_range
358
+ self.mask_punctuation = mask_punctuation
359
+ self.mapping_network_prefix_length = mapping_network_prefix_length
360
+ self.use_vision_encoder = use_vision_encoder
361
+ self.separate_query_and_context_text_encoder = separate_query_and_context_text_encoder
362
+ self.separate_query_and_context_vision_encoder = separate_query_and_context_vision_encoder
363
+ self.query_concat_output_from_vision_encoder = query_concat_output_from_vision_encoder
364
+ self.query_concat_output_from_text_encoder = query_concat_output_from_text_encoder
365
+ self.context_concat_output_from_vision_encoder = context_concat_output_from_vision_encoder
366
+ self.context_concat_output_from_text_encoder = context_concat_output_from_text_encoder
367
+ self.use_transformer_mapping_network = use_transformer_mapping_network
368
+ self.transformer_mapping_config_base = transformer_mapping_config_base
369
+ self.transformer_mapping_num_hidden_layers = transformer_mapping_num_hidden_layers
370
+ self.load_cpu_extension = load_cpu_extension
371
+ self.mask_instruction_token = mask_instruction_token
372
+ self.transformer_mapping_cross_attention_length = transformer_mapping_cross_attention_length
373
+ self.vision_model_version = vision_model_version
374
+
375
+ @classmethod
376
+ def from_text_vision_configs(cls, text_config: FLMRTextConfig, vision_config: FLMRVisionConfig, **kwargs):
377
+ r"""
378
+ Instantiate a [`FLMRConfig`] (or a derived class) from FLMR text model configuration and FLMR vision model
379
+ configuration.
380
+
381
+ Returns:
382
+ [`FLMRConfig`]: An instance of a configuration object
383
+ """
384
+
385
+ return cls(text_config=text_config, vision_config=vision_config, **kwargs)
context_tokenizer/added_tokens.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<BOC>": 30527,
3
+ "<BOK>": 30529,
4
+ "<BOQ>": 30525,
5
+ "<BOV>": 30522,
6
+ "<EOC>": 30528,
7
+ "<EOK>": 30530,
8
+ "<EOQ>": 30526,
9
+ "<EOV>": 30524,
10
+ "<SOV>": 30523
11
+ }
context_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<BOV>",
4
+ "<SOV>",
5
+ "<EOV>",
6
+ "<BOQ>",
7
+ "<EOQ>",
8
+ "<BOC>",
9
+ "<EOC>",
10
+ "<BOK>",
11
+ "<EOK>"
12
+ ],
13
+ "cls_token": {
14
+ "content": "[CLS]",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "mask_token": {
21
+ "content": "[MASK]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "pad_token": {
28
+ "content": "[PAD]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ },
34
+ "sep_token": {
35
+ "content": "[SEP]",
36
+ "lstrip": false,
37
+ "normalized": false,
38
+ "rstrip": false,
39
+ "single_word": false
40
+ },
41
+ "unk_token": {
42
+ "content": "[UNK]",
43
+ "lstrip": false,
44
+ "normalized": false,
45
+ "rstrip": false,
46
+ "single_word": false
47
+ }
48
+ }
context_tokenizer/tokenization_flmr.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
27
+
28
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "LinWeizheDragon/PreFLMR_ViT-L": (
31
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
32
+ ),
33
+ "LinWeizheDragon/FLMR": (
34
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
35
+ ),
36
+ },
37
+ "tokenizer_file": {
38
+ "LinWeizheDragon/PreFLMR_ViT-L": (
39
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
40
+ ),
41
+ "LinWeizheDragon/FLMR": (
42
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
43
+ ),
44
+ },
45
+ }
46
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
47
+ "vocab_file": {
48
+ "LinWeizheDragon/PreFLMR_ViT-L": (
49
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
50
+ ),
51
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
52
+ },
53
+ "tokenizer_file": {
54
+ "LinWeizheDragon/PreFLMR_ViT-L": (
55
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
56
+ ),
57
+ "LinWeizheDragon/FLMR": (
58
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
59
+ ),
60
+ },
61
+ }
62
+
63
+
64
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
66
+ "LinWeizheDragon/FLMR": 512,
67
+ }
68
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
69
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
70
+ "LinWeizheDragon/FLMR": 512,
71
+ }
72
+
73
+
74
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
75
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
76
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
77
+ }
78
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
79
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
80
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
81
+ }
82
+
83
+
84
+ # Modified from colbert.modeling.tokenization
85
+ class FLMRContextEncoderTokenizer(BertTokenizer):
86
+ r"""
87
+ Construct a FLMRContextEncoder tokenizer.
88
+
89
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
90
+ splitting and wordpiece.
91
+
92
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
99
+
100
+ def __init__(
101
+ self,
102
+ doc_maxlen: Optional[int] = 512,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ doc_maxlen=doc_maxlen,
107
+ **kwargs,
108
+ )
109
+
110
+ self.doc_maxlen = doc_maxlen
111
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
112
+
113
+ def __call__(
114
+ self,
115
+ text: List[str],
116
+ padding: Optional[Union[str, bool]] = "max_length",
117
+ truncation: Optional[Union[bool, str]] = "longest_first",
118
+ max_length: Optional[int] = 512,
119
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
120
+ **kwargs,
121
+ ):
122
+ # add placehold for the [D] marker
123
+ text = [". " + x for x in text]
124
+
125
+ if max_length > self.doc_maxlen:
126
+ # can not exceed the pre-set length
127
+ max_length = self.doc_maxlen
128
+
129
+ encoding = super().__call__(
130
+ text,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ return_tensors=return_tensors,
134
+ max_length=max_length,
135
+ **kwargs,
136
+ )
137
+
138
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
139
+
140
+ # postprocess for the [D] marker
141
+ ids[:, 1] = self.D_marker_token_id
142
+
143
+ # if bsize:
144
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
145
+ # if image_features is not None:
146
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
147
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
148
+ # else:
149
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
150
+ # batches = _split_into_batches(ids, mask, bsize)
151
+
152
+ # return batches, reverse_indices
153
+
154
+ encoding["input_ids"] = ids
155
+ encoding["attention_mask"] = mask
156
+
157
+ return encoding
158
+
159
+
160
+ # Modified from colbert.modeling.tokenization
161
+ class FLMRQueryEncoderTokenizer(BertTokenizer):
162
+ r"""
163
+ Constructs a FLMRQueryEncoder tokenizer.
164
+
165
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
166
+ splitting and wordpiece.
167
+
168
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
169
+ """
170
+
171
+ vocab_files_names = VOCAB_FILES_NAMES
172
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
173
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
174
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ query_maxlen: Optional[int] = 32,
180
+ attend_to_mask_tokens: Optional[bool] = False,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ *args,
185
+ query_maxlen=query_maxlen,
186
+ attend_to_mask_tokens=attend_to_mask_tokens,
187
+ **kwargs,
188
+ )
189
+
190
+ self.query_maxlen = query_maxlen
191
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
192
+ self.attend_to_mask_tokens = attend_to_mask_tokens
193
+
194
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
195
+
196
+ def __call__(
197
+ self,
198
+ text: Union[str, List[str]],
199
+ padding: Optional[Union[str, bool]] = "max_length",
200
+ truncation: Optional[Union[bool, str]] = True,
201
+ max_length: Optional[int] = None,
202
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
203
+ **kwargs,
204
+ ):
205
+ if isinstance(text, str):
206
+ # convert to list if input is a single string
207
+ text = [text]
208
+
209
+ # add placehold for the [Q] marker
210
+ text = [". " + x for x in text]
211
+
212
+ if max_length is not None:
213
+ # use user specified max_length
214
+ pass
215
+ else:
216
+ # use default max length
217
+ max_length = self.query_maxlen
218
+
219
+ encoding = super().__call__(
220
+ text,
221
+ padding=padding,
222
+ truncation=truncation,
223
+ return_tensors=return_tensors,
224
+ max_length=max_length,
225
+ **kwargs,
226
+ )
227
+
228
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
229
+
230
+ # postprocess for the [Q] marker and the [MASK] augmentation
231
+ ids[:, 1] = self.Q_marker_token_id
232
+ ids[ids == self.pad_token_id] = self.mask_token_id
233
+
234
+ if self.attend_to_mask_tokens:
235
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
236
+ mask[ids == self.mask_token_id] = 1
237
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
238
+
239
+ return {"input_ids": ids, "attention_mask": mask}
context_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "30522": {
44
+ "content": "<BOV>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "30523": {
52
+ "content": "<SOV>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "30524": {
60
+ "content": "<EOV>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "30525": {
68
+ "content": "<BOQ>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "30526": {
76
+ "content": "<EOQ>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "30527": {
84
+ "content": "<BOC>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "30528": {
92
+ "content": "<EOC>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "30529": {
100
+ "content": "<BOK>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "30530": {
108
+ "content": "<EOK>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ }
115
+ },
116
+ "additional_special_tokens": [
117
+ "<BOV>",
118
+ "<SOV>",
119
+ "<EOV>",
120
+ "<BOQ>",
121
+ "<EOQ>",
122
+ "<BOC>",
123
+ "<EOC>",
124
+ "<BOK>",
125
+ "<EOK>"
126
+ ],
127
+ "auto_map": {
128
+ "AutoTokenizer": [
129
+ "tokenization_flmr.FLMRContextEncoderTokenizer",
130
+ null
131
+ ]
132
+ },
133
+ "clean_up_tokenization_spaces": true,
134
+ "cls_token": "[CLS]",
135
+ "do_basic_tokenize": true,
136
+ "do_lower_case": true,
137
+ "doc_maxlen": 512,
138
+ "mask_token": "[MASK]",
139
+ "model_max_length": 1000000000000000019884624838656,
140
+ "never_split": null,
141
+ "pad_token": "[PAD]",
142
+ "sep_token": "[SEP]",
143
+ "strip_accents": null,
144
+ "tokenize_chinese_chars": true,
145
+ "tokenizer_class": "FLMRContextEncoderTokenizer",
146
+ "unk_token": "[UNK]"
147
+ }
context_tokenizer/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
flmr_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains utility functions for the FLMR model. Some of these functions are adapted from the original ColBERT codebase.
3
+ """
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+
9
+ def get_rank():
10
+ return dist.get_rank()
11
+
12
+
13
+ def get_world_size():
14
+ return dist.get_world_size()
15
+
16
+
17
+ def get_default_group():
18
+ return dist.group.WORLD
19
+
20
+
21
+ # TODO: The masking below might also be applicable in the kNN part
22
+ def colbert_score_reduce(scores_padded, D_mask):
23
+ # print('D_mask', D_mask.shape, D_mask)
24
+ D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
25
+ # print('D_padding', D_padding.shape, D_padding)
26
+ # print(D_padding[0].tolist())
27
+ scores_padded[D_padding] = -9999
28
+ scores = scores_padded.max(1).values
29
+
30
+ return scores.sum(-1)
31
+
32
+
33
+ def colbert_score(Q, D_padded, D_mask, use_gpu=False):
34
+ """
35
+ Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim).
36
+ If Q.size(0) is 1, the matrix will be compared with all passages.
37
+ Otherwise, each query matrix will be compared against the *aligned* passage.
38
+
39
+ EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU).
40
+ """
41
+ if use_gpu:
42
+ Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda()
43
+ assert Q.dim() == 3, Q.size()
44
+ assert D_padded.dim() == 3, D_padded.size()
45
+ assert Q.size(0) in [1, D_padded.size(0)]
46
+
47
+ scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
48
+
49
+ return colbert_score_reduce(scores, D_mask)
50
+
51
+
52
+ def _sort_by_length(ids, mask, bsize, *args):
53
+ if ids.size(0) <= bsize:
54
+ return ids, mask, torch.arange(ids.size(0))
55
+
56
+ indices = mask.sum(-1).sort().indices
57
+ reverse_indices = indices.sort().indices
58
+
59
+ return_array = [ids[indices], mask[indices]]
60
+ for arg in args:
61
+ if isinstance(arg, torch.Tensor):
62
+ return_array.append(arg[indices])
63
+ else:
64
+ # arg is a list, and we want to sort the list according to indices
65
+ return_array.append([arg[i] for i in indices])
66
+
67
+ return *return_array, reverse_indices
68
+
69
+
70
+ def _split_into_batches(ids, mask, bsize, *args):
71
+ batches = []
72
+ for offset in range(0, ids.size(0), bsize):
73
+ batch = [ids[offset : offset + bsize], mask[offset : offset + bsize]]
74
+ for arg in args:
75
+ batch.append(arg[offset : offset + bsize])
76
+ batches.append(batch)
77
+ return batches
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1364a0f3a77b48c8f58dd73bbcde64096c5651787d6aeb62d43563f474be4b22
3
+ size 828103856
modeling_flmr.py ADDED
@@ -0,0 +1,1502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 FLMR Authors, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch FLMR model for Knowledge-intensive Visual Question Answering."""
16
+
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+ import string
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ from torch import Tensor, nn
28
+ from torch.utils.cpp_extension import load
29
+
30
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (
33
+ ModelOutput,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
39
+ from transformers.models.bert.modeling_bert import BertModel
40
+ from transformers.models.clip import CLIPVisionModel
41
+ from .configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig
42
+ from .tokenization_flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer
43
+ from .tokenization_flmr_fast import FLMRQueryEncoderTokenizerFast, FLMRContextEncoderTokenizerFast
44
+ from .flmr_utils import (
45
+ colbert_score,
46
+ colbert_score_reduce,
47
+ get_rank,
48
+ get_world_size,
49
+ )
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CONFIG_FOR_DOC = "FLMRConfig"
55
+ _CHECKPOINT_FOR_DOC = "LinWeizheDragon/PreFLMR_ViT-L"
56
+
57
+
58
+ FLMR_PRETRAINED_MODEL_ARCHIVE_LIST = [
59
+ "LinWeizheDragon/PreFLMR_ViT-L",
60
+ "LinWeizheDragon/FLMR",
61
+ # See all FLMR models at https://huggingface.co/models?filter=flmr
62
+ ]
63
+
64
+
65
+ ##########
66
+ # Outputs
67
+ ##########
68
+
69
+
70
+ @dataclass
71
+ class FLMRContextEncoderOutput(ModelOutput):
72
+ """
73
+ Class for outputs of the `doc()` function of [`FLMRModelForRetrieval`].
74
+
75
+ Args:
76
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
77
+ The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the context representation.
78
+ This output can be used to embed questions for nearest neighbors queries with query embeddings.
79
+ late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
80
+ The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
81
+ This output is to be used to embed contexts for late-interaction retrieval with query embeddings.
82
+ context_mask (`torch.FloatTensor` of shape `(batch_size, context_embedding_length)`):
83
+ The FLMR encoder outputs the *context_mask* that corresponds to the mask of the context representation.
84
+ text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
85
+ Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
86
+ tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
87
+ text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
88
+ Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
89
+ outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
90
+ vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
91
+ Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
92
+ tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
93
+ vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
94
+ Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
95
+ outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
96
+ transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
97
+ Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
98
+ is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
99
+ transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
100
+ Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
101
+ initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
102
+ """
103
+
104
+ pooler_output: torch.FloatTensor
105
+ late_interaction_output: torch.FloatTensor = None
106
+ context_mask: torch.FloatTensor = None
107
+ text_encoder_attentions: Optional[Tuple[Tensor]] = None
108
+ text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
109
+ vision_encoder_attentions: Optional[Tuple[Tensor]] = None
110
+ vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
111
+ transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
112
+ transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
113
+
114
+
115
+ @dataclass
116
+ class FLMRQueryEncoderOutput(ModelOutput):
117
+ """
118
+ Class for outputs of the `query()` function of [`FLMRModelForRetrieval.query()`].
119
+
120
+ Args:
121
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
122
+ The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the query representation.
123
+ This output can be used to embed questions for nearest neighbors queries with context embeddings.
124
+ late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
125
+ The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
126
+ This output is to be used to embed questions for late-interaction retrieval with context embeddings.
127
+ text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
128
+ Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
129
+ tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
130
+ text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
131
+ Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
132
+ outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
133
+ vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
134
+ Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
135
+ tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
136
+ vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
137
+ Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
138
+ outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
139
+ transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
140
+ Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
141
+ is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
142
+ transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
143
+ Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
144
+ initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
145
+ """
146
+
147
+ pooler_output: torch.FloatTensor
148
+ late_interaction_output: torch.FloatTensor = None
149
+ text_encoder_attentions: Optional[Tuple[Tensor]] = None
150
+ text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
151
+ vision_encoder_attentions: Optional[Tuple[Tensor]] = None
152
+ vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
153
+ transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
154
+ transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
155
+
156
+
157
+ @dataclass
158
+ class FLMRModelForRetrievalOutput(ModelOutput):
159
+ """
160
+ Class for outputs of [`FLMRModelForRetrieval.query()`].
161
+
162
+ Args:
163
+ loss (`torch.FloatTensor`):
164
+ contrastive loss of the input queries and positive and negative examples. This output is to be used in model training.
165
+ scores (`torch.FloatTensor` of shape `(batch_size, num_positive_examples + num_negative_examples)`):
166
+ The FLMR model outputs the *scores* that corresponds to the late-interaction scores of the input query and context. Each query is associated with `num_positive_examples` positive examples and `num_negative_examples` negative examples, and the scores are the late-interaction scores of the query and these examples.
167
+ in_batch_negative_loss (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
168
+ The FLMR model outputs the *in_batch_negative_loss* which computes contrastive loss that includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This output is to be used in model training.
169
+ query_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
170
+ The FLMR model outputs the *query_late_interaction_output* that corresponds to the late-interaction representations of the input query.
171
+ context_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
172
+ The FLMR model outputs the *context_late_interaction_output* that corresponds to the late-interaction representations of the input context.
173
+ query_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
174
+ Tuple of elements containing the attention weights of the query's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
175
+ query_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
176
+ Tuple of elements containing the hidden states of the query's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
177
+ context_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
178
+ Tuple of elements containing the attention weights of the context's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
179
+ context_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
180
+ Tuple of elements containing the hidden states of the context's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
181
+ """
182
+
183
+ loss: torch.FloatTensor
184
+ scores: torch.FloatTensor = None
185
+ in_batch_negative_loss: torch.FloatTensor = None
186
+ query_late_interaction_output: torch.FloatTensor = None
187
+ context_late_interaction_output: torch.FloatTensor = None
188
+ query_attentions: Optional[Tuple[Tuple[Tensor]]] = None
189
+ query_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
190
+ context_attentions: Optional[Tuple[Tuple[Tensor]]] = None
191
+ context_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
192
+
193
+
194
+ class FLMRPreTrainedModel(PreTrainedModel):
195
+ def _init_weights(self, module):
196
+ """Initialize the weights"""
197
+ if isinstance(module, nn.Linear):
198
+ # Slightly different from the TF version which uses truncated_normal for initialization
199
+ # cf https://github.com/pytorch/pytorch/pull/5617
200
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
201
+ if module.bias is not None:
202
+ module.bias.data.zero_()
203
+ elif isinstance(module, nn.Embedding):
204
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
205
+ if module.padding_idx is not None:
206
+ module.weight.data[module.padding_idx].zero_()
207
+ elif isinstance(module, nn.LayerNorm):
208
+ module.bias.data.zero_()
209
+ module.weight.data.fill_(1.0)
210
+
211
+
212
+ ##################
213
+ # PreTrainedModel
214
+ ##################
215
+
216
+
217
+ class FLMRPretrainedModelForRetrieval(FLMRPreTrainedModel):
218
+ """
219
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
220
+ models.
221
+ """
222
+
223
+ config_class = FLMRConfig
224
+ load_tf_weights = None
225
+ base_model_prefix = "flmr"
226
+
227
+
228
+ ###############
229
+ # Actual Models
230
+ ###############
231
+
232
+
233
+ FLMR_START_DOCSTRING = r"""
234
+
235
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
236
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
237
+ etc.)
238
+
239
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
240
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
241
+ and behavior.
242
+
243
+ Parameters:
244
+ config ([`FLMRConfig`]): Model configuration class with all the parameters of the model.
245
+ Initializing with a config file does not load the weights associated with the model, only the
246
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
247
+ query_tokenizer ([`FLMRQueryEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the query.
248
+ The query tokenizer can be initialized with `FLMRQueryEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
249
+ context_tokenizer ([`FLMRContextEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the context.
250
+ The context tokenizer can be initialized with `FLMRContextEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
251
+ """
252
+
253
+
254
+ FLMR_MODEL_INPUTS_DOCSTRING = r"""
255
+ Args:
256
+ query_input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
257
+ Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
258
+ formatted with [CLS] and Q marker tokens as follows:
259
+ [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
260
+
261
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
262
+ rather than the left.
263
+
264
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
265
+ [`PreTrainedTokenizer.__call__`] for details.
266
+
267
+ [What are input IDs?](../glossary#input-ids)
268
+ query_attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
269
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
270
+
271
+ - 1 for tokens that are **not masked**,
272
+ - 0 for tokens that are **masked**.
273
+
274
+ [What are attention masks?](../glossary#attention-mask)
275
+ query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
276
+ Pixel values. Pixel values can be obtained using
277
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
278
+ query_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
279
+ Image features are required when `query_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
280
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
281
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
282
+ Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
283
+ formatted with [CLS] and D marker tokens as follows:
284
+ [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
285
+
286
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
287
+ rather than the left.
288
+
289
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
290
+ [`PreTrainedTokenizer.__call__`] for details.
291
+
292
+ [What are input IDs?](../glossary#input-ids)
293
+
294
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
295
+
296
+ context_attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
297
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
298
+
299
+ - 1 for tokens that are **not masked**,
300
+ - 0 for tokens that are **masked**.
301
+
302
+ [What are attention masks?](../glossary#attention-mask)
303
+
304
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
305
+ context_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
306
+ Pixel values. Pixel values can be obtained using
307
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
308
+ context_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
309
+ Image features are required when `context_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
310
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
311
+ use_in_batch_negatives (`bool`, *optional*):
312
+ Whether or not to use in-batch negatives. If `True`, the contrastive loss includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This input is to be used in model training.
313
+ in_batch_negatives_from_all_gpus (`bool`, *optional*):
314
+ Whether or not to use in-batch negatives from all GPUs. If `True`, the contrastive loss includes in-batch negatives from all GPUs. This input is to be used in model training.
315
+ num_negative_examples (`int`, *optional*):
316
+ The number of negative examples in the batch. For example, if `num_negative_examples` is 4, the batch size of `context_input_ids` and `context_attention_mask` is `batch_size * 5`.
317
+ query_concat_output_from_vision_encoder (`bool`, *optional*):
318
+ Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
319
+ query_concat_output_from_text_encoder (`bool`, *optional*):
320
+ Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
321
+
322
+ This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
323
+ context_concat_output_from_vision_encoder (`bool`, *optional*):
324
+ Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
325
+
326
+ This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
327
+ context_concat_output_from_text_encoder (`bool`, *optional*):
328
+ Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
329
+ return_dict (`bool`, *optional*):
330
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
331
+ output_attentions (`bool`, *optional*):
332
+ Whether or not to return the attentions tensors of all attention layers. See `*_attentions` under returned
333
+ tensors for more detail.
334
+ output_hidden_states (`bool`, *optional*):
335
+ Whether or not to return the hidden states of all layers. See `*_hidden_states` under returned tensors for more detail.
336
+ """
337
+
338
+
339
+ FLMR_MODEL_QUERY_INPUTS_DOCSTRING = r"""
340
+ Args:
341
+ input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
342
+ Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
343
+ formatted with [CLS] and Q marker tokens as follows:
344
+ [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
345
+
346
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
347
+ rather than the left.
348
+
349
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
350
+ [`PreTrainedTokenizer.__call__`] for details.
351
+
352
+ [What are input IDs?](../glossary#input-ids)
353
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
354
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
355
+
356
+ - 1 for tokens that are **not masked**,
357
+ - 0 for tokens that are **masked**.
358
+
359
+ [What are attention masks?](../glossary#attention-mask)
360
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
361
+ Pixel values. Pixel values can be obtained using
362
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
363
+ image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
364
+ Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
365
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
366
+ concat_output_from_vision_encoder (`bool`, *optional*):
367
+ Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
368
+ concat_output_from_text_encoder (`bool`, *optional*):
369
+ Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
370
+
371
+ This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
372
+ """
373
+
374
+
375
+ FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING = r"""
376
+ Args:
377
+ input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
378
+ Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
379
+ formatted with [CLS] and D marker tokens as follows:
380
+ [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
381
+
382
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
383
+ rather than the left.
384
+
385
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
386
+ [`PreTrainedTokenizer.__call__`] for details.
387
+
388
+ [What are input IDs?](../glossary#input-ids)
389
+
390
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
391
+ attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
392
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
393
+
394
+ - 1 for tokens that are **not masked**,
395
+ - 0 for tokens that are **masked**.
396
+
397
+ [What are attention masks?](../glossary#attention-mask)
398
+
399
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
400
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
401
+ Pixel values. Pixel values can be obtained using
402
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
403
+ image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
404
+ Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
405
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel
406
+ .__call__`] for details.
407
+ concat_output_from_vision_encoder (`bool`, *optional*):
408
+ Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
409
+
410
+ This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
411
+ concat_output_from_text_encoder (`bool`, *optional*):
412
+ Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
413
+ keep_dims (`bool`, *optional*):
414
+ Whether or not to keep the dimensions of the output. If `True`, the output is returned with the same dimensions as the input. If `False`, the output is returned with the batch size of the input and the context length. This input is to be used in model training.
415
+ return_mask (`bool`, *optional*):
416
+ Whether or not to return the mask of the context representation. If `True`, the mask of the context representation is returned. This input is to be used in model training.
417
+ """
418
+
419
+
420
+ FLMR_TEXT_ENCODERS_START_DOCSTRING = r"""
421
+
422
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
423
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
424
+ etc.)
425
+
426
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
427
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
428
+ and behavior.
429
+
430
+ Parameters:
431
+ config ([`FLMRTextConfig`]): Model configuration class with all the parameters of the model.
432
+ Initializing with a config file does not load the weights associated with the model, only the
433
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
434
+ """
435
+
436
+
437
+ # Modified from transformers.models.dpr.modeling_dpr with DPR -> FLMR
438
+ FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING = r"""
439
+ Args:
440
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
441
+ Indices of input sequence tokens in the vocabulary. To match pretraining, FLMR input sequence should be
442
+ formatted with [CLS] and [SEP] tokens as follows:
443
+
444
+ (a) For sequence pairs (for a pair title+text for example):
445
+
446
+ ```
447
+ tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
448
+ token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
449
+ ```
450
+
451
+ (b) For single sequences (for a question for example):
452
+
453
+ ```
454
+ tokens: [CLS] the dog is hairy . [SEP]
455
+ token_type_ids: 0 0 0 0 0 0 0
456
+ ```
457
+
458
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
459
+ rather than the left.
460
+
461
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
462
+ [`PreTrainedTokenizer.__call__`] for details.
463
+
464
+ [What are input IDs?](../glossary#input-ids)
465
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
466
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
467
+
468
+ - 1 for tokens that are **not masked**,
469
+ - 0 for tokens that are **masked**.
470
+
471
+ [What are attention masks?](../glossary#attention-mask)
472
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
473
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
474
+ 1]`:
475
+
476
+ - 0 corresponds to a *sentence A* token,
477
+ - 1 corresponds to a *sentence B* token.
478
+
479
+ [What are token type IDs?](../glossary#token-type-ids)
480
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
481
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
482
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
483
+ model's internal embedding lookup matrix.
484
+ output_attentions (`bool`, *optional*):
485
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
486
+ tensors for more detail.
487
+ output_hidden_states (`bool`, *optional*):
488
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
489
+ more detail.
490
+ return_dict (`bool`, *optional*):
491
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
492
+ """
493
+
494
+ FLMR_VISION_ENCODERS_START_DOCSTRING = r"""
495
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
496
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
497
+ etc.)
498
+
499
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
500
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
501
+ and behavior.
502
+
503
+ Parameters:
504
+ config ([`FLMRVisionConfig`]): Model configuration class with all the parameters of the model.
505
+ Initializing with a config file does not load the weights associated with the model, only the
506
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
507
+ """
508
+
509
+ # Modified from transformers.models.clip.modeling_clip with CLIP -> FLMR
510
+ FLMR_VISION_ENCODERS_INPUTS_DOCSTRING = r"""
511
+ Args:
512
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
513
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
514
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
515
+ output_attentions (`bool`, *optional*):
516
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
517
+ tensors for more detail.
518
+ output_hidden_states (`bool`, *optional*):
519
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
520
+ more detail.
521
+ return_dict (`bool`, *optional*):
522
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
523
+ """
524
+
525
+
526
+ class FLMRMultiLayerPerceptron(nn.Module):
527
+ """
528
+ A simple multi-layer perceptron with an activation function. This can be used as the mapping network in the FLMR model.
529
+ """
530
+
531
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
532
+ return self.model(x)
533
+
534
+ def __init__(self, sizes, bias=True, act=nn.Tanh):
535
+ super(FLMRMultiLayerPerceptron, self).__init__()
536
+ layers = []
537
+ for i in range(len(sizes) - 1):
538
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
539
+ if i < len(sizes) - 2:
540
+ layers.append(act())
541
+ self.model = nn.Sequential(*layers)
542
+
543
+
544
+ @add_start_docstrings(
545
+ "The bare FLMR model that can be used to generate late-interaction embeddings for both multi-modal queries and documents. ",
546
+ FLMR_START_DOCSTRING,
547
+ )
548
+ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
549
+ _keys_to_ignore_on_load_unexpected = [r"cls"]
550
+ main_input_name = "query_input_ids"
551
+ _tied_weights_keys = [] # Added dynamically at initialization depending on the architecture
552
+
553
+ def __init__(self, config: FLMRConfig, query_tokenizer=None, context_tokenizer=None):
554
+ super().__init__(config)
555
+ self.config = config
556
+ self.vision_model_version = config.vision_model_version
557
+
558
+ self.context_text_encoder = FLMRTextModel(config.text_config)
559
+ self.context_text_encoder_linear = nn.Linear(config.text_config.hidden_size, config.dim, bias=False)
560
+
561
+ self.query_tokenizer = query_tokenizer
562
+ self.context_tokenizer = context_tokenizer
563
+
564
+ if self.query_tokenizer is None:
565
+ logger.warning(
566
+ "query_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRQueryEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
567
+ )
568
+ from transformers import FLMRQueryEncoderTokenizer
569
+
570
+ # initialize a FLMRQueryEncoderTokenizer
571
+ self.query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained("bert-base-uncased")
572
+
573
+ if self.context_tokenizer is None:
574
+ logger.warning(
575
+ "context_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRContextEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
576
+ )
577
+ from transformers import FLMRContextEncoderTokenizer
578
+
579
+ # initialize a FLMRContextEncoderTokenizer
580
+ self.context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained("bert-base-uncased")
581
+
582
+ self.mapping_network_prefix_length = self.config.mapping_network_prefix_length
583
+ self.vision_encoder_embedding_size = self.config.vision_config.hidden_size
584
+ self.text_encoder_embedding_size = self.config.text_config.hidden_size
585
+ self.late_interaction_embedding_size = self.config.dim
586
+
587
+ self.context_vision_projection = FLMRMultiLayerPerceptron(
588
+ (
589
+ self.vision_encoder_embedding_size,
590
+ (self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2,
591
+ self.late_interaction_embedding_size * self.mapping_network_prefix_length,
592
+ )
593
+ )
594
+
595
+ if self.config.use_vision_encoder:
596
+ self.context_vision_encoder = FLMRVisionModel(config.vision_config)
597
+
598
+ if self.config.use_transformer_mapping_network:
599
+ # This is a PreFLMR style model
600
+ transformer_mapping_config_base = self.config.transformer_mapping_config_base
601
+ try:
602
+ from transformers import BertConfig
603
+ from transformers.models.bert.modeling_bert import BertEncoder
604
+ except Exception as e:
605
+ raise ImportError(f"Failed to import BertConfig and BertEncoder from transformers. {e}")
606
+
607
+ transformer_mapping_config = BertConfig.from_pretrained(transformer_mapping_config_base)
608
+
609
+ assert (
610
+ self.config.text_config.hidden_size == transformer_mapping_config.hidden_size
611
+ ), f"hidden_size {self.config.text_config.hidden_size} != transformer_mapping_config.hidden_size {transformer_mapping_config.hidden_size}. To use cross attention, the dimensions must match."
612
+ # shallow transformer
613
+ transformer_mapping_config.num_hidden_layers = self.config.transformer_mapping_num_hidden_layers
614
+ # add cross attention
615
+ transformer_mapping_config.is_decoder = True
616
+ transformer_mapping_config.add_cross_attention = True
617
+
618
+ # The linear layer from vision encoder to transformer input
619
+ self.transformer_mapping_input_linear = nn.Linear(
620
+ self.vision_encoder_embedding_size, transformer_mapping_config.hidden_size
621
+ )
622
+
623
+ # The transformer encoder
624
+ self.transformer_mapping_network = BertEncoder(transformer_mapping_config)
625
+
626
+ # The linear layer from transformer output to FLMR dim
627
+ self.transformer_mapping_output_linear = nn.Linear(
628
+ transformer_mapping_config.hidden_size, self.late_interaction_embedding_size
629
+ )
630
+
631
+ if self.config.separate_query_and_context_text_encoder:
632
+ self.query_text_encoder = copy.deepcopy(self.context_text_encoder)
633
+ self.query_text_encoder_linear = copy.deepcopy(self.context_text_encoder_linear)
634
+ else:
635
+ self.query_text_encoder = self.context_text_encoder
636
+ self.query_text_encoder_linear = self.context_text_encoder_linear
637
+ self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
638
+
639
+ if self.config.separate_query_and_context_vision_encoder:
640
+ self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder)
641
+ self.query_vision_projection = copy.deepcopy(self.context_vision_projection)
642
+ else:
643
+ self.query_vision_encoder = self.context_vision_encoder
644
+ self.query_vision_projection = self.context_vision_projection
645
+ self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"]
646
+
647
+ if self.config.load_cpu_extension:
648
+ try:
649
+ FLMRModelForRetrieval.try_load_torch_extensions()
650
+ except Exception as e:
651
+ raise(f"Unable to load `segmented_maxsim.cpp`. hf-hub does not download this file automatically. Please download it manually from `https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/blob/main/segmented_maxsim.cpp` and put it under the same folder as the model file.\n {e}")
652
+
653
+ if self.config.mask_punctuation:
654
+ self.skiplist = {
655
+ w: True
656
+ for symbol in string.punctuation
657
+ for w in [symbol, self.context_tokenizer.encode(symbol, add_special_tokens=False)[0]]
658
+ }
659
+
660
+ if self.config.mask_instruction_token is not None:
661
+ self.mask_instruction = True
662
+ # obtain the token id of the instruction token
663
+ self.instruction_token_id = self.query_tokenizer.encode(
664
+ self.config.mask_instruction_token, add_special_tokens=False
665
+ )[0]
666
+ else:
667
+ self.mask_instruction = False
668
+
669
+ self.loss_fn = torch.nn.CrossEntropyLoss()
670
+
671
+ # Initialize weights and apply final processing
672
+ self.post_init()
673
+
674
+ @property
675
+ def use_gpu(self):
676
+ return self.device.type == "cuda"
677
+
678
+ @classmethod
679
+ def from_pretrained(self, name_or_path, **kwargs):
680
+ obj = super().from_pretrained(name_or_path, **kwargs)
681
+ return obj
682
+
683
+ @classmethod
684
+ def try_load_torch_extensions(cls):
685
+ if hasattr(cls, "loaded_extensions"):
686
+ return
687
+
688
+ logger.info(
689
+ "Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)..."
690
+ )
691
+ segmented_maxsim_cpp = load(
692
+ name="segmented_maxsim_cpp",
693
+ sources=[
694
+ os.path.join(pathlib.Path(__file__).parent.resolve(), "segmented_maxsim.cpp"),
695
+ ],
696
+ extra_cflags=["-O3"],
697
+ verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
698
+ )
699
+ cls.segmented_maxsim = segmented_maxsim_cpp.segmented_maxsim_cpp
700
+
701
+ cls.loaded_extensions = True
702
+
703
+ def query_mask(self, input_ids, skiplist):
704
+ if not self.mask_instruction:
705
+ return self.mask(input_ids, skiplist)
706
+
707
+ # find the position of end of instruction in input_ids
708
+ # mask the tokens before the position
709
+ sep_id = self.instruction_token_id
710
+ sep_positions = torch.argmax((input_ids == sep_id).int(), dim=1).tolist()
711
+ # if any of the positions is lower than 1, set to 1
712
+ for i, x in enumerate(sep_positions):
713
+ if x < 1:
714
+ sep_positions[i] = 1
715
+ logger.error(f"can not find the separator in the input_ids: {input_ids[i].tolist()}")
716
+ mask = [
717
+ [
718
+ (x not in skiplist) and (x != 0) and (index > sep_positions[seq_index] or index < 2)
719
+ for index, x in enumerate(d)
720
+ ]
721
+ for seq_index, d in enumerate(input_ids.cpu().tolist())
722
+ ]
723
+ return mask
724
+
725
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_INPUTS_DOCSTRING)
726
+ @replace_return_docstrings(output_type=FLMRModelForRetrievalOutput, config_class=_CONFIG_FOR_DOC)
727
+ def forward(
728
+ self,
729
+ query_input_ids: Optional[torch.Tensor] = None,
730
+ query_attention_mask: Optional[torch.Tensor] = None,
731
+ query_pixel_values: Optional[torch.Tensor] = None,
732
+ query_image_features: Optional[torch.Tensor] = None,
733
+ context_input_ids: Optional[torch.Tensor] = None,
734
+ context_attention_mask: Optional[torch.Tensor] = None,
735
+ context_pixel_values: Optional[torch.Tensor] = None,
736
+ context_image_features: Optional[torch.Tensor] = None,
737
+ use_in_batch_negatives: bool = True,
738
+ in_batch_negatives_from_all_gpus: bool = False,
739
+ num_negative_examples: int = 1,
740
+ query_concat_output_from_vision_encoder: Optional[bool] = None,
741
+ query_concat_output_from_text_encoder: Optional[bool] = None,
742
+ context_concat_output_from_vision_encoder: Optional[bool] = None,
743
+ context_concat_output_from_text_encoder: Optional[bool] = None,
744
+ return_dict: bool = None,
745
+ output_attentions: bool = None,
746
+ output_hidden_states: bool = None,
747
+ ) -> Union[FLMRModelForRetrievalOutput, Tuple[Tensor, ...]]:
748
+ r"""
749
+ Return:
750
+
751
+ Examples:
752
+
753
+ ```python
754
+ >>> import torch
755
+ >>> from transformers import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRModelForRetrieval, AutoImageProcessor
756
+
757
+ >>> checkpoint_path = "LinWeizheDragon/PreFLMR_ViT-L"
758
+ >>> image_processor_name = "openai/clip-vit-large-patch14"
759
+ >>> query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="query_tokenizer")
760
+ >>> context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="context_tokenizer")
761
+
762
+ >>> model = FLMRModelForRetrieval.from_pretrained(checkpoint_path,
763
+ query_tokenizer=query_tokenizer,
764
+ context_tokenizer=context_tokenizer,
765
+ )
766
+ >>> image_processor = AutoImageProcessor.from_pretrained(image_processor_name)
767
+
768
+ >>> Q_encoding = query_tokenizer(["Using the provided image, obtain documents that address the subsequent question: What is the capital of France?", "Extract documents linked to the question provided in conjunction with the image: What is the capital of China?"])
769
+ >>> D_encoding = context_tokenizer(["Paris is the capital of France.", "Beijing is the capital of China.",
770
+ "Paris is the capital of France.", "Beijing is the capital of China."])
771
+ >>> Q_pixel_values = torch.zeros(2, 3, 224, 224)
772
+ >>> inputs = dict(
773
+ query_input_ids=Q_encoding['input_ids'],
774
+ query_attention_mask=Q_encoding['attention_mask'],
775
+ query_pixel_values=Q_pixel_values,
776
+ context_input_ids=D_encoding['input_ids'],
777
+ context_attention_mask=D_encoding['attention_mask'],
778
+ use_in_batch_negatives=True,
779
+ )
780
+
781
+ >>> model.forward(**inputs)
782
+ FLMRModelForRetrievalOutput(loss=tensor(4.5000, device='cuda:0', dtype=torch.float16,
783
+ grad_fn=<NllLossBackward0>), scores=tensor([[44.2188, 40.6562],
784
+ [39.4375, 48.4062]], device='cuda:0', dtype=torch.float16,
785
+ grad_fn=<ViewBackward0>), in_batch_negative_loss=tensor(5.1994, device='cuda:0', grad_fn=<NllLossBackward0>), query_late_interaction_output=tensor(...), context_late_interaction_output=tensor(...)
786
+ ```
787
+ """
788
+
789
+ if query_concat_output_from_vision_encoder is None:
790
+ query_concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
791
+
792
+ if query_concat_output_from_text_encoder is None:
793
+ query_concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
794
+
795
+ if context_concat_output_from_vision_encoder is None:
796
+ context_concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
797
+
798
+ if context_concat_output_from_text_encoder is None:
799
+ context_concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
800
+
801
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
802
+ output_hidden_states = (
803
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
804
+ )
805
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
806
+
807
+ query_outputs = self.query(
808
+ input_ids=query_input_ids,
809
+ attention_mask=query_attention_mask,
810
+ pixel_values=query_pixel_values,
811
+ image_features=query_image_features,
812
+ concat_output_from_vision_encoder=query_concat_output_from_vision_encoder,
813
+ concat_output_from_text_encoder=query_concat_output_from_text_encoder,
814
+ output_attentions=output_attentions,
815
+ output_hidden_states=output_hidden_states,
816
+ )
817
+ Q = query_outputs.late_interaction_output
818
+
819
+ context_outputs = self.doc(
820
+ input_ids=context_input_ids,
821
+ attention_mask=context_attention_mask,
822
+ pixel_values=context_pixel_values,
823
+ image_features=context_image_features,
824
+ concat_output_from_vision_encoder=context_concat_output_from_vision_encoder,
825
+ concat_output_from_text_encoder=context_concat_output_from_text_encoder,
826
+ keep_dims=True,
827
+ return_mask=True,
828
+ output_attentions=output_attentions,
829
+ output_hidden_states=output_hidden_states,
830
+ )
831
+ D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask
832
+
833
+ # Gather tensors from other GPUs
834
+ if in_batch_negatives_from_all_gpus:
835
+ Q, D, D_mask = self.gather_tensors_from_other_gpus(Q, D, D_mask)
836
+ # Repeat each query encoding for every corresponding document.
837
+ Q_duplicated = Q.repeat_interleave(num_negative_examples + 1, dim=0).contiguous()
838
+
839
+ scores = self.score(Q_duplicated, D, D_mask)
840
+
841
+ # Use contrastive learning
842
+ batch_size = query_input_ids.shape[0]
843
+ scores = scores.view(-1, num_negative_examples + 1)
844
+ labels = torch.zeros(batch_size, dtype=torch.long, device=self.device)
845
+ loss = self.loss_fn(scores, labels)
846
+
847
+ if use_in_batch_negatives:
848
+ ib_loss = self.compute_ib_loss_new(Q, D, D_mask)
849
+ else:
850
+ ib_loss = None
851
+
852
+ if output_attentions:
853
+ query_attentions = (
854
+ query_outputs.text_encoder_attentions if query_outputs.text_encoder_attentions is not None else None,
855
+ query_outputs.vision_encoder_attentions
856
+ if query_outputs.vision_encoder_attentions is not None
857
+ else None,
858
+ query_outputs.transformer_mapping_network_attentions
859
+ if query_outputs.transformer_mapping_network_attentions is not None
860
+ else None,
861
+ )
862
+ context_attentions = (
863
+ context_outputs.text_encoder_attentions
864
+ if context_outputs.text_encoder_attentions is not None
865
+ else None,
866
+ context_outputs.vision_encoder_attentions
867
+ if context_outputs.vision_encoder_attentions is not None
868
+ else None,
869
+ context_outputs.transformer_mapping_network_attentions
870
+ if context_outputs.transformer_mapping_network_attentions is not None
871
+ else None,
872
+ )
873
+ else:
874
+ query_attentions = None
875
+ context_attentions = None
876
+
877
+ if output_hidden_states:
878
+ query_hidden_states = (
879
+ query_outputs.text_encoder_hidden_states
880
+ if query_outputs.text_encoder_hidden_states is not None
881
+ else None,
882
+ query_outputs.vision_encoder_hidden_states
883
+ if query_outputs.vision_encoder_hidden_states is not None
884
+ else None,
885
+ query_outputs.transformer_mapping_network_hidden_states
886
+ if query_outputs.transformer_mapping_network_hidden_states is not None
887
+ else None,
888
+ )
889
+ context_hidden_states = (
890
+ context_outputs.text_encoder_hidden_states
891
+ if context_outputs.text_encoder_hidden_states is not None
892
+ else None,
893
+ context_outputs.vision_encoder_hidden_states
894
+ if context_outputs.vision_encoder_hidden_states is not None
895
+ else None,
896
+ context_outputs.transformer_mapping_network_hidden_states
897
+ if context_outputs.transformer_mapping_network_hidden_states is not None
898
+ else None,
899
+ )
900
+ else:
901
+ query_hidden_states = None
902
+ context_hidden_states = None
903
+
904
+ if not return_dict:
905
+ if output_attentions and output_hidden_states:
906
+ return (
907
+ loss,
908
+ scores,
909
+ ib_loss,
910
+ query_outputs.late_interaction_output,
911
+ context_outputs.late_interaction_output,
912
+ query_attentions,
913
+ query_hidden_states,
914
+ context_attentions,
915
+ context_hidden_states,
916
+ )
917
+ elif output_attentions:
918
+ return (
919
+ loss,
920
+ scores,
921
+ ib_loss,
922
+ query_outputs.late_interaction_output,
923
+ context_outputs.late_interaction_output,
924
+ query_attentions,
925
+ context_attentions,
926
+ )
927
+ elif output_hidden_states:
928
+ return (
929
+ loss,
930
+ scores,
931
+ ib_loss,
932
+ query_outputs.late_interaction_output,
933
+ context_outputs.late_interaction_output,
934
+ query_hidden_states,
935
+ context_hidden_states,
936
+ )
937
+ else:
938
+ return (
939
+ loss,
940
+ scores,
941
+ ib_loss,
942
+ query_outputs.late_interaction_output,
943
+ context_outputs.late_interaction_output,
944
+ )
945
+
946
+ return FLMRModelForRetrievalOutput(
947
+ loss=loss,
948
+ scores=scores,
949
+ in_batch_negative_loss=ib_loss,
950
+ query_late_interaction_output=query_outputs.late_interaction_output,
951
+ context_late_interaction_output=context_outputs.late_interaction_output,
952
+ query_attentions=query_attentions if output_attentions else None,
953
+ query_hidden_states=query_hidden_states if output_hidden_states else None,
954
+ context_attentions=context_attentions if output_attentions else None,
955
+ context_hidden_states=context_hidden_states if output_hidden_states else None,
956
+ )
957
+
958
+ def compute_ib_loss_new(self, Q: torch.Tensor, D: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor:
959
+ # Q: batch_size x q_len x dim
960
+ # D: batch_size*n_docs x i_len x dim
961
+ # D_mask: batch_size*n_docs x i_len x dim
962
+ # 1 x batch_size*n_docs x i_len x dim matmul batch_size x 1 x q_len x dim
963
+ # = batch_size x batch_size*n_docs x i_len x q_len
964
+
965
+ scores = (D.float().unsqueeze(0) @ Q.float().permute(0, 2, 1).unsqueeze(1)).flatten(
966
+ 0, 1
967
+ ) # query-major unsqueeze
968
+ scores = colbert_score_reduce(scores, D_mask.repeat(Q.size(0), 1, 1))
969
+
970
+ in_batch_scores = scores.reshape(Q.size(0), -1)
971
+
972
+ batch_size = Q.shape[0]
973
+ batch_size_with_pos_and_neg = D.shape[0]
974
+ num_pos_and_neg = batch_size_with_pos_and_neg // batch_size
975
+
976
+ # batch_size x dim matmul dim x (num_pos+num_neg)*batch_size
977
+ # --> batch_size x (num_pos+num_neg)*batch_size
978
+ in_batch_labels = torch.zeros(batch_size, batch_size_with_pos_and_neg).to(scores.device)
979
+ step = num_pos_and_neg
980
+ for i in range(batch_size):
981
+ in_batch_labels[i, step * i] = 1
982
+ # print('in_batch_labels', in_batch_labels)
983
+ in_batch_labels = torch.argmax(in_batch_labels, dim=1)
984
+ # print('in_batch_labels', in_batch_labels)
985
+
986
+ loss = self.loss_fn(in_batch_scores, in_batch_labels)
987
+
988
+ return loss
989
+
990
+ def gather_tensors_from_other_gpus(self, query_embeddings, item_embeddings, item_mask):
991
+ # print("get rank", get_rank())
992
+ # print("get world size", get_world_size())
993
+ # Gather embeddings from other GPUs
994
+ n_nodes = get_world_size()
995
+ if n_nodes == 1:
996
+ return query_embeddings, item_embeddings, item_mask
997
+ # Create placeholder to hold embeddings passed from other ranks
998
+ global_query_embeddings_placeholder = [
999
+ torch.zeros(*query_embeddings.shape, dtype=query_embeddings.dtype).to(query_embeddings.device)
1000
+ for _ in range(n_nodes)
1001
+ ]
1002
+ global_item_embeddings_placeholder = [
1003
+ torch.zeros(*item_embeddings.shape, dtype=item_embeddings.dtype).to(item_embeddings.device)
1004
+ for _ in range(n_nodes)
1005
+ ]
1006
+ global_item_mask_placeholder = [
1007
+ torch.zeros(*item_mask.shape, dtype=item_mask.dtype).to(item_mask.device) for _ in range(n_nodes)
1008
+ ]
1009
+ dist.all_gather(global_query_embeddings_placeholder, query_embeddings.detach())
1010
+ dist.all_gather(global_item_embeddings_placeholder, item_embeddings.detach())
1011
+ dist.all_gather(global_item_mask_placeholder, item_mask.detach())
1012
+
1013
+ global_query_embeddings = []
1014
+ global_item_embeddings = []
1015
+ global_item_mask = []
1016
+ # print(f"rank {get_rank()} global_query_embeddings", global_query_embeddings)
1017
+ # print(f"rank {get_rank()} global_item_embeddings", global_item_embeddings)
1018
+ # input()
1019
+ current_rank = get_rank()
1020
+ for rank_index, remote_q_embeddings in enumerate(global_query_embeddings_placeholder):
1021
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1022
+ if rank_index != current_rank:
1023
+ global_query_embeddings.append(remote_q_embeddings)
1024
+ else:
1025
+ global_query_embeddings.append(query_embeddings)
1026
+
1027
+ for rank_index, remote_item_embeddings in enumerate(global_item_embeddings_placeholder):
1028
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1029
+ if rank_index != current_rank:
1030
+ global_item_embeddings.append(remote_item_embeddings)
1031
+ else:
1032
+ global_item_embeddings.append(item_embeddings)
1033
+
1034
+ for rank_index, remote_item_mask in enumerate(global_item_mask_placeholder):
1035
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1036
+ if rank_index != current_rank:
1037
+ global_item_mask.append(remote_item_mask)
1038
+ else:
1039
+ global_item_mask.append(item_mask)
1040
+
1041
+ # Replace the previous variables with gathered tensors
1042
+ query_embeddings = torch.cat(global_query_embeddings)
1043
+ item_embeddings = torch.cat(global_item_embeddings)
1044
+ item_mask = torch.cat(global_item_mask)
1045
+
1046
+ return query_embeddings, item_embeddings, item_mask
1047
+
1048
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_QUERY_INPUTS_DOCSTRING)
1049
+ @replace_return_docstrings(output_type=FLMRQueryEncoderOutput, config_class=_CONFIG_FOR_DOC)
1050
+ def query(
1051
+ self,
1052
+ input_ids: torch.Tensor,
1053
+ attention_mask: torch.Tensor,
1054
+ pixel_values: Optional[torch.Tensor] = None,
1055
+ image_features: Optional[torch.Tensor] = None,
1056
+ concat_output_from_vision_encoder: Optional[bool] = None,
1057
+ concat_output_from_text_encoder: Optional[bool] = None,
1058
+ output_attentions: Optional[bool] = None,
1059
+ output_hidden_states: Optional[bool] = None,
1060
+ ):
1061
+ r"""
1062
+ Returns:
1063
+
1064
+ """
1065
+
1066
+ if concat_output_from_vision_encoder is None:
1067
+ concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
1068
+
1069
+ if concat_output_from_text_encoder is None:
1070
+ concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
1071
+
1072
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1073
+ output_hidden_states = (
1074
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1075
+ )
1076
+
1077
+ input_modality = []
1078
+ if pixel_values is not None or image_features is not None:
1079
+ input_modality.append("image")
1080
+ if input_ids is not None and attention_mask is not None:
1081
+ input_modality.append("text")
1082
+
1083
+ text_encoder_outputs = None
1084
+ vision_encoder_outputs = None
1085
+ transformer_mapping_outputs = None
1086
+
1087
+ if "image" in input_modality:
1088
+ assert (
1089
+ pixel_values is not None or image_features is not None
1090
+ ), "pixel_values or image_features must be provided if image modality is used"
1091
+ assert (
1092
+ pixel_values is None or image_features is None
1093
+ ), "pixel_values and image_features cannot be provided at the same time"
1094
+
1095
+ if "text" in input_modality:
1096
+ assert (
1097
+ input_ids is not None and attention_mask is not None
1098
+ ), "input_ids and attention_mask must be provided if text modality is used"
1099
+ # Forward the text encoder
1100
+ input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
1101
+ text_encoder_outputs = self.query_text_encoder(input_ids, attention_mask=attention_mask)
1102
+ text_encoder_hidden_states = text_encoder_outputs[0]
1103
+ text_embeddings = self.query_text_encoder_linear(text_encoder_hidden_states)
1104
+ mask = torch.tensor(self.query_mask(input_ids, skiplist=[]), device=self.device).unsqueeze(2).float()
1105
+
1106
+ text_embeddings = text_embeddings * mask
1107
+
1108
+ if "image" in input_modality:
1109
+ if pixel_values is not None:
1110
+ batch_size = pixel_values.shape[0]
1111
+ # Forward the vision encoder
1112
+ pixel_values = pixel_values.to(self.device)
1113
+ if len(pixel_values.shape) == 5:
1114
+ # Multiple ROIs are provided
1115
+ # merge the first two dimensions
1116
+ pixel_values = pixel_values.reshape(
1117
+ -1, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4]
1118
+ )
1119
+ vision_encoder_outputs = self.query_vision_encoder(pixel_values, output_hidden_states=True)
1120
+ vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
1121
+
1122
+ if image_features is not None:
1123
+ batch_size = image_features.shape[0]
1124
+ vision_embeddings = image_features.to(self.device)
1125
+
1126
+ # Forward the vision projection / mapping network
1127
+ vision_embeddings = self.query_vision_projection(vision_embeddings)
1128
+ vision_embeddings = vision_embeddings.view(batch_size, -1, self.late_interaction_embedding_size)
1129
+
1130
+ if self.config.use_transformer_mapping_network:
1131
+ # select the second last layer
1132
+ vision_second_last_layer_hidden_states = vision_encoder_outputs.hidden_states[-2][:, 1:]
1133
+ # transformer_mapping
1134
+ transformer_mapping_input_features = self.transformer_mapping_input_linear(
1135
+ vision_second_last_layer_hidden_states
1136
+ )
1137
+
1138
+ # Cross attention only attends to the first 32 tokens
1139
+ encoder_mask = torch.ones_like(mask).to(mask.device, dtype=mask.dtype)
1140
+ cross_attention_length = self.config.transformer_mapping_cross_attention_length
1141
+ if text_encoder_hidden_states.shape[1] > cross_attention_length:
1142
+ text_encoder_hidden_states = text_encoder_hidden_states[:, :cross_attention_length]
1143
+ encoder_mask = encoder_mask[:, :cross_attention_length]
1144
+
1145
+ # Obtain cross attention mask
1146
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_mask.squeeze(-1))
1147
+ # Pass through the transformer mapping
1148
+ transformer_mapping_outputs = self.transformer_mapping_network(
1149
+ transformer_mapping_input_features,
1150
+ encoder_hidden_states=text_encoder_hidden_states,
1151
+ encoder_attention_mask=encoder_extended_attention_mask,
1152
+ )
1153
+ transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state
1154
+ # Convert the dimension to FLMR dim
1155
+ transformer_mapping_output_features = self.transformer_mapping_output_linear(
1156
+ transformer_mapping_output_features
1157
+ )
1158
+ # Merge with the vision embeddings
1159
+ vision_embeddings = torch.cat([vision_embeddings, transformer_mapping_output_features], dim=1)
1160
+
1161
+ if concat_output_from_vision_encoder and concat_output_from_text_encoder:
1162
+ Q = torch.cat([text_embeddings, vision_embeddings], dim=1)
1163
+ elif concat_output_from_vision_encoder:
1164
+ Q = vision_embeddings
1165
+ elif concat_output_from_text_encoder:
1166
+ Q = text_embeddings
1167
+
1168
+ vision_encoder_attentions = (
1169
+ vision_encoder_outputs.attentions
1170
+ if vision_encoder_outputs is not None
1171
+ and hasattr(vision_encoder_outputs, "attentions")
1172
+ and output_attentions
1173
+ else None
1174
+ )
1175
+ vision_encoder_hidden_states = (
1176
+ vision_encoder_outputs.hidden_states
1177
+ if vision_encoder_outputs is not None
1178
+ and hasattr(vision_encoder_outputs, "hidden_states")
1179
+ and output_hidden_states
1180
+ else None
1181
+ )
1182
+ text_encoder_attentions = (
1183
+ text_encoder_outputs.attentions
1184
+ if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
1185
+ else None
1186
+ )
1187
+ text_encoder_hidden_states = (
1188
+ text_encoder_outputs.hidden_states
1189
+ if text_encoder_outputs is not None
1190
+ and hasattr(text_encoder_outputs, "hidden_states")
1191
+ and output_hidden_states
1192
+ else None
1193
+ )
1194
+ transformer_mapping_network_attentions = (
1195
+ transformer_mapping_outputs.attentions
1196
+ if transformer_mapping_outputs is not None
1197
+ and hasattr(transformer_mapping_outputs, "attentions")
1198
+ and output_attentions
1199
+ else None
1200
+ )
1201
+ transformer_mapping_network_hidden_states = (
1202
+ transformer_mapping_outputs.hidden_states
1203
+ if transformer_mapping_outputs is not None
1204
+ and hasattr(transformer_mapping_outputs, "hidden_states")
1205
+ and output_hidden_states
1206
+ else None
1207
+ )
1208
+
1209
+ return FLMRQueryEncoderOutput(
1210
+ pooler_output=Q[:, 0, :],
1211
+ late_interaction_output=torch.nn.functional.normalize(Q, p=2, dim=2),
1212
+ vision_encoder_attentions=vision_encoder_attentions,
1213
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1214
+ text_encoder_attentions=text_encoder_attentions,
1215
+ text_encoder_hidden_states=text_encoder_hidden_states,
1216
+ transformer_mapping_network_attentions=transformer_mapping_network_attentions,
1217
+ transformer_mapping_network_hidden_states=transformer_mapping_network_hidden_states,
1218
+ )
1219
+
1220
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING)
1221
+ @replace_return_docstrings(output_type=FLMRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
1222
+ def doc(
1223
+ self,
1224
+ input_ids: torch.Tensor,
1225
+ attention_mask: torch.Tensor,
1226
+ pixel_values: Optional[torch.Tensor] = None,
1227
+ image_features: Optional[torch.Tensor] = None,
1228
+ concat_output_from_vision_encoder: Optional[bool] = None,
1229
+ concat_output_from_text_encoder: Optional[bool] = None,
1230
+ keep_dims: Optional[bool] = True,
1231
+ return_mask: Optional[bool] = True,
1232
+ output_attentions: Optional[bool] = None,
1233
+ output_hidden_states: Optional[bool] = None,
1234
+ ):
1235
+ r"""
1236
+ Returns:
1237
+
1238
+ """
1239
+ assert keep_dims in [True, False]
1240
+
1241
+ if concat_output_from_vision_encoder is None:
1242
+ concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
1243
+
1244
+ if concat_output_from_text_encoder is None:
1245
+ concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
1246
+
1247
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1248
+ output_hidden_states = (
1249
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1250
+ )
1251
+
1252
+ input_modality = []
1253
+ if pixel_values is not None or image_features is not None:
1254
+ input_modality.append("image")
1255
+ if input_ids is not None and attention_mask is not None:
1256
+ input_modality.append("text")
1257
+
1258
+ text_encoder_outputs = None
1259
+ vision_encoder_outputs = None
1260
+
1261
+ if "image" in input_modality:
1262
+ assert (
1263
+ pixel_values is not None or image_features is not None
1264
+ ), "pixel_values or image_features must be provided if image modality is used"
1265
+ assert (
1266
+ pixel_values is None or image_features is None
1267
+ ), "pixel_values and image_features cannot be provided at the same time"
1268
+
1269
+ if "text" in input_modality:
1270
+ assert (
1271
+ input_ids is not None and attention_mask is not None
1272
+ ), "input_ids and attention_mask must be provided if text modality is used"
1273
+ # Forward the text encoder
1274
+ input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
1275
+ text_encoder_outputs = self.context_text_encoder(input_ids, attention_mask=attention_mask)
1276
+ text_embeddings = text_encoder_outputs[0]
1277
+ text_embeddings = self.context_text_encoder_linear(text_embeddings)
1278
+
1279
+ mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float()
1280
+ text_embeddings = text_embeddings * mask
1281
+
1282
+ if "image" in input_modality:
1283
+ if pixel_values is not None:
1284
+ # Forward the vision encoder
1285
+ pixel_values = pixel_values.to(self.device)
1286
+ vision_encoder_outputs = self.context_vision_encoder(pixel_values)
1287
+ vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
1288
+
1289
+ if image_features is not None:
1290
+ vision_embeddings = image_features.to(self.device)
1291
+
1292
+ batch_size = vision_embeddings.shape[0]
1293
+
1294
+ # Forward the vision projection / mapping network
1295
+ vision_embeddings = self.context_vision_projection(vision_embeddings)
1296
+ vision_embeddings = vision_embeddings.view(
1297
+ -1, self.mapping_network_prefix_length, self.late_interaction_embedding_size
1298
+ )
1299
+
1300
+ image_mask = torch.ones(batch_size, vision_embeddings.shape[1], 1).to(self.device)
1301
+
1302
+ if concat_output_from_vision_encoder and concat_output_from_text_encoder:
1303
+ # Note: vision embeddings must be in the front since the ColBERT engine only indexes embeddings up to number of 1's in the mask
1304
+ # TODO: fix the engine to support masks with discontinuous 0 and 1.
1305
+ D = torch.cat([vision_embeddings, text_embeddings], dim=1)
1306
+ # concatenate the mask
1307
+ mask = torch.cat([mask, image_mask], dim=1)
1308
+ elif concat_output_from_vision_encoder:
1309
+ D = vision_embeddings
1310
+ mask = image_mask
1311
+ elif concat_output_from_text_encoder:
1312
+ D = text_embeddings
1313
+ mask = mask
1314
+
1315
+ D = torch.nn.functional.normalize(D, p=2, dim=2)
1316
+
1317
+ if self.use_gpu:
1318
+ D = D.half()
1319
+
1320
+ if keep_dims is False:
1321
+ D, mask = D.cpu(), mask.bool().cpu().squeeze(-1)
1322
+ D = [d[mask[idx]] for idx, d in enumerate(D)]
1323
+
1324
+ vision_encoder_attentions = (
1325
+ vision_encoder_outputs.attentions
1326
+ if vision_encoder_outputs is not None
1327
+ and hasattr(vision_encoder_outputs, "attentions")
1328
+ and output_attentions
1329
+ else None
1330
+ )
1331
+ vision_encoder_hidden_states = (
1332
+ vision_encoder_outputs.hidden_states
1333
+ if vision_encoder_outputs is not None
1334
+ and hasattr(vision_encoder_outputs, "hidden_states")
1335
+ and output_hidden_states
1336
+ else None
1337
+ )
1338
+ text_encoder_attentions = (
1339
+ text_encoder_outputs.attentions
1340
+ if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
1341
+ else None
1342
+ )
1343
+ text_encoder_hidden_states = (
1344
+ text_encoder_outputs.hidden_states
1345
+ if text_encoder_outputs is not None
1346
+ and hasattr(text_encoder_outputs, "hidden_states")
1347
+ and output_hidden_states
1348
+ else None
1349
+ )
1350
+
1351
+ return FLMRContextEncoderOutput(
1352
+ pooler_output=D[:, 0, :],
1353
+ late_interaction_output=D,
1354
+ context_mask=mask.bool() if return_mask else None,
1355
+ vision_encoder_attentions=vision_encoder_attentions,
1356
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1357
+ text_encoder_attentions=text_encoder_attentions,
1358
+ text_encoder_hidden_states=text_encoder_hidden_states,
1359
+ )
1360
+
1361
+ def score(self, Q, D_padded, D_mask):
1362
+ # assert self.colbert_config.similarity == 'cosine'
1363
+ # if self.colbert_config.similarity == 'l2':
1364
+ # assert self.colbert_config.interaction == 'colbert'
1365
+ # return (-1.0 * ((Q.unsqueeze(2) - D_padded.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
1366
+ return colbert_score(Q, D_padded, D_mask, use_gpu=self.use_gpu)
1367
+
1368
+ def mask(self, input_ids, skiplist):
1369
+ mask = [[(x not in skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
1370
+ return mask
1371
+
1372
+
1373
+ @add_start_docstrings(
1374
+ "The bare FLMR text encoder that can be used to generate late-interaction embeddings for texts in queries and contexts. This model is based on a `BertModel`. It can be used like a `BertModel` model for encoding text.",
1375
+ FLMR_TEXT_ENCODERS_START_DOCSTRING,
1376
+ )
1377
+ class FLMRTextModel(FLMRPreTrainedModel):
1378
+ base_model_prefix = "bert_model"
1379
+ config_class = FLMRTextConfig
1380
+
1381
+ def __init__(self, config: FLMRTextConfig, *args, **kwargs):
1382
+ super().__init__(config)
1383
+ self.bert_model = BertModel(config, add_pooling_layer=True)
1384
+ if self.bert_model.config.hidden_size <= 0:
1385
+ raise ValueError("Encoder hidden_size can't be zero")
1386
+ self.projection_dim = config.projection_dim
1387
+ if self.projection_dim > 0:
1388
+ self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
1389
+ # Initialize weights and apply final processing
1390
+ self.post_init()
1391
+
1392
+ @add_start_docstrings_to_model_forward(FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING)
1393
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRTextConfig)
1394
+ def forward(
1395
+ self,
1396
+ input_ids: Optional[Tensor] = None,
1397
+ attention_mask: Optional[Tensor] = None,
1398
+ token_type_ids: Optional[Tensor] = None,
1399
+ inputs_embeds: Optional[Tensor] = None,
1400
+ output_attentions: bool = None,
1401
+ output_hidden_states: bool = None,
1402
+ return_dict: bool = None,
1403
+ ) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]:
1404
+ r"""
1405
+ Returns:
1406
+
1407
+ """
1408
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1409
+ output_hidden_states = (
1410
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1411
+ )
1412
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1413
+
1414
+ outputs = self.bert_model(
1415
+ input_ids=input_ids,
1416
+ attention_mask=attention_mask,
1417
+ token_type_ids=token_type_ids,
1418
+ inputs_embeds=inputs_embeds,
1419
+ output_attentions=output_attentions,
1420
+ output_hidden_states=output_hidden_states,
1421
+ return_dict=return_dict,
1422
+ )
1423
+ sequence_output = outputs[0]
1424
+ pooled_output = sequence_output[:, 0, :]
1425
+
1426
+ if self.projection_dim > 0:
1427
+ pooled_output = self.encode_proj(pooled_output)
1428
+
1429
+ if not return_dict:
1430
+ return (sequence_output, pooled_output) + outputs[2:]
1431
+
1432
+ return BaseModelOutputWithPooling(
1433
+ last_hidden_state=sequence_output,
1434
+ pooler_output=pooled_output,
1435
+ hidden_states=outputs.hidden_states,
1436
+ attentions=outputs.attentions,
1437
+ )
1438
+
1439
+ @property
1440
+ def embeddings_size(self) -> int:
1441
+ if self.projection_dim > 0:
1442
+ return self.encode_proj.out_features
1443
+ return self.bert_model.config.hidden_size
1444
+
1445
+
1446
+ @add_start_docstrings(
1447
+ "The bare FLMR vision encoder that can be used to generate late-interaction embeddings for images in queries and contexts. This model is based on a `CLIPVisionModel`. It can be used like a `CLIPVisionModel` model for encoding images.",
1448
+ FLMR_VISION_ENCODERS_START_DOCSTRING,
1449
+ )
1450
+ class FLMRVisionModel(FLMRPreTrainedModel):
1451
+ base_model_prefix = "vision_model"
1452
+ config_class = FLMRVisionConfig
1453
+ main_input_name = "pixel_values"
1454
+ _no_split_modules = ["CLIPEncoderLayer"]
1455
+
1456
+ def __init__(self, config: FLMRVisionConfig):
1457
+ super().__init__(config)
1458
+ self.vision_model = CLIPVisionModel(config)
1459
+ self.post_init()
1460
+
1461
+ def get_input_embeddings(self) -> nn.Module:
1462
+ return self.vision_model.vision_model.embeddings.patch_embedding
1463
+
1464
+ @add_start_docstrings_to_model_forward(FLMR_VISION_ENCODERS_INPUTS_DOCSTRING)
1465
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRVisionConfig)
1466
+ def forward(
1467
+ self,
1468
+ pixel_values: Optional[torch.FloatTensor] = None,
1469
+ output_attentions: Optional[bool] = None,
1470
+ output_hidden_states: Optional[bool] = None,
1471
+ return_dict: Optional[bool] = None,
1472
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1473
+ r"""
1474
+ Returns:
1475
+
1476
+ Examples:
1477
+
1478
+ ```python
1479
+ >>> from PIL import Image
1480
+ >>> import requests
1481
+ >>> from transformers import AutoProcessor, FLMRVisionModel
1482
+
1483
+ >>> model = FLMRVisionModel.from_pretrained("openai/clip-vit-base-patch32")
1484
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1485
+
1486
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1487
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1488
+
1489
+ >>> inputs = processor(images=image, return_tensors="pt")
1490
+
1491
+ >>> outputs = model(**inputs)
1492
+ >>> last_hidden_state = outputs.last_hidden_state
1493
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
1494
+ ```"""
1495
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1496
+
1497
+ return self.vision_model(
1498
+ pixel_values=pixel_values,
1499
+ output_attentions=output_attentions,
1500
+ output_hidden_states=output_hidden_states,
1501
+ return_dict=return_dict,
1502
+ )
query_tokenizer/added_tokens.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<BOC>": 30527,
3
+ "<BOK>": 30529,
4
+ "<BOQ>": 30525,
5
+ "<BOV>": 30522,
6
+ "<EOC>": 30528,
7
+ "<EOK>": 30530,
8
+ "<EOQ>": 30526,
9
+ "<EOV>": 30524,
10
+ "<SOV>": 30523
11
+ }
query_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<BOV>",
4
+ "<SOV>",
5
+ "<EOV>",
6
+ "<BOQ>",
7
+ "<EOQ>",
8
+ "<BOC>",
9
+ "<EOC>",
10
+ "<BOK>",
11
+ "<EOK>"
12
+ ],
13
+ "cls_token": {
14
+ "content": "[CLS]",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "mask_token": {
21
+ "content": "[MASK]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "pad_token": {
28
+ "content": "[PAD]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ },
34
+ "sep_token": {
35
+ "content": "[SEP]",
36
+ "lstrip": false,
37
+ "normalized": false,
38
+ "rstrip": false,
39
+ "single_word": false
40
+ },
41
+ "unk_token": {
42
+ "content": "[UNK]",
43
+ "lstrip": false,
44
+ "normalized": false,
45
+ "rstrip": false,
46
+ "single_word": false
47
+ }
48
+ }
query_tokenizer/tokenization_flmr.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
27
+
28
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "LinWeizheDragon/PreFLMR_ViT-L": (
31
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
32
+ ),
33
+ "LinWeizheDragon/FLMR": (
34
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
35
+ ),
36
+ },
37
+ "tokenizer_file": {
38
+ "LinWeizheDragon/PreFLMR_ViT-L": (
39
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
40
+ ),
41
+ "LinWeizheDragon/FLMR": (
42
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
43
+ ),
44
+ },
45
+ }
46
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
47
+ "vocab_file": {
48
+ "LinWeizheDragon/PreFLMR_ViT-L": (
49
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
50
+ ),
51
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
52
+ },
53
+ "tokenizer_file": {
54
+ "LinWeizheDragon/PreFLMR_ViT-L": (
55
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
56
+ ),
57
+ "LinWeizheDragon/FLMR": (
58
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
59
+ ),
60
+ },
61
+ }
62
+
63
+
64
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
66
+ "LinWeizheDragon/FLMR": 512,
67
+ }
68
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
69
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
70
+ "LinWeizheDragon/FLMR": 512,
71
+ }
72
+
73
+
74
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
75
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
76
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
77
+ }
78
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
79
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
80
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
81
+ }
82
+
83
+
84
+ # Modified from colbert.modeling.tokenization
85
+ class FLMRContextEncoderTokenizer(BertTokenizer):
86
+ r"""
87
+ Construct a FLMRContextEncoder tokenizer.
88
+
89
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
90
+ splitting and wordpiece.
91
+
92
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
99
+
100
+ def __init__(
101
+ self,
102
+ doc_maxlen: Optional[int] = 512,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ doc_maxlen=doc_maxlen,
107
+ **kwargs,
108
+ )
109
+
110
+ self.doc_maxlen = doc_maxlen
111
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
112
+
113
+ def __call__(
114
+ self,
115
+ text: List[str],
116
+ padding: Optional[Union[str, bool]] = "max_length",
117
+ truncation: Optional[Union[bool, str]] = "longest_first",
118
+ max_length: Optional[int] = 512,
119
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
120
+ **kwargs,
121
+ ):
122
+ # add placehold for the [D] marker
123
+ text = [". " + x for x in text]
124
+
125
+ if max_length > self.doc_maxlen:
126
+ # can not exceed the pre-set length
127
+ max_length = self.doc_maxlen
128
+
129
+ encoding = super().__call__(
130
+ text,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ return_tensors=return_tensors,
134
+ max_length=max_length,
135
+ **kwargs,
136
+ )
137
+
138
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
139
+
140
+ # postprocess for the [D] marker
141
+ ids[:, 1] = self.D_marker_token_id
142
+
143
+ # if bsize:
144
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
145
+ # if image_features is not None:
146
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
147
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
148
+ # else:
149
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
150
+ # batches = _split_into_batches(ids, mask, bsize)
151
+
152
+ # return batches, reverse_indices
153
+
154
+ encoding["input_ids"] = ids
155
+ encoding["attention_mask"] = mask
156
+
157
+ return encoding
158
+
159
+
160
+ # Modified from colbert.modeling.tokenization
161
+ class FLMRQueryEncoderTokenizer(BertTokenizer):
162
+ r"""
163
+ Constructs a FLMRQueryEncoder tokenizer.
164
+
165
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
166
+ splitting and wordpiece.
167
+
168
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
169
+ """
170
+
171
+ vocab_files_names = VOCAB_FILES_NAMES
172
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
173
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
174
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ query_maxlen: Optional[int] = 32,
180
+ attend_to_mask_tokens: Optional[bool] = False,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ *args,
185
+ query_maxlen=query_maxlen,
186
+ attend_to_mask_tokens=attend_to_mask_tokens,
187
+ **kwargs,
188
+ )
189
+
190
+ self.query_maxlen = query_maxlen
191
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
192
+ self.attend_to_mask_tokens = attend_to_mask_tokens
193
+
194
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
195
+
196
+ def __call__(
197
+ self,
198
+ text: Union[str, List[str]],
199
+ padding: Optional[Union[str, bool]] = "max_length",
200
+ truncation: Optional[Union[bool, str]] = True,
201
+ max_length: Optional[int] = None,
202
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
203
+ **kwargs,
204
+ ):
205
+ if isinstance(text, str):
206
+ # convert to list if input is a single string
207
+ text = [text]
208
+
209
+ # add placehold for the [Q] marker
210
+ text = [". " + x for x in text]
211
+
212
+ if max_length is not None:
213
+ # use user specified max_length
214
+ pass
215
+ else:
216
+ # use default max length
217
+ max_length = self.query_maxlen
218
+
219
+ encoding = super().__call__(
220
+ text,
221
+ padding=padding,
222
+ truncation=truncation,
223
+ return_tensors=return_tensors,
224
+ max_length=max_length,
225
+ **kwargs,
226
+ )
227
+
228
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
229
+
230
+ # postprocess for the [Q] marker and the [MASK] augmentation
231
+ ids[:, 1] = self.Q_marker_token_id
232
+ ids[ids == self.pad_token_id] = self.mask_token_id
233
+
234
+ if self.attend_to_mask_tokens:
235
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
236
+ mask[ids == self.mask_token_id] = 1
237
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
238
+
239
+ return {"input_ids": ids, "attention_mask": mask}
query_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "30522": {
44
+ "content": "<BOV>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "30523": {
52
+ "content": "<SOV>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "30524": {
60
+ "content": "<EOV>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "30525": {
68
+ "content": "<BOQ>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "30526": {
76
+ "content": "<EOQ>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "30527": {
84
+ "content": "<BOC>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "30528": {
92
+ "content": "<EOC>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "30529": {
100
+ "content": "<BOK>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "30530": {
108
+ "content": "<EOK>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ }
115
+ },
116
+ "additional_special_tokens": [
117
+ "<BOV>",
118
+ "<SOV>",
119
+ "<EOV>",
120
+ "<BOQ>",
121
+ "<EOQ>",
122
+ "<BOC>",
123
+ "<EOC>",
124
+ "<BOK>",
125
+ "<EOK>"
126
+ ],
127
+ "attend_to_mask_tokens": false,
128
+ "auto_map": {
129
+ "AutoTokenizer": [
130
+ "tokenization_flmr.FLMRQueryEncoderTokenizer",
131
+ null
132
+ ]
133
+ },
134
+ "clean_up_tokenization_spaces": true,
135
+ "cls_token": "[CLS]",
136
+ "do_basic_tokenize": true,
137
+ "do_lower_case": true,
138
+ "mask_token": "[MASK]",
139
+ "model_max_length": 1000000000000000019884624838656,
140
+ "never_split": null,
141
+ "pad_token": "[PAD]",
142
+ "query_maxlen": 512,
143
+ "sep_token": "[SEP]",
144
+ "strip_accents": null,
145
+ "tokenize_chinese_chars": true,
146
+ "tokenizer_class": "FLMRQueryEncoderTokenizer",
147
+ "unk_token": "[UNK]"
148
+ }
query_tokenizer/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
segmented_maxsim.cpp ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <pthread.h>
2
+ #include <torch/extension.h>
3
+
4
+ #include <algorithm>
5
+ #include <numeric>
6
+
7
+ typedef struct {
8
+ int tid;
9
+ int nthreads;
10
+
11
+ int ndocs;
12
+ int ndoc_vectors;
13
+ int nquery_vectors;
14
+
15
+ int64_t* lengths;
16
+ float* scores;
17
+ int64_t* offsets;
18
+
19
+ float* max_scores;
20
+ } max_args_t;
21
+
22
+ void* max(void* args) {
23
+ max_args_t* max_args = (max_args_t*)args;
24
+
25
+ int ndocs_per_thread =
26
+ std::ceil(((float)max_args->ndocs) / max_args->nthreads);
27
+ int start = max_args->tid * ndocs_per_thread;
28
+ int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs);
29
+
30
+ auto max_scores_offset =
31
+ max_args->max_scores + (start * max_args->nquery_vectors);
32
+ auto scores_offset =
33
+ max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors);
34
+
35
+ for (int i = start; i < end; i++) {
36
+ for (int j = 0; j < max_args->lengths[i]; j++) {
37
+ std::transform(max_scores_offset,
38
+ max_scores_offset + max_args->nquery_vectors,
39
+ scores_offset, max_scores_offset,
40
+ [](float a, float b) { return std::max(a, b); });
41
+ scores_offset += max_args->nquery_vectors;
42
+ }
43
+ max_scores_offset += max_args->nquery_vectors;
44
+ }
45
+
46
+ return NULL;
47
+ }
48
+
49
+ torch::Tensor segmented_maxsim(const torch::Tensor scores,
50
+ const torch::Tensor lengths) {
51
+ auto lengths_a = lengths.data_ptr<int64_t>();
52
+ auto scores_a = scores.data_ptr<float>();
53
+ auto ndocs = lengths.size(0);
54
+ auto ndoc_vectors = scores.size(0);
55
+ auto nquery_vectors = scores.size(1);
56
+ auto nthreads = at::get_num_threads();
57
+
58
+ torch::Tensor max_scores =
59
+ torch::zeros({ndocs, nquery_vectors}, scores.options());
60
+
61
+ int64_t offsets[ndocs + 1];
62
+ offsets[0] = 0;
63
+ std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1);
64
+
65
+ pthread_t threads[nthreads];
66
+ max_args_t args[nthreads];
67
+
68
+ for (int i = 0; i < nthreads; i++) {
69
+ args[i].tid = i;
70
+ args[i].nthreads = nthreads;
71
+
72
+ args[i].ndocs = ndocs;
73
+ args[i].ndoc_vectors = ndoc_vectors;
74
+ args[i].nquery_vectors = nquery_vectors;
75
+
76
+ args[i].lengths = lengths_a;
77
+ args[i].scores = scores_a;
78
+ args[i].offsets = offsets;
79
+
80
+ args[i].max_scores = max_scores.data_ptr<float>();
81
+
82
+ int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]);
83
+ if (rc) {
84
+ fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
85
+ }
86
+ }
87
+
88
+ for (int i = 0; i < nthreads; i++) {
89
+ pthread_join(threads[i], NULL);
90
+ }
91
+
92
+ return max_scores.sum(1);
93
+ }
94
+
95
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
96
+ m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim");
97
+ }
tokenization_flmr.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
27
+
28
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "LinWeizheDragon/PreFLMR_ViT-L": (
31
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
32
+ ),
33
+ "LinWeizheDragon/FLMR": (
34
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
35
+ ),
36
+ },
37
+ "tokenizer_file": {
38
+ "LinWeizheDragon/PreFLMR_ViT-L": (
39
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
40
+ ),
41
+ "LinWeizheDragon/FLMR": (
42
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
43
+ ),
44
+ },
45
+ }
46
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
47
+ "vocab_file": {
48
+ "LinWeizheDragon/PreFLMR_ViT-L": (
49
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
50
+ ),
51
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
52
+ },
53
+ "tokenizer_file": {
54
+ "LinWeizheDragon/PreFLMR_ViT-L": (
55
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
56
+ ),
57
+ "LinWeizheDragon/FLMR": (
58
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
59
+ ),
60
+ },
61
+ }
62
+
63
+
64
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
66
+ "LinWeizheDragon/FLMR": 512,
67
+ }
68
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
69
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
70
+ "LinWeizheDragon/FLMR": 512,
71
+ }
72
+
73
+
74
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
75
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
76
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
77
+ }
78
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
79
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
80
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
81
+ }
82
+
83
+
84
+ # Modified from colbert.modeling.tokenization
85
+ class FLMRContextEncoderTokenizer(BertTokenizer):
86
+ r"""
87
+ Construct a FLMRContextEncoder tokenizer.
88
+
89
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
90
+ splitting and wordpiece.
91
+
92
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
99
+
100
+ def __init__(
101
+ self,
102
+ doc_maxlen: Optional[int] = 512,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ doc_maxlen=doc_maxlen,
107
+ **kwargs,
108
+ )
109
+
110
+ self.doc_maxlen = doc_maxlen
111
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
112
+
113
+ def __call__(
114
+ self,
115
+ text: List[str],
116
+ padding: Optional[Union[str, bool]] = "max_length",
117
+ truncation: Optional[Union[bool, str]] = "longest_first",
118
+ max_length: Optional[int] = 512,
119
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
120
+ **kwargs,
121
+ ):
122
+ # add placehold for the [D] marker
123
+ text = [". " + x for x in text]
124
+
125
+ if max_length > self.doc_maxlen:
126
+ # can not exceed the pre-set length
127
+ max_length = self.doc_maxlen
128
+
129
+ encoding = super().__call__(
130
+ text,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ return_tensors=return_tensors,
134
+ max_length=max_length,
135
+ **kwargs,
136
+ )
137
+
138
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
139
+
140
+ # postprocess for the [D] marker
141
+ ids[:, 1] = self.D_marker_token_id
142
+
143
+ # if bsize:
144
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
145
+ # if image_features is not None:
146
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
147
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
148
+ # else:
149
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
150
+ # batches = _split_into_batches(ids, mask, bsize)
151
+
152
+ # return batches, reverse_indices
153
+
154
+ encoding["input_ids"] = ids
155
+ encoding["attention_mask"] = mask
156
+
157
+ return encoding
158
+
159
+
160
+ # Modified from colbert.modeling.tokenization
161
+ class FLMRQueryEncoderTokenizer(BertTokenizer):
162
+ r"""
163
+ Constructs a FLMRQueryEncoder tokenizer.
164
+
165
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
166
+ splitting and wordpiece.
167
+
168
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
169
+ """
170
+
171
+ vocab_files_names = VOCAB_FILES_NAMES
172
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
173
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
174
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ query_maxlen: Optional[int] = 32,
180
+ attend_to_mask_tokens: Optional[bool] = False,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ *args,
185
+ query_maxlen=query_maxlen,
186
+ attend_to_mask_tokens=attend_to_mask_tokens,
187
+ **kwargs,
188
+ )
189
+
190
+ self.query_maxlen = query_maxlen
191
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
192
+ self.attend_to_mask_tokens = attend_to_mask_tokens
193
+
194
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
195
+
196
+ def __call__(
197
+ self,
198
+ text: Union[str, List[str]],
199
+ padding: Optional[Union[str, bool]] = "max_length",
200
+ truncation: Optional[Union[bool, str]] = True,
201
+ max_length: Optional[int] = None,
202
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
203
+ **kwargs,
204
+ ):
205
+ if isinstance(text, str):
206
+ # convert to list if input is a single string
207
+ text = [text]
208
+
209
+ # add placehold for the [Q] marker
210
+ text = [". " + x for x in text]
211
+
212
+ if max_length is not None:
213
+ # use user specified max_length
214
+ pass
215
+ else:
216
+ # use default max length
217
+ max_length = self.query_maxlen
218
+
219
+ encoding = super().__call__(
220
+ text,
221
+ padding=padding,
222
+ truncation=truncation,
223
+ return_tensors=return_tensors,
224
+ max_length=max_length,
225
+ **kwargs,
226
+ )
227
+
228
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
229
+
230
+ # postprocess for the [Q] marker and the [MASK] augmentation
231
+ ids[:, 1] = self.Q_marker_token_id
232
+ ids[ids == self.pad_token_id] = self.mask_token_id
233
+
234
+ if self.attend_to_mask_tokens:
235
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
236
+ mask[ids == self.mask_token_id] = 1
237
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
238
+
239
+ return {"input_ids": ids, "attention_mask": mask}
tokenization_flmr_fast.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from transformers.utils import logging
19
+ from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
20
+ from .tokenization_flmr import FLMRContextEncoderTokenizer, FLMRQueryEncoderTokenizer
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
26
+
27
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
28
+ "vocab_file": {
29
+ "LinWeizheDragon/PreFLMR_ViT-L": (
30
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
31
+ ),
32
+ "LinWeizheDragon/FLMR": (
33
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
34
+ ),
35
+ },
36
+ "tokenizer_file": {
37
+ "LinWeizheDragon/PreFLMR_ViT-L": (
38
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
39
+ ),
40
+ "LinWeizheDragon/FLMR": (
41
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
42
+ ),
43
+ },
44
+ }
45
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
46
+ "vocab_file": {
47
+ "LinWeizheDragon/PreFLMR_ViT-L": (
48
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
49
+ ),
50
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
51
+ },
52
+ "tokenizer_file": {
53
+ "LinWeizheDragon/PreFLMR_ViT-L": (
54
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
55
+ ),
56
+ "LinWeizheDragon/FLMR": (
57
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
58
+ ),
59
+ },
60
+ }
61
+
62
+
63
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
64
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
65
+ "LinWeizheDragon/FLMR": 512,
66
+ }
67
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
68
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
69
+ "LinWeizheDragon/FLMR": 512,
70
+ }
71
+
72
+
73
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
74
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
75
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
76
+ }
77
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
78
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
79
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
80
+ }
81
+
82
+
83
+ class FLMRContextEncoderTokenizerFast(BertTokenizerFast):
84
+ r"""
85
+ Construct a "fast" FLMRContextEncoder tokenizer (backed by HuggingFace's *tokenizers* library).
86
+
87
+ [`FLMRContextEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
88
+ punctuation splitting and wordpiece.
89
+
90
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
91
+ """
92
+
93
+ vocab_files_names = VOCAB_FILES_NAMES
94
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
95
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
96
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
97
+ slow_tokenizer_class = FLMRContextEncoderTokenizer
98
+
99
+
100
+ class FLMRQueryEncoderTokenizerFast(BertTokenizerFast):
101
+ r"""
102
+ Constructs a "fast" FLMRQueryEncoderTokenizer tokenizer (backed by HuggingFace's *tokenizers* library).
103
+
104
+ [`FLMRQueryEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
105
+ punctuation splitting and wordpiece.
106
+
107
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
108
+ """
109
+
110
+ vocab_files_names = VOCAB_FILES_NAMES
111
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
112
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
113
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
114
+ slow_tokenizer_class = FLMRQueryEncoderTokenizer