psamal commited on
Commit
13f969b
·
verified ·
1 Parent(s): a7dc9b3

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
chat_template.jinja ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- set image_count = namespace(value=0) %}
2
+ {%- set video_count = namespace(value=0) %}
3
+ {%- macro render_content(content, do_vision_count, is_system_content=false) %}
4
+ {%- if content is string %}
5
+ {{- content }}
6
+ {%- elif content is iterable and content is not mapping %}
7
+ {%- for item in content %}
8
+ {%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
9
+ {%- if is_system_content %}
10
+ {{- raise_exception('System message cannot contain images.') }}
11
+ {%- endif %}
12
+ {%- if do_vision_count %}
13
+ {%- set image_count.value = image_count.value + 1 %}
14
+ {%- endif %}
15
+ {%- if add_vision_id %}
16
+ {{- 'Picture ' ~ image_count.value ~ ': ' }}
17
+ {%- endif %}
18
+ {{- '<|vision_start|><|image_pad|><|vision_end|>' }}
19
+ {%- elif 'video' in item or item.type == 'video' %}
20
+ {%- if is_system_content %}
21
+ {{- raise_exception('System message cannot contain videos.') }}
22
+ {%- endif %}
23
+ {%- if do_vision_count %}
24
+ {%- set video_count.value = video_count.value + 1 %}
25
+ {%- endif %}
26
+ {%- if add_vision_id %}
27
+ {{- 'Video ' ~ video_count.value ~ ': ' }}
28
+ {%- endif %}
29
+ {{- '<|vision_start|><|video_pad|><|vision_end|>' }}
30
+ {%- elif 'text' in item %}
31
+ {{- item.text }}
32
+ {%- else %}
33
+ {{- raise_exception('Unexpected item type in content.') }}
34
+ {%- endif %}
35
+ {%- endfor %}
36
+ {%- elif content is none or content is undefined %}
37
+ {{- '' }}
38
+ {%- else %}
39
+ {{- raise_exception('Unexpected content type.') }}
40
+ {%- endif %}
41
+ {%- endmacro %}
42
+ {%- if not messages %}
43
+ {{- raise_exception('No messages provided.') }}
44
+ {%- endif %}
45
+ {%- if tools and tools is iterable and tools is not mapping %}
46
+ {{- '<|im_start|>system\n' }}
47
+ {{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
48
+ {%- for tool in tools %}
49
+ {{- "\n" }}
50
+ {{- tool | tojson }}
51
+ {%- endfor %}
52
+ {{- "\n</tools>" }}
53
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
54
+ {%- if messages[0].role == 'system' %}
55
+ {%- set content = render_content(messages[0].content, false, true)|trim %}
56
+ {%- if content %}
57
+ {{- '\n\n' + content }}
58
+ {%- endif %}
59
+ {%- endif %}
60
+ {{- '<|im_end|>\n' }}
61
+ {%- else %}
62
+ {%- if messages[0].role == 'system' %}
63
+ {%- set content = render_content(messages[0].content, false, true)|trim %}
64
+ {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
65
+ {%- endif %}
66
+ {%- endif %}
67
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
68
+ {%- for message in messages[::-1] %}
69
+ {%- set index = (messages|length - 1) - loop.index0 %}
70
+ {%- if ns.multi_step_tool and message.role == "user" %}
71
+ {%- set content = render_content(message.content, false)|trim %}
72
+ {%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
73
+ {%- set ns.multi_step_tool = false %}
74
+ {%- set ns.last_query_index = index %}
75
+ {%- endif %}
76
+ {%- endif %}
77
+ {%- endfor %}
78
+ {%- if ns.multi_step_tool %}
79
+ {{- raise_exception('No user query found in messages.') }}
80
+ {%- endif %}
81
+ {%- for message in messages %}
82
+ {%- set content = render_content(message.content, true)|trim %}
83
+ {%- if message.role == "system" %}
84
+ {%- if not loop.first %}
85
+ {{- raise_exception('System message must be at the beginning.') }}
86
+ {%- endif %}
87
+ {%- elif message.role == "user" %}
88
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
89
+ {%- elif message.role == "assistant" %}
90
+ {%- set reasoning_content = '' %}
91
+ {%- if message.reasoning_content is string %}
92
+ {%- set reasoning_content = message.reasoning_content %}
93
+ {%- else %}
94
+ {%- if '</think>' in content %}
95
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
96
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
97
+ {%- endif %}
98
+ {%- endif %}
99
+ {%- set reasoning_content = reasoning_content|trim %}
100
+ {%- if loop.index0 > ns.last_query_index %}
101
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
102
+ {%- else %}
103
+ {{- '<|im_start|>' + message.role + '\n' + content }}
104
+ {%- endif %}
105
+ {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
106
+ {%- for tool_call in message.tool_calls %}
107
+ {%- if tool_call.function is defined %}
108
+ {%- set tool_call = tool_call.function %}
109
+ {%- endif %}
110
+ {%- if loop.first %}
111
+ {%- if content|trim %}
112
+ {{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
113
+ {%- else %}
114
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
115
+ {%- endif %}
116
+ {%- else %}
117
+ {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
118
+ {%- endif %}
119
+ {%- if tool_call.arguments is defined %}
120
+ {%- for args_name, args_value in tool_call.arguments|items %}
121
+ {{- '<parameter=' + args_name + '>\n' }}
122
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
123
+ {{- args_value }}
124
+ {{- '\n</parameter>\n' }}
125
+ {%- endfor %}
126
+ {%- endif %}
127
+ {{- '</function>\n</tool_call>' }}
128
+ {%- endfor %}
129
+ {%- endif %}
130
+ {{- '<|im_end|>\n' }}
131
+ {%- elif message.role == "tool" %}
132
+ {%- if loop.previtem and loop.previtem.role != "tool" %}
133
+ {{- '<|im_start|>user' }}
134
+ {%- endif %}
135
+ {{- '\n<tool_response>\n' }}
136
+ {{- content }}
137
+ {{- '\n</tool_response>' }}
138
+ {%- if not loop.last and loop.nextitem.role != "tool" %}
139
+ {{- '<|im_end|>\n' }}
140
+ {%- elif loop.last %}
141
+ {{- '<|im_end|>\n' }}
142
+ {%- endif %}
143
+ {%- else %}
144
+ {{- raise_exception('Unexpected message role.') }}
145
+ {%- endif %}
146
+ {%- endfor %}
147
+ {%- if add_generation_prompt %}
148
+ {{- '<|im_start|>assistant\n' }}
149
+ {%- if enable_thinking is defined and enable_thinking is false %}
150
+ {{- '<think>\n\n</think>\n\n' }}
151
+ {%- else %}
152
+ {{- '<think>\n' }}
153
+ {%- endif %}
154
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ColVec1"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_colvec1.ColVec1Config",
7
+ "AutoModel": "model.ColVec1"
8
+ },
9
+ "base_model_name_or_path": "Qwen/Qwen3.5-9B",
10
+ "dtype": "bfloat16",
11
+ "embed_dim": 2560,
12
+ "initializer_range": 0.02,
13
+ "model_type": "colvec1",
14
+ "padding_side": "left",
15
+ "text_hidden_size": 4096,
16
+ "transformers_version": "5.3.0"
17
+ }
configuration_colvec1.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for ColVec1 retrieval model.
3
+ """
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ColVec1Config(PretrainedConfig):
9
+ """Configuration for the ColVec1 retrieval wrapper."""
10
+
11
+ model_type = "colvec1"
12
+
13
+ def __init__(
14
+ self,
15
+ embed_dim: int = 128,
16
+ text_hidden_size: int = 2560,
17
+ padding_side: str = "left",
18
+ initializer_range: float = 0.02,
19
+ base_model_name_or_path: str = None,
20
+ **kwargs,
21
+ ):
22
+ super().__init__(**kwargs)
23
+ self.embed_dim = embed_dim
24
+ self.text_hidden_size = text_hidden_size
25
+ self.padding_side = padding_side
26
+ self.initializer_range = initializer_range
27
+ self.base_model_name_or_path = base_model_name_or_path
28
+
29
+
30
+ __all__ = ["ColVec1Config"]
model.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ColVec1 - ColVec1 retrieval wrapper for late interaction.
3
+ """
4
+
5
+ import glob
6
+ import json
7
+ import os
8
+ from typing import ClassVar, List, Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers import AutoModelForImageTextToText, PreTrainedModel
13
+
14
+ from .configuration_colvec1 import ColVec1Config
15
+
16
+
17
+ class ColVec1PreTrainedModel(PreTrainedModel):
18
+ """Base class for ColVec1 models."""
19
+
20
+ config_class = ColVec1Config
21
+ base_model_prefix = "colvec1"
22
+ supports_gradient_checkpointing = True
23
+ _tied_weights_keys: ClassVar[List[str]] = []
24
+
25
+
26
+ class ColVec1(ColVec1PreTrainedModel):
27
+ """
28
+ Retrieval model wrapper for ColVec1 checkpoints.
29
+
30
+ It loads the upstream model with `AutoModelForImageTextToText`, then adds
31
+ a projection head to produce L2-normalized retrieval embeddings.
32
+ """
33
+
34
+ main_input_name: ClassVar[str] = "input_ids"
35
+
36
+ def __init__(self, config: ColVec1Config):
37
+ super().__init__(config)
38
+ self.config = config
39
+ self.vlm = None
40
+ self.embedding_proj_layer = nn.Linear(config.text_hidden_size, config.embed_dim)
41
+ self.post_init()
42
+
43
+ def forward(
44
+ self,
45
+ input_ids: torch.LongTensor = None,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ pixel_values: Optional[torch.FloatTensor] = None,
48
+ **kwargs,
49
+ ) -> torch.Tensor:
50
+ kwargs.pop("output_hidden_states", None)
51
+ kwargs.pop("return_dict", None)
52
+
53
+ vlm_outputs = self.vlm(
54
+ input_ids=input_ids,
55
+ attention_mask=attention_mask,
56
+ pixel_values=pixel_values,
57
+ output_hidden_states=True,
58
+ return_dict=True,
59
+ **kwargs,
60
+ )
61
+
62
+ if hasattr(vlm_outputs, "hidden_states") and vlm_outputs.hidden_states is not None:
63
+ last_hidden_states = vlm_outputs.hidden_states[-1]
64
+ elif hasattr(vlm_outputs, "last_hidden_state"):
65
+ last_hidden_states = vlm_outputs.last_hidden_state
66
+ else:
67
+ last_hidden_states = vlm_outputs[0]
68
+
69
+ embeddings = self.embedding_proj_layer(
70
+ last_hidden_states.to(self.embedding_proj_layer.weight.dtype)
71
+ )
72
+ embeddings = nn.functional.normalize(embeddings, p=2, dim=-1)
73
+
74
+ if attention_mask is not None:
75
+ embeddings = embeddings * attention_mask.unsqueeze(-1)
76
+
77
+ return embeddings
78
+
79
+ @classmethod
80
+ def from_pretrained(
81
+ cls,
82
+ pretrained_model_name_or_path: str,
83
+ embed_dim: int = 128,
84
+ torch_dtype: torch.dtype = None,
85
+ device_map: str = None,
86
+ attn_impl: str = None,
87
+ **kwargs,
88
+ ):
89
+ # AutoModel may rename torch_dtype -> dtype in newer transformers
90
+ if torch_dtype is None:
91
+ torch_dtype = kwargs.pop("dtype", None)
92
+
93
+ # Pop config early so we can inspect model_type for merged-repo detection.
94
+ # When called via AutoModel.from_pretrained, transformers resolves the config
95
+ # and passes it here as a kwarg;
96
+ config = kwargs.pop("config", None)
97
+ if config is not None and hasattr(config, "embed_dim"):
98
+ embed_dim = config.embed_dim
99
+
100
+ # Detect a merged ColVec1 repo using three strategies in order:
101
+ # 1. config object already provided (Hub path via AutoModel dispatch)
102
+ # 2. local config.json on disk (direct local-path usage)
103
+ # 3. AutoConfig.from_pretrained (direct Hub ID usage without AutoModel)
104
+ _is_merged = (
105
+ config is not None
106
+ and getattr(config, "model_type", None) == "colvec1"
107
+ )
108
+
109
+ if not _is_merged:
110
+ config_path = os.path.join(pretrained_model_name_or_path, "config.json")
111
+ if os.path.exists(config_path):
112
+ with open(config_path) as f:
113
+ raw = json.load(f)
114
+ _is_merged = raw.get("model_type") == "colvec1"
115
+ else:
116
+ # Remote Hub ID: fetch the config to check model_type.
117
+ from transformers import AutoConfig
118
+ try:
119
+ hub_config = AutoConfig.from_pretrained(
120
+ pretrained_model_name_or_path,
121
+ trust_remote_code=kwargs.get("trust_remote_code", True),
122
+ )
123
+ _is_merged = getattr(hub_config, "model_type", None) == "colvec1"
124
+ except Exception:
125
+ pass
126
+
127
+ if _is_merged:
128
+ return cls._load_merged(
129
+ pretrained_model_name_or_path,
130
+ torch_dtype=torch_dtype,
131
+ device_map=device_map,
132
+ attn_impl=attn_impl,
133
+ **kwargs,
134
+ )
135
+
136
+ # --- From-scratch path: load a raw Qwen3.5 VLM and wrap it ---
137
+ # (config was already popped above; rest of the method is unchanged)
138
+ vlm_kwargs = {"trust_remote_code": kwargs.pop("trust_remote_code", True)}
139
+ if torch_dtype is not None:
140
+ vlm_kwargs["torch_dtype"] = torch_dtype
141
+ if device_map is not None:
142
+ vlm_kwargs["device_map"] = device_map
143
+ if attn_impl is not None:
144
+ vlm_kwargs["attn_implementation"] = attn_impl
145
+ if "quantization_config" in kwargs:
146
+ vlm_kwargs["quantization_config"] = kwargs.pop("quantization_config")
147
+
148
+ vlm = AutoModelForImageTextToText.from_pretrained(pretrained_model_name_or_path, **vlm_kwargs)
149
+
150
+ if hasattr(vlm.config, "text_config") and hasattr(vlm.config.text_config, "hidden_size"):
151
+ text_hidden_size = vlm.config.text_config.hidden_size
152
+ else:
153
+ text_hidden_size = getattr(vlm.config, "hidden_size", 2560)
154
+
155
+ model_config = ColVec1Config(
156
+ embed_dim=embed_dim,
157
+ text_hidden_size=text_hidden_size,
158
+ padding_side="left",
159
+ )
160
+ model = cls(model_config)
161
+ model.vlm = vlm
162
+ model.embedding_proj_layer = nn.Linear(model_config.text_hidden_size, model_config.embed_dim)
163
+
164
+ if torch_dtype is not None:
165
+ model.embedding_proj_layer = model.embedding_proj_layer.to(torch_dtype)
166
+
167
+ if hasattr(vlm, "device"):
168
+ model.embedding_proj_layer = model.embedding_proj_layer.to(vlm.device)
169
+
170
+ tied = getattr(vlm, "_tied_weights_keys", None)
171
+ if isinstance(tied, dict):
172
+ model._tied_weights_keys = {f"vlm.{k}": f"vlm.{v}" for k, v in tied.items()}
173
+ elif isinstance(tied, (list, tuple, set)):
174
+ model._tied_weights_keys = [f"vlm.{k}" for k in tied]
175
+ else:
176
+ model._tied_weights_keys = []
177
+
178
+ return model
179
+
180
+ @classmethod
181
+ def _load_merged(
182
+ cls,
183
+ path: str,
184
+ torch_dtype: torch.dtype = None,
185
+ device_map: str = None,
186
+ attn_impl: str = None,
187
+ **kwargs,
188
+ ):
189
+ """Load a merged ColVec1 checkpoint (dense VLM weights + embedding_proj_layer)."""
190
+ from safetensors.torch import load_file
191
+
192
+ # Resolve Hub repo ID to a local cached snapshot directory so all
193
+ # subsequent os.path / glob operations work for both local and remote paths.
194
+ if not os.path.isdir(path):
195
+ from huggingface_hub import snapshot_download
196
+ path = snapshot_download(path)
197
+
198
+ config = ColVec1Config.from_pretrained(path)
199
+ base_name = config.base_model_name_or_path
200
+ if base_name is None:
201
+ raise ValueError(
202
+ f"Merged ColVec1 config at {path} is missing 'base_model_name_or_path'. "
203
+ "This field is required to know which VLM architecture to instantiate."
204
+ )
205
+
206
+ vlm_kwargs = {"trust_remote_code": True}
207
+ if torch_dtype is not None:
208
+ vlm_kwargs["torch_dtype"] = torch_dtype
209
+ if device_map is not None:
210
+ vlm_kwargs["device_map"] = device_map
211
+ if attn_impl is not None:
212
+ vlm_kwargs["attn_implementation"] = attn_impl
213
+
214
+ vlm = AutoModelForImageTextToText.from_pretrained(base_name, **vlm_kwargs)
215
+
216
+ model = cls(config)
217
+ model.vlm = vlm
218
+
219
+ safetensor_files = sorted(glob.glob(os.path.join(path, "model*.safetensors")))
220
+ if not safetensor_files:
221
+ raise FileNotFoundError(f"No model*.safetensors files found in {path}")
222
+
223
+ state_dict = {}
224
+ for sf in safetensor_files:
225
+ state_dict.update(load_file(sf))
226
+
227
+ model.load_state_dict(state_dict, strict=False)
228
+
229
+ if torch_dtype is not None:
230
+ model.embedding_proj_layer = model.embedding_proj_layer.to(torch_dtype)
231
+ if hasattr(vlm, "device"):
232
+ model.embedding_proj_layer = model.embedding_proj_layer.to(vlm.device)
233
+
234
+ tied = getattr(vlm, "_tied_weights_keys", None)
235
+ if isinstance(tied, dict):
236
+ model._tied_weights_keys = {f"vlm.{k}": f"vlm.{v}" for k, v in tied.items()}
237
+ elif isinstance(tied, (list, tuple, set)):
238
+ model._tied_weights_keys = [f"vlm.{k}" for k in tied]
239
+ else:
240
+ model._tied_weights_keys = []
241
+
242
+ return model
243
+
244
+ def tie_weights(self, *args, **kwargs):
245
+ if self.vlm is None:
246
+ # Called during post_init() before the wrapped VLM is attached.
247
+ return None
248
+ try:
249
+ return self.vlm.tie_weights(*args, **kwargs)
250
+ except TypeError:
251
+ return self.vlm.tie_weights()
252
+
253
+ def get_input_embeddings(self):
254
+ return self.vlm.get_input_embeddings()
255
+
256
+ def set_input_embeddings(self, value):
257
+ self.vlm.set_input_embeddings(value)
258
+
259
+ def get_output_embeddings(self):
260
+ return self.vlm.get_output_embeddings()
261
+
262
+ def set_output_embeddings(self, new_embeddings):
263
+ self.vlm.set_output_embeddings(new_embeddings)
264
+
265
+ def resize_token_embeddings(
266
+ self,
267
+ new_num_tokens: Optional[int] = None,
268
+ pad_to_multiple_of: Optional[int] = None,
269
+ mean_resizing: bool = True,
270
+ ) -> nn.Embedding:
271
+ model_embeds = self.vlm.resize_token_embeddings(
272
+ new_num_tokens=new_num_tokens,
273
+ pad_to_multiple_of=pad_to_multiple_of,
274
+ mean_resizing=mean_resizing,
275
+ )
276
+
277
+ if hasattr(self.vlm.config, "text_config"):
278
+ self.vlm.config.text_config.vocab_size = model_embeds.num_embeddings
279
+ if hasattr(self.vlm.config, "vocab_size"):
280
+ self.vlm.config.vocab_size = model_embeds.num_embeddings
281
+ return model_embeds
282
+
283
+ @property
284
+ def device(self):
285
+ return next(self.parameters()).device
286
+
287
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
288
+ if self.vlm is not None and hasattr(self.vlm, "gradient_checkpointing_enable"):
289
+ self.vlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
290
+
291
+ def gradient_checkpointing_disable(self):
292
+ if self.vlm is not None and hasattr(self.vlm, "gradient_checkpointing_disable"):
293
+ self.vlm.gradient_checkpointing_disable()
294
+
295
+
296
+ __all__ = ["ColVec1", "ColVec1PreTrainedModel"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df0211664297feed8aee74da5a0c60275512a41f08c7cb9d8a9387598f18903d
3
+ size 18840702256
processor.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ColVec1 processor.
3
+
4
+ Processing utilities for ColVec1, aligned with the ColQwen3 reference implementation.
5
+ """
6
+
7
+ import importlib
8
+ import numpy as np
9
+ from typing import Any, List, Optional, Tuple, Union
10
+ import torch
11
+ from PIL import Image
12
+ from transformers import BatchEncoding
13
+ from transformers.feature_extraction_utils import BatchFeature
14
+ from transformers.image_utils import ImageInput, is_valid_image
15
+ from transformers.processing_utils import AudioInput, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideoInput
16
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
17
+ from transformers.utils import logging
18
+
19
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ try:
24
+ from fast_plaid import search
25
+ except ImportError:
26
+ logger.info(
27
+ "FastPlaid is not installed.If you want to use it:Instal with `pip install --no-deps fast-plaid fastkmeans`"
28
+ )
29
+
30
+
31
+ def get_torch_device(device: str = "auto") -> str:
32
+ """Resolve a torch device string with a simple auto mode."""
33
+ if device == "auto":
34
+ if torch.cuda.is_available():
35
+ device = "cuda:0"
36
+ elif torch.backends.mps.is_available(): # for Apple Silicon
37
+ device = "mps"
38
+ else:
39
+ device = "cpu"
40
+ return device
41
+
42
+
43
+ class ColVec1ProcessorKwargs(ProcessingKwargs, total=False):
44
+ _defaults = {
45
+ "text_kwargs": {
46
+ "padding": "longest",
47
+ },
48
+ "images_kwargs": {
49
+ "data_format": "channels_first",
50
+ "do_convert_rgb": True,
51
+ },
52
+ "videos_kwargs": {
53
+ "return_metadata": True,
54
+ "data_format": "channels_first",
55
+ "do_convert_rgb": True,
56
+ },
57
+ "common_kwargs": {"return_tensors": "pt"},
58
+ }
59
+
60
+
61
+ class ColVec1Processor(ProcessorMixin):
62
+ """
63
+ Constructs a ColVec1 processor which wraps a Qwen3VLProcessor with retrieval-specific helpers.
64
+ """
65
+
66
+ attributes = ["image_processor", "tokenizer", "video_processor"]
67
+ image_processor_class = "AutoImageProcessor"
68
+ video_processor_class = "AutoVideoProcessor"
69
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
70
+
71
+ def __init__(
72
+ self,
73
+ image_processor=None,
74
+ tokenizer=None,
75
+ video_processor=None,
76
+ chat_template=None,
77
+ visual_prompt_prefix: Optional[str] = None,
78
+ visual_prompt_suffix: Optional[str] = None,
79
+ video_prompt_prefix: Optional[str] = None,
80
+ video_prompt_suffix: Optional[str] = None,
81
+ query_prefix: Optional[str] = None,
82
+ **kwargs,
83
+ ):
84
+ super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
85
+ self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
86
+ self.image_token_id = (
87
+ tokenizer.image_token_id
88
+ if getattr(tokenizer, "image_token_id", None)
89
+ else tokenizer.convert_tokens_to_ids(self.image_token)
90
+ )
91
+ self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
92
+ self.video_token_id = (
93
+ tokenizer.video_token_id
94
+ if getattr(tokenizer, "video_token_id", None)
95
+ else tokenizer.convert_tokens_to_ids(self.video_token)
96
+ )
97
+ self.vision_start_token = (
98
+ "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
99
+ )
100
+ self.vision_end_token = (
101
+ "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
102
+ )
103
+ self.vision_start_token_id = (
104
+ tokenizer.vision_start_token_id
105
+ if getattr(tokenizer, "vision_start_token_id", None)
106
+ else tokenizer.convert_tokens_to_ids(self.vision_start_token)
107
+ )
108
+ self.vision_end_token_id = (
109
+ tokenizer.vision_end_token_id
110
+ if getattr(tokenizer, "vision_end_token_id", None)
111
+ else tokenizer.convert_tokens_to_ids(self.vision_end_token)
112
+ )
113
+
114
+ if visual_prompt_prefix is None:
115
+ visual_prompt_prefix = (
116
+ "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image."
117
+ )
118
+ self.visual_prompt_prefix = visual_prompt_prefix
119
+ if visual_prompt_suffix is None:
120
+ visual_prompt_suffix = "<|im_end|><|endoftext|>"
121
+ self.visual_prompt_suffix = visual_prompt_suffix
122
+
123
+ if video_prompt_prefix is None:
124
+ video_prompt_prefix = (
125
+ "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>Describe the video."
126
+ )
127
+ self.video_prompt_prefix = video_prompt_prefix
128
+ if video_prompt_suffix is None:
129
+ video_prompt_suffix = "<|im_end|><|endoftext|>"
130
+ self.video_prompt_suffix = video_prompt_suffix
131
+
132
+ if query_prefix is None:
133
+ query_prefix = ""
134
+ self.query_prefix = query_prefix
135
+ self.tokenizer.padding_side = "left"
136
+
137
+ @classmethod
138
+ def from_pretrained( # type: ignore[override]
139
+ cls,
140
+ *args: Any,
141
+ max_num_visual_tokens: int = 1280,
142
+ **kwargs: Any,
143
+ ) -> "ColVec1Processor":
144
+ instance = super().from_pretrained(
145
+ *args,
146
+ **kwargs,
147
+ )
148
+
149
+ patch_size = getattr(instance.image_processor, "patch_size", None)
150
+ merge_size = getattr(instance.image_processor, "merge_size", None) or getattr(
151
+ instance.image_processor, "spatial_merge_size", None
152
+ )
153
+ if patch_size is None or merge_size is None:
154
+ raise ValueError("Qwen3VL image processor is missing `patch_size` or `merge_size`/`spatial_merge_size`.")
155
+ tile = patch_size * merge_size
156
+ instance.image_processor.max_pixels = max_num_visual_tokens * tile * tile
157
+ instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
158
+
159
+ video_patch_size = getattr(instance.video_processor, "patch_size", None)
160
+ video_merge_size = getattr(instance.video_processor, "merge_size", None) or getattr(
161
+ instance.video_processor, "spatial_merge_size", None
162
+ )
163
+ video_temporal_patch_size = getattr(instance.video_processor, "temporal_patch_size", None)
164
+ if video_patch_size is None or video_merge_size is None or video_temporal_patch_size is None:
165
+ raise ValueError(
166
+ "Qwen3VL video processor is missing `patch_size`, `merge_size`/`spatial_merge_size`, or `temporal_patch_size`."
167
+ )
168
+ video_tile = video_patch_size * video_merge_size
169
+ # Include temporal patching so the visual token cap applies across space and time.
170
+ instance.video_processor.max_pixels = max_num_visual_tokens * video_tile * video_tile * video_temporal_patch_size
171
+ instance.video_processor.size["longest_edge"] = instance.video_processor.max_pixels
172
+
173
+ return instance
174
+
175
+ def __call__(
176
+ self,
177
+ images: Optional[ImageInput] = None,
178
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
179
+ audio: Optional[AudioInput] = None,
180
+ videos: Optional[VideoInput] = None,
181
+ **kwargs: Unpack[ColVec1ProcessorKwargs],
182
+ ) -> BatchFeature:
183
+ output_kwargs = self._merge_kwargs(
184
+ ColVec1ProcessorKwargs,
185
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
186
+ **kwargs,
187
+ )
188
+ suffix = output_kwargs["text_kwargs"].pop("suffix", None)
189
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
190
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
191
+
192
+ if images is not None and videos is not None:
193
+ raise ValueError("Provide only one of `images` or `videos`, not both.")
194
+
195
+ # Normalize text inputs
196
+ text_list: list[str] = []
197
+ if text is not None:
198
+ if isinstance(text, str):
199
+ text_list = [text]
200
+ elif isinstance(text, list):
201
+ if len(text) == 0 or not all(isinstance(t, (str, type(None))) for t in text):
202
+ raise ValueError("Text must be a string or a list of strings.")
203
+ text_list = [t or "" for t in text]
204
+ else:
205
+ raise ValueError("Text must be a string or a list of strings")
206
+
207
+ # Normalize image inputs
208
+ image_list: Optional[list[Any]] = None
209
+ if images is not None:
210
+ raw_images = images if isinstance(images, list) else [images]
211
+ image_list = []
212
+ for idx, img_item in enumerate(raw_images):
213
+ if img_item is None:
214
+ image_list.append([])
215
+ elif is_valid_image(img_item):
216
+ image_list.append([img_item])
217
+ elif isinstance(img_item, list):
218
+ if not img_item:
219
+ image_list.append([])
220
+ continue
221
+ for sub_idx, sub_img in enumerate(img_item):
222
+ if not is_valid_image(sub_img):
223
+ raise ValueError(f"Image at position {idx}[{sub_idx}] is not a valid image.")
224
+ image_list.append(list(img_item))
225
+ else:
226
+ raise ValueError("images must be an image, list of images or list of list of images")
227
+
228
+ # Normalize video inputs
229
+ video_list: Optional[list[Any]] = None
230
+ if videos is not None:
231
+ raw_videos = list(videos) if isinstance(videos, (list, tuple)) else [videos]
232
+ video_list = []
233
+ for idx, vid_item in enumerate(raw_videos):
234
+ if vid_item is None:
235
+ video_list.append([])
236
+ elif isinstance(vid_item, list):
237
+ video_list.append(list(vid_item))
238
+ else:
239
+ video_list.append([vid_item])
240
+
241
+ if image_list is None and video_list is None and not text_list:
242
+ raise ValueError("Either text, images or videos must be provided")
243
+
244
+ # Align text length with provided vision inputs when needed
245
+ if image_list is not None:
246
+ if not text_list:
247
+ text_list = [""] * len(image_list)
248
+ elif len(text_list) == 1 and len(image_list) > 1:
249
+ text_list = text_list * len(image_list)
250
+ elif len(text_list) != len(image_list):
251
+ raise ValueError("When providing both images and text, their lengths must match.")
252
+ num_items = len(image_list)
253
+ elif video_list is not None:
254
+ if not text_list:
255
+ text_list = [""] * len(video_list)
256
+ elif len(text_list) == 1 and len(video_list) > 1:
257
+ text_list = text_list * len(video_list)
258
+ elif len(text_list) != len(video_list):
259
+ raise ValueError("When providing both videos and text, their lengths must match.")
260
+ num_items = len(video_list)
261
+ else:
262
+ num_items = len(text_list)
263
+
264
+ if num_items == 0:
265
+ raise ValueError("Either text, images or videos must be provided")
266
+
267
+ prompts: list[str] = []
268
+ query_suffix = suffix if suffix is not None else self.query_augmentation_token * 10
269
+
270
+ for idx in range(num_items):
271
+ extra_text = (text_list[idx] if idx < len(text_list) else "") or ""
272
+ extra_text = extra_text.strip()
273
+ has_image = image_list is not None and len(image_list[idx]) > 0
274
+ has_video = video_list is not None and len(video_list[idx]) > 0
275
+ if has_image and has_video:
276
+ raise ValueError("Provide only one of `images` or `videos` per item.")
277
+
278
+ if has_image:
279
+ prompt = (
280
+ f"{self.visual_prompt_prefix} {extra_text}{self.visual_prompt_suffix}"
281
+ if extra_text
282
+ else f"{self.visual_prompt_prefix}{self.visual_prompt_suffix}"
283
+ )
284
+ prompts.append(prompt)
285
+ elif has_video:
286
+ prompt = (
287
+ f"{self.video_prompt_prefix} {extra_text}{self.video_prompt_suffix}"
288
+ if extra_text
289
+ else f"{self.video_prompt_prefix}{self.video_prompt_suffix}"
290
+ )
291
+ prompts.append(prompt)
292
+ else:
293
+ prompt = self.query_prefix + extra_text + query_suffix
294
+ prompts.append(prompt)
295
+
296
+ # Process images (excluding empty placeholders)
297
+ image_inputs: dict[str, Any] = {}
298
+ image_grid_thw = None
299
+ if image_list is not None:
300
+ normalized_images: list[list[Image.Image]] = []
301
+ for idx, img_group in enumerate(image_list):
302
+ converted_list: list[Image.Image] = []
303
+ for sub_idx, sub_img in enumerate(img_group):
304
+ if not is_valid_image(sub_img):
305
+ raise ValueError(f"Image at position {idx}[{sub_idx}] is not a valid image.")
306
+ converted_list.append(sub_img.convert("RGB") if hasattr(sub_img, "convert") else sub_img)
307
+ normalized_images.append(converted_list)
308
+
309
+ image_inputs = self.image_processor(images=normalized_images, **output_kwargs["images_kwargs"])
310
+ image_grid_thw = image_inputs["image_grid_thw"]
311
+
312
+ # Process videos (excluding empty placeholders)
313
+ videos_inputs: dict[str, Any] = {}
314
+ video_grid_thw = None
315
+ video_metadata = None
316
+ if video_list is not None:
317
+ videos_inputs = self.video_processor(videos=video_list, **output_kwargs["videos_kwargs"])
318
+ video_grid_thw = videos_inputs["video_grid_thw"]
319
+ if "return_metadata" not in output_kwargs["videos_kwargs"]:
320
+ video_metadata = videos_inputs.pop("video_metadata")
321
+ else:
322
+ video_metadata = videos_inputs["video_metadata"]
323
+
324
+ # Expand prompts to match the number of visual tokens
325
+ text_prompts = prompts.copy()
326
+ if image_grid_thw is not None:
327
+ merge_size = getattr(self.image_processor, "merge_size", None) or getattr(
328
+ self.image_processor, "spatial_merge_size", None
329
+ )
330
+ if merge_size is None:
331
+ raise ValueError("Qwen3VL image processor is missing `merge_size`/`spatial_merge_size`.")
332
+ merge_length = merge_size**2
333
+ index = 0
334
+ for i in range(len(text_prompts)):
335
+ while self.image_token in text_prompts[i]:
336
+ if index >= len(image_grid_thw):
337
+ raise ValueError("Number of image tokens does not match provided images.")
338
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
339
+ text_prompts[i] = text_prompts[i].replace(
340
+ self.image_token, "<|placeholder|>" * num_image_tokens, 1
341
+ )
342
+ index += 1
343
+ text_prompts[i] = text_prompts[i].replace("<|placeholder|>", self.image_token)
344
+
345
+ if video_grid_thw is not None:
346
+ merge_size = getattr(self.video_processor, "merge_size", None)
347
+ if merge_size is None:
348
+ raise ValueError("Qwen3VL video processor is missing `merge_size`.")
349
+ merge_length = merge_size**2
350
+ index = 0
351
+ for i in range(len(text_prompts)):
352
+ while self.video_token in text_prompts[i]:
353
+ if video_metadata is None or index >= len(video_metadata):
354
+ raise ValueError("Video metadata is required to build video prompts.")
355
+ metadata = video_metadata[index]
356
+ if metadata.fps is None:
357
+ logger.warning_once(
358
+ "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could "
359
+ "not be inferred. Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
360
+ )
361
+ metadata.fps = 24 if metadata.fps is None else metadata.fps
362
+
363
+ curr_timestamp = self._calculate_timestamps(
364
+ metadata.frames_indices, metadata.fps, self.video_processor.merge_size
365
+ )
366
+ frame_seqlen = int(video_grid_thw[index][1:].prod().item() // merge_length)
367
+ video_placeholder = ""
368
+ for frame_idx in range(int(video_grid_thw[index][0])):
369
+ curr_time = curr_timestamp[frame_idx]
370
+ video_placeholder += f"<{curr_time:.1f} seconds>"
371
+ video_placeholder += (
372
+ self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
373
+ )
374
+
375
+ if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text_prompts[i]:
376
+ text_prompts[i] = text_prompts[i].replace(
377
+ f"{self.vision_start_token}{self.video_token}{self.vision_end_token}",
378
+ video_placeholder,
379
+ 1,
380
+ )
381
+ else:
382
+ text_prompts[i] = text_prompts[i].replace(self.video_token, video_placeholder, 1)
383
+ index += 1
384
+
385
+ text_prompts[i] = text_prompts[i].replace("<|placeholder|>", self.video_token)
386
+
387
+ text_inputs = self.tokenizer(text_prompts, **output_kwargs["text_kwargs"])
388
+ self._check_special_mm_tokens(text_prompts, text_inputs, modalities=["image", "video"])
389
+
390
+ if return_mm_token_type_ids:
391
+ array_ids = np.array(text_inputs["input_ids"])
392
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
393
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
394
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
395
+
396
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
397
+
398
+ def process_images(
399
+ self,
400
+ images: List[Image.Image],
401
+ max_length: Optional[int] = None,
402
+ ) -> Union[BatchFeature, BatchEncoding]:
403
+ images = [image.convert("RGB") for image in images]
404
+ kwargs = dict(
405
+ images=images,
406
+ padding="longest",
407
+ return_tensors="pt",
408
+ return_mm_token_type_ids=True,
409
+ )
410
+ if max_length is not None:
411
+ kwargs["max_length"] = max_length
412
+ kwargs["truncation"] = True
413
+ return self(**kwargs)
414
+
415
+ def process_queries(self, texts: List[str], max_length: Optional[int] = None) -> Union[BatchFeature, BatchEncoding]:
416
+ kwargs = dict(text=texts, return_tensors="pt", padding="longest")
417
+ if max_length is not None:
418
+ kwargs["max_length"] = max_length
419
+ kwargs["truncation"] = True
420
+ return self(**kwargs)
421
+
422
+
423
+ @staticmethod
424
+ def _split_batch_feature(batch_feature: BatchFeature) -> list[BatchFeature]:
425
+ # Split a batched BatchFeature into a list of per-item BatchFeatures.
426
+ length: Optional[int] = None
427
+ for value in batch_feature.values():
428
+ if hasattr(value, "__len__"):
429
+ try:
430
+ length = len(value)
431
+ except Exception:
432
+ continue
433
+ if length is not None:
434
+ break
435
+
436
+ if length is None:
437
+ return [batch_feature]
438
+
439
+ items: list[BatchFeature] = []
440
+ for idx in range(length):
441
+ data = {}
442
+ for key, value in batch_feature.items():
443
+ try:
444
+ data[key] = value[idx]
445
+ except Exception:
446
+ data[key] = value
447
+ items.append(BatchFeature(data=data))
448
+ return items
449
+
450
+ @staticmethod
451
+ def _merge_batch_features(features: list[BatchFeature]) -> BatchFeature:
452
+ if not features:
453
+ return BatchFeature()
454
+
455
+ all_keys = set()
456
+ for feat in features:
457
+ all_keys.update(feat.keys())
458
+
459
+ merged: dict[str, list[Any]] = {key: [] for key in all_keys}
460
+ for feat in features:
461
+ for key in all_keys:
462
+ merged[key].append(feat.get(key))
463
+
464
+ combined: dict[str, Any] = {}
465
+ for key, values in merged.items():
466
+ # Prefer stacking tensors so callers get batched tensors instead of lists
467
+ if all(isinstance(v, torch.Tensor) for v in values):
468
+ try:
469
+ combined[key] = torch.stack(values)
470
+ continue
471
+ except Exception:
472
+ # Fallback to list if shapes are incompatible for stacking
473
+ pass
474
+ combined[key] = values
475
+
476
+ return BatchFeature(data=combined)
477
+
478
+ def score_retrieval(
479
+ self,
480
+ qs: List[torch.Tensor],
481
+ ps: List[torch.Tensor],
482
+ score_batch_size: int = 128,
483
+ device: Optional[Union[str, torch.device]] = None,
484
+ **kwargs,
485
+ ) -> torch.Tensor:
486
+ return self.score_multi_vector(qs, ps, batch_size=score_batch_size, device=device, **kwargs)
487
+
488
+ @staticmethod
489
+ def score_single_vector(
490
+ qs: Union[torch.Tensor, List[torch.Tensor]],
491
+ ps: Union[torch.Tensor, List[torch.Tensor]],
492
+ device: Optional[Union[str, torch.device]] = None,
493
+ ) -> torch.Tensor:
494
+ """
495
+ Compute the dot product score for the given single-vector query and passage embeddings.
496
+ """
497
+ device = device or get_torch_device("auto")
498
+
499
+ if isinstance(qs, list) and isinstance(ps, list):
500
+ if len(qs) == 0:
501
+ raise ValueError("No queries provided")
502
+ if len(ps) == 0:
503
+ raise ValueError("No passages provided")
504
+
505
+ qs = torch.stack(qs).to(device)
506
+ ps = torch.stack(ps).to(device)
507
+ else:
508
+ qs = qs.to(device)
509
+ ps = ps.to(device)
510
+
511
+ scores = torch.einsum("bd,cd->bc", qs, ps)
512
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
513
+
514
+ scores = scores.to(torch.float32)
515
+ return scores
516
+
517
+ @staticmethod
518
+ def score_multi_vector(
519
+ qs: Union[torch.Tensor, List[torch.Tensor]],
520
+ ps: Union[torch.Tensor, List[torch.Tensor]],
521
+ batch_size: int = 128,
522
+ device: Optional[Union[str, torch.device]] = None,
523
+ ) -> torch.Tensor:
524
+ """
525
+ Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
526
+ query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
527
+ image of a document page.
528
+ Because the embedding tensors are multi-vector and can thus have different shapes, they
529
+ should be fed as:
530
+ (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
531
+ (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
532
+ obtained by padding the list of tensors.
533
+ Args:
534
+ qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
535
+ ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
536
+ batch_size (`int`, *optional*): Batch size for computing scores.
537
+ device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not
538
+ provided, uses `get_torch_device("auto")`.
539
+ Returns:
540
+ `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
541
+ tensor is saved on the "cpu" device.
542
+ """
543
+ device = device or get_torch_device("auto")
544
+
545
+ if len(qs) == 0:
546
+ raise ValueError("No queries provided")
547
+ if len(ps) == 0:
548
+ raise ValueError("No passages provided")
549
+
550
+ scores_list: List[torch.Tensor] = []
551
+
552
+ for i in range(0, len(qs), batch_size):
553
+ scores_batch = []
554
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
555
+ device
556
+ )
557
+ for j in range(0, len(ps), batch_size):
558
+ ps_batch = torch.nn.utils.rnn.pad_sequence(
559
+ ps[j : j + batch_size], batch_first=True, padding_value=0
560
+ ).to(device)
561
+ scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
562
+ scores_batch = torch.cat(scores_batch, dim=1).cpu()
563
+ scores_list.append(scores_batch)
564
+
565
+ scores = torch.cat(scores_list, dim=0)
566
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
567
+
568
+ scores = scores.to(torch.float32)
569
+ return scores
570
+
571
+ @staticmethod
572
+ def get_topk_plaid(
573
+ qs: Union[torch.Tensor, List[torch.Tensor]],
574
+ plaid_index: "search.FastPlaid",
575
+ k: int = 10,
576
+ batch_size: int = 128,
577
+ device: Optional[Union[str, torch.device]] = None,
578
+ ) -> torch.Tensor:
579
+ """
580
+ Experimental: Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
581
+ query embeddings (`qs`) and passage embeddings endoded in a plaid index. For ColPali, a passage is the
582
+ image of a document page.
583
+ """
584
+ device = device or get_torch_device("auto")
585
+
586
+ if len(qs) == 0:
587
+ raise ValueError("No queries provided")
588
+
589
+ scores_list: List[torch.Tensor] = []
590
+
591
+ for i in range(0, len(qs), batch_size):
592
+ scores_batch = []
593
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
594
+ device
595
+ )
596
+ scores_batch = plaid_index.search(
597
+ queries_embeddings=qs_batch.to(torch.float32),
598
+ top_k=k,
599
+ )
600
+ scores_list.append(scores_batch)
601
+
602
+ return scores_list
603
+
604
+ @staticmethod
605
+ def create_plaid_index(
606
+ ps: Union[torch.Tensor, List[torch.Tensor]],
607
+ device: Optional[Union[str, torch.device]] = None,
608
+ ) -> torch.Tensor:
609
+ """
610
+ Experimental: Create a FastPlaid index from the given passage embeddings.
611
+ Args:
612
+ ps (`Union[torch.Tensor, List[torch.Tensor]]`): Passage embeddings. Should be a list of tensors,
613
+ where each tensor is of shape (sequence_length_i, embedding_dim).
614
+ device (`Optional[Union[str, torch.device]]`, *optional*): Device to use for computation. If not
615
+ provided, uses `get_torch_device("auto")`.
616
+ """
617
+ if not importlib.util.find_spec("fast_plaid"):
618
+ raise ImportError("FastPlaid is not installed. Please install it with `pip install fast-plaid`.")
619
+
620
+ fast_plaid_index = search.FastPlaid(index="index")
621
+ device = device or get_torch_device("auto")
622
+ fast_plaid_index.create(documents_embeddings=[d.to(device).to(torch.float32) for d in ps])
623
+ return fast_plaid_index
624
+
625
+ def get_n_patches(
626
+ self,
627
+ image_size: Tuple[int, int],
628
+ spatial_merge_size: int,
629
+ ) -> Tuple[int, int]:
630
+ """
631
+ Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
632
+ size (height, width) with the given patch size.
633
+ The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
634
+ as a `Qwen2VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
635
+ """
636
+ patch_size = self.image_processor.patch_size
637
+
638
+ height_new, width_new = smart_resize(
639
+ width=image_size[0],
640
+ height=image_size[1],
641
+ factor=patch_size * self.image_processor.merge_size,
642
+ min_pixels=self.image_processor.size["shortest_edge"],
643
+ max_pixels=self.image_processor.size["longest_edge"],
644
+ )
645
+
646
+ n_patches_x = width_new // patch_size // spatial_merge_size
647
+ n_patches_y = height_new // patch_size // spatial_merge_size
648
+
649
+ return n_patches_x, n_patches_y
650
+
651
+ def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
652
+ return batch_images.input_ids == self.image_token_id
653
+
654
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
655
+ vision_data = {}
656
+ if image_sizes is not None:
657
+ images_kwargs = ColVec1ProcessorKwargs._defaults.get("images_kwargs", {})
658
+ images_kwargs.update(kwargs)
659
+ merge_size = images_kwargs.get("merge_size", None) or getattr(
660
+ self.image_processor, "merge_size", None
661
+ ) or getattr(self.image_processor, "spatial_merge_size", None)
662
+ if merge_size is None:
663
+ raise ValueError("Qwen3VL image processor is missing `merge_size`/`spatial_merge_size`.")
664
+
665
+ num_image_patches = [
666
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
667
+ for image_size in image_sizes
668
+ ]
669
+ num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
670
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
671
+
672
+ video_sizes = kwargs.pop("video_sizes", None)
673
+ if video_sizes is not None:
674
+ videos_kwargs = ColVec1ProcessorKwargs._defaults.get("videos_kwargs", {})
675
+ videos_kwargs.update(kwargs)
676
+ merge_size = videos_kwargs.get("merge_size", None) or getattr(self.video_processor, "merge_size", None)
677
+ if merge_size is None:
678
+ raise ValueError("Qwen3VL video processor is missing `merge_size`.")
679
+
680
+ num_video_patches = [
681
+ self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) for video_size in video_sizes
682
+ ]
683
+ num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
684
+ vision_data.update({"num_video_tokens": num_video_tokens, "num_video_patches": num_video_patches})
685
+
686
+ return MultiModalData(**vision_data)
687
+
688
+ @property
689
+ def model_input_names(self) -> list[str]:
690
+ return [
691
+ "input_ids",
692
+ "attention_mask",
693
+ "pixel_values",
694
+ "image_grid_thw",
695
+ "pixel_values_videos",
696
+ "video_grid_thw",
697
+ ]
698
+
699
+ @property
700
+ def query_augmentation_token(self) -> str:
701
+ return self.tokenizer.pad_token
702
+
703
+ def get_video_mask(self, batch_videos: BatchFeature) -> torch.Tensor:
704
+ return batch_videos.input_ids == self.video_token_id
705
+
706
+ def _calculate_timestamps(
707
+ self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2
708
+ ) -> list[float]:
709
+ if not isinstance(indices, list):
710
+ indices = indices.tolist()
711
+ if len(indices) % merge_size != 0:
712
+ indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size))
713
+ timestamps = [idx / video_fps for idx in indices]
714
+ timestamps = [
715
+ (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size)
716
+ ]
717
+ return timestamps
718
+
719
+
720
+ __all__ = ["ColVec1Processor", "ColVec1ProcessorKwargs"]
processor_config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_processor": {
3
+ "data_format": "channels_first",
4
+ "do_convert_rgb": true,
5
+ "do_normalize": true,
6
+ "do_rescale": true,
7
+ "do_resize": true,
8
+ "image_mean": [
9
+ 0.5,
10
+ 0.5,
11
+ 0.5
12
+ ],
13
+ "image_processor_type": "Qwen2VLImageProcessorFast",
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "max_pixels": 1310720,
20
+ "merge_size": 2,
21
+ "patch_size": 16,
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "longest_edge": 1310720,
26
+ "shortest_edge": 65536
27
+ },
28
+ "temporal_patch_size": 2
29
+ },
30
+ "processor_class": "ColVec1Processor",
31
+ "query_prefix": "",
32
+ "video_processor": {
33
+ "data_format": "channels_first",
34
+ "default_to_square": true,
35
+ "do_convert_rgb": true,
36
+ "do_normalize": true,
37
+ "do_rescale": true,
38
+ "do_resize": true,
39
+ "do_sample_frames": true,
40
+ "fps": 2,
41
+ "image_mean": [
42
+ 0.5,
43
+ 0.5,
44
+ 0.5
45
+ ],
46
+ "image_std": [
47
+ 0.5,
48
+ 0.5,
49
+ 0.5
50
+ ],
51
+ "max_frames": 768,
52
+ "max_pixels": 2621440,
53
+ "merge_size": 2,
54
+ "min_frames": 4,
55
+ "patch_size": 16,
56
+ "resample": 3,
57
+ "rescale_factor": 0.00392156862745098,
58
+ "return_metadata": false,
59
+ "size": {
60
+ "longest_edge": 2621440,
61
+ "shortest_edge": 4096
62
+ },
63
+ "temporal_patch_size": 2,
64
+ "video_processor_type": "Qwen3VLVideoProcessor"
65
+ },
66
+ "video_prompt_prefix": "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>Describe the video.",
67
+ "video_prompt_suffix": "<|im_end|><|endoftext|>",
68
+ "visual_prompt_prefix": "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.",
69
+ "visual_prompt_suffix": "<|im_end|><|endoftext|>",
70
+ "auto_map": {
71
+ "AutoProcessor": "processor.ColVec1Processor"
72
+ }
73
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87a7830d63fcf43bf241c3c5242e96e62dd3fdc29224ca26fed8ea333db72de4
3
+ size 19989343
tokenizer_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "audio_bos_token": "<|audio_start|>",
4
+ "audio_eos_token": "<|audio_end|>",
5
+ "audio_token": "<|audio_pad|>",
6
+ "backend": "tokenizers",
7
+ "bos_token": null,
8
+ "clean_up_tokenization_spaces": false,
9
+ "eos_token": "<|im_end|>",
10
+ "errors": "replace",
11
+ "image_token": "<|image_pad|>",
12
+ "is_local": false,
13
+ "model_max_length": 262144,
14
+ "model_specific_special_tokens": {
15
+ "audio_bos_token": "<|audio_start|>",
16
+ "audio_eos_token": "<|audio_end|>",
17
+ "audio_token": "<|audio_pad|>",
18
+ "image_token": "<|image_pad|>",
19
+ "video_token": "<|video_pad|>",
20
+ "vision_bos_token": "<|vision_start|>",
21
+ "vision_eos_token": "<|vision_end|>"
22
+ },
23
+ "pad_token": "<|endoftext|>",
24
+ "pretokenize_regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
25
+ "processor_class": "ColVec1Processor",
26
+ "split_special_tokens": false,
27
+ "tokenizer_class": "TokenizersBackend",
28
+ "unk_token": null,
29
+ "video_token": "<|video_pad|>",
30
+ "vision_bos_token": "<|vision_start|>",
31
+ "vision_eos_token": "<|vision_end|>"
32
+ }