Manli commited on
Commit
4974490
1 Parent(s): e2624bb

initial commit

Browse files
README.md CHANGED
@@ -1,3 +1,86 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language:
4
+ - en
5
+ pipeline_tag: image-text-to-text
6
+ ---
7
+
8
+
9
+ # Model description
10
+ We are excited to announce the continuation and rebranding of our **BLIP series** into **XGen-MM**, to be better aligned with Salesforce's unified XGen initiative for large foundation models! This rebranding marks a significant step in our ongoing development of cutting-edge multimodal technologies.
11
+
12
+ `XGen-MM` is a series of the latest foundational Large Multimodal Models (LMMs) developed by Salesforce AI Research. This series advances upon the successful designs of the `BLIP` series, incorporating fundamental enhancements that ensure a more robust and superior foundation. These models have been trained at scale on high-quality image caption datasets and interleaved image-text data.
13
+
14
+ In the v1.1 (08/2024) release, we present a series of XGen-MM models including:
15
+ - Base model `xgen-mm-phi3-mini-base-r-v1.1`
16
+ - Single-image instruct model `xgen-mm-phi3-mini-instruct-r-v1.1`
17
+ - Multi-image instruct model `xgen-mm-phi3-mini-instruct-multi-r-v1.1`
18
+ - DPO instruct model `xgen-mm-phi3-mini-instruct-dpo-r-v1.1`
19
+
20
+ In addition to the models, we are also releasing a series of datasets for multi-modal pre-training, including:
21
+ - [MINT-1T: Scaling Open-Source Multimodal Data by 10x: A Multimodal Dataset with One Trillion Tokens](https://arxiv.org/abs/2406.11271)
22
+ - BLIP3-OCR-200M: a dataset with dense OCR annotations.
23
+ - BLIP3-GROUNDING-50M: a dataset for enhancing the ability to ground semantic concepts in images.
24
+ - BLIP3-KALE-300M (stay tuned): a large-scale curated high-quality caption dataset.
25
+
26
+ # Data
27
+
28
+
29
+ # Results
30
+
31
+ ### Base model (without instruction tuning)
32
+
33
+ ### Instruct model
34
+
35
+ ### DPO model
36
+
37
+
38
+ # How to use
39
+
40
+ Please check out our [inference notebook](demo.ipynb) for example code to use our model. We also provide example script for [batch inference](batch_inference.ipynb).
41
+
42
+ # Reproducibility:
43
+
44
+ Our evaluation is implemented based on [open-compass/VLMEvalKit](https://github.com/open-compass/VLMEvalKit). We will create a PR to that repo to support XGen-MM evaluation.
45
+
46
+
47
+ # Bias, Risks, Limitations, and Ethical Considerations
48
+ The main data sources are from the internet, including webpages,
49
+ image stock sites, and curated datasets released by the research community. We have excluded certain data, such as LAION, due to known CSAM concerns.
50
+ The model may be subject to bias from the original data source, as well as bias from LLMs and commercial APIs.
51
+ We strongly recommend users assess safety and fairness before applying to downstream applications.
52
+
53
+
54
+ # License
55
+
56
+ Our code and weights are released under the Creative Commons Attribution Non Commercial 4.0 [LICENSE](LICENSE.txt). Please fill out a form at [here](https://forms.gle/ffPc9oZC2ZGeJ1N68) to consult the commercial use of model weights.
57
+
58
+ # Code acknowledgement
59
+ Our training code is based on [OpenFlamingo: An open-source framework for training large multimodal models.](https://github.com/mlfoundations/open_flamingo), and part of our data preprocessing code is adapted from [LLaVA](https://github.com/haotian-liu/LLaVA).
60
+ Our evaluation code is based on [VLMEvalKit: Open-source evaluation toolkit of large vision-language models (LVLMs)](https://github.com/open-compass/VLMEvalKit).
61
+
62
+ We thank the authors for their open-source implementations.
63
+
64
+
65
+ # Citation
66
+ ```
67
+ @misc{xgen_mm_phi3_mini,
68
+ title={xgen-mm-phi3-mini-instruct Model Card},
69
+ url={https://huggingface.co/Salesforce/xgen-mm-phi3-mini-instruct-r-v1},
70
+ author={Salesforce AI Research},
71
+ month={May},
72
+ year={2024}
73
+ }
74
+ ```
75
+
76
+ # Troubleshoot
77
+
78
+ 1. If you missed any packages, please consider the following
79
+
80
+ ```
81
+ pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
82
+ pip install open_clip_torch==2.24.0
83
+ pip install einops
84
+ pip install einops-exts
85
+ pip install transformers==4.41.1
86
+ ```
added_tokens.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<pad>": 32011,
3
+ "<|assistant|>": 32001,
4
+ "<|endoftext|>": 32000,
5
+ "<|end|>": 32007,
6
+ "<|placeholder1|>": 32002,
7
+ "<|placeholder2|>": 32003,
8
+ "<|placeholder3|>": 32004,
9
+ "<|placeholder4|>": 32005,
10
+ "<|placeholder5|>": 32008,
11
+ "<|placeholder6|>": 32009,
12
+ "<|system|>": 32006,
13
+ "<|user|>": 32010
14
+ }
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "XGenMMModelForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_xgenmm.XGenMMConfig",
7
+ "AutoModelForVision2Seq": "modeling_xgenmm.XGenMMModelForConditionalGeneration"
8
+ },
9
+ "model_type": "xgenmm",
10
+ "text_config": {
11
+ "initial_tokenizer_len": 32012,
12
+ "model_type": "phi3",
13
+ "sliding_window": 2047,
14
+ "torch_dtype": "bfloat16"
15
+ },
16
+ "torch_dtype": "float32",
17
+ "transformers_version": "4.41.1",
18
+ "vision_encoder_config": {
19
+ "anyres_patch_sampling": true,
20
+ "image_aspect_ratio": "anyres",
21
+ "model_type": "xgenmm_vision_encoder"
22
+ },
23
+ "vision_tokenizer_config": {
24
+ "model_type": "xgenmm_vision_tokenizer"
25
+ }
26
+ }
configuration_xgenmm.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers import logging
3
+ from transformers import CONFIG_MAPPING
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+ class XGenMMVisionEncoderConfig(PretrainedConfig):
8
+ model_type = "xgenmm_vision_encoder"
9
+
10
+ def __init__(self,
11
+ model_name: str = 'google/siglip-so400m-patch14-384',
12
+ **kwargs):
13
+ self.model_name = model_name
14
+ super().__init__(**kwargs)
15
+
16
+
17
+ class XGenMMVisionTokenizerConfig(PretrainedConfig):
18
+ model_type = "xgenmm_vision_tokenizer"
19
+
20
+ def __init__(self,
21
+ vis_feature_dim: int = 1152,
22
+ lang_embedding_dim: int = 3072,
23
+ num_vis_tokens: int = 128,
24
+ image_aspect_ratio: str = 'none',
25
+ **kwargs):
26
+ self.vis_feature_dim = vis_feature_dim
27
+ self.lang_embedding_dim = lang_embedding_dim
28
+ self.num_vis_tokens = num_vis_tokens
29
+ self.image_aspect_ratio = image_aspect_ratio
30
+ super().__init__(**kwargs)
31
+
32
+
33
+ class XGenMMConfig(PretrainedConfig):
34
+ model_type = "xgenmm"
35
+
36
+ def __init__(self,
37
+ vision_encoder_config: dict = None,
38
+ vision_tokenizer_config: dict = None,
39
+ text_config: dict = None,
40
+ **kwargs):
41
+
42
+ if vision_encoder_config is None:
43
+ vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True}
44
+ logger.info("vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values.")
45
+
46
+ if vision_tokenizer_config is None:
47
+ vision_tokenizer_config = {}
48
+ logger.info("vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values.")
49
+
50
+ if text_config is None:
51
+ text_config = {
52
+ 'initial_tokenizer_len':32012,
53
+ 'pad_token_id':32011,
54
+ 'bos_token_id':1,
55
+ 'eos_token_id':32000,
56
+ 'vocab_size': 32064,
57
+ 'hidden_size': 3072,
58
+ 'intermediate_size': 8192,
59
+ 'num_hidden_layers': 32,
60
+ 'num_attention_heads': 32,
61
+ 'num_key_value_heads': 32,
62
+ 'resid_pdrop': 0.0,
63
+ 'embd_pdrop': 0.0,
64
+ 'attention_dropout': 0.0,
65
+ 'hidden_act': 'silu',
66
+ 'max_position_embeddings': 4096,
67
+ 'original_max_position_embeddings': 4096,
68
+ 'initializer_range': 0.02,
69
+ 'rms_norm_eps': 1e-05,
70
+ 'use_cache': True,
71
+ 'rope_theta': 10000.0,
72
+ 'rope_scaling': None,
73
+ 'sliding_window': 2047,
74
+ 'return_dict': True,
75
+ 'output_hidden_states': False,
76
+ 'output_attentions': False,
77
+ 'torchscript': False,
78
+ 'torch_dtype': 'bfloat16',
79
+ 'use_bfloat16': False,
80
+ 'tf_legacy_loss': False,
81
+ 'pruned_heads': {},
82
+ 'tie_word_embeddings': False,
83
+ 'chunk_size_feed_forward': 0,
84
+ 'is_encoder_decoder': False,
85
+ 'is_decoder': False,
86
+ 'cross_attention_hidden_size': None,
87
+ 'add_cross_attention': False,
88
+ 'tie_encoder_decoder': False,
89
+ 'max_length': 20,
90
+ 'min_length': 0,
91
+ 'do_sample': False,
92
+ 'early_stopping': False,
93
+ 'num_beams': 1,
94
+ 'num_beam_groups': 1,
95
+ 'diversity_penalty': 0.0,
96
+ 'temperature': 1.0,
97
+ 'top_k': 50,
98
+ 'top_p': 1.0,
99
+ 'typical_p': 1.0,
100
+ 'repetition_penalty': 1.0,
101
+ 'length_penalty': 1.0,
102
+ 'no_repeat_ngram_size': 0,
103
+ 'encoder_no_repeat_ngram_size': 0,
104
+ 'bad_words_ids': None,
105
+ 'num_return_sequences': 1,
106
+ 'output_scores': False,
107
+ 'return_dict_in_generate': False,
108
+ 'forced_bos_token_id': None,
109
+ 'forced_eos_token_id': None,
110
+ 'remove_invalid_values': False,
111
+ 'exponential_decay_length_penalty': None,
112
+ 'suppress_tokens': None,
113
+ 'begin_suppress_tokens': None,
114
+ 'finetuning_task': None,
115
+ 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},
116
+ 'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
117
+ 'tokenizer_class': None,
118
+ 'prefix': None,
119
+ 'bos_token_id': 1,
120
+ 'pad_token_id': 32000,
121
+ 'eos_token_id': 32000,
122
+ 'sep_token_id': None,
123
+ 'decoder_start_token_id': None,
124
+ 'task_specific_params': None,
125
+ 'problem_type': None,
126
+ 'model_type': 'phi3'
127
+ }
128
+ logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).")
129
+
130
+ self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
131
+
132
+ self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(**vision_tokenizer_config)
133
+
134
+ text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3"
135
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
136
+
137
+ for key in ['initial_tokenizer_len', 'pad_token_id']:
138
+ if key not in self.text_config.to_dict():
139
+ raise ValueError(f"The key `{key}` is missing in the text_config.")
140
+
141
+ super().__init__(**kwargs)
142
+
143
+ @classmethod
144
+ def from_vision_encoder_vision_tokenizer_text_configs(
145
+ cls,
146
+ vision_encoder_config: XGenMMVisionEncoderConfig,
147
+ vision_tokenizer_config: XGenMMVisionTokenizerConfig,
148
+ text_config: PretrainedConfig,
149
+ **kwargs):
150
+
151
+ return cls(
152
+ vision_encoder_config=vision_encoder_config.to_dict(),
153
+ vision_tokenizer_config=vision_tokenizer_config.to_dict(),
154
+ text_config=text_config.to_dict(),
155
+ **kwargs,
156
+ )
157
+
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 32000,
5
+ "pad_token_id": 32000,
6
+ "transformers_version": "4.41.1"
7
+ }
image_processing_blip_3.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3
+ import torchvision.transforms.functional as F
4
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
5
+ CenterCrop, ColorJitter, Grayscale
6
+ import numbers
7
+ import torch
8
+ import ast
9
+ import math
10
+ import numpy as np
11
+ from PIL import Image
12
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
13
+ from transformers.image_utils import ImageInput
14
+ from transformers.utils import TensorType
15
+
16
+ from utils import expand2square
17
+
18
+
19
+ class Blip3ImageProcessor(BaseImageProcessor):
20
+
21
+ def __init__(
22
+ self,
23
+ do_resize: bool = True,
24
+ resize_mode: str = "squash",
25
+ interpolation_mode: str = "bicubic",
26
+ size: Union[Tuple[int, int], List[int]] = None,
27
+ image_mean: Optional[Union[float, List[float]]] = None,
28
+ image_std: Optional[Union[float, List[float]]] = None,
29
+ **kwargs,
30
+ ) -> None:
31
+ super().__init__(**kwargs)
32
+ self.do_resize = do_resize
33
+ self.resize_mode = resize_mode
34
+ self.interpolation_mode = interpolation_mode
35
+ self.size = size if size is not None else (384, 384)
36
+ self.grids = None
37
+
38
+ self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
39
+ self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
40
+
41
+
42
+ @classmethod
43
+ def resize(cls, image_size, resize_mode, interpolation='bicubic', fill_color=0):
44
+ interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
45
+ if resize_mode == 'longest':
46
+ transforms = [
47
+ ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
48
+ CenterCropOrPad(image_size, fill=fill_color)
49
+ ]
50
+ elif resize_mode == 'squash':
51
+ if isinstance(image_size, int):
52
+ image_size = (image_size, image_size)
53
+ transforms = [
54
+ Resize(image_size, interpolation=interpolation_mode),
55
+ ]
56
+ else:
57
+ assert resize_mode == 'shortest'
58
+ if not isinstance(image_size, (tuple, list)):
59
+ image_size = (image_size, image_size)
60
+ if image_size[0] == image_size[1]:
61
+ # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
62
+ transforms = [
63
+ Resize(image_size[0], interpolation=interpolation_mode)
64
+ ]
65
+ else:
66
+ # resize shortest edge to matching target dim for non-square target
67
+ transforms = [ResizeKeepRatio(image_size)]
68
+ transforms += [CenterCrop(image_size)]
69
+ return transforms
70
+
71
+ @classmethod
72
+ def convert_rgb(cls, image):
73
+ return image.convert("RGB")
74
+
75
+
76
+ def _preprocess(self,
77
+ images: ImageInput
78
+ ) -> torch.Tensor:
79
+ transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
80
+ transforms.extend([
81
+ self.convert_rgb,
82
+ ToTensor(),
83
+ Normalize(mean=self.image_mean, std=self.image_std)
84
+ ])
85
+ composed_transforms = Compose(transforms)
86
+ images_tensor = composed_transforms(images)
87
+ return images_tensor
88
+
89
+ def preprocess(self,
90
+ images: ImageInput,
91
+ return_tensors: Optional[Union[str, TensorType]] = None,
92
+ **kwargs) -> BatchFeature:
93
+ if 'image_aspect_ratio' in kwargs:
94
+ image_aspect_ratio = kwargs['image_aspect_ratio']
95
+ else:
96
+ image_aspect_ratio = 'none'
97
+ new_images = []
98
+ if image_aspect_ratio == 'pad':
99
+ for image in images:
100
+ image = expand2square(image, tuple(int(x*255) for x in self.image_mean))
101
+ image = self._preprocess(image)
102
+ new_images.append(image)
103
+ elif image_aspect_ratio == 'anyres':
104
+ for image in images:
105
+ image = process_anyres_image(image, self._preprocess, self.size,
106
+ self.grids)
107
+ new_images.append(image)
108
+ else:
109
+ for image in images:
110
+ image = self._preprocess(image)
111
+ new_images.append(image)
112
+
113
+ if all(x.shape == new_images[0].shape for x in new_images):
114
+ new_images = torch.stack(new_images, dim=0)
115
+ if image_aspect_ratio == 'anyres':
116
+ new_images = BatchFeature(data={"pixel_values": new_images}, tensor_type=return_tensors)
117
+ else:
118
+ new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(1).unsqueeze(0)}, tensor_type=return_tensors)
119
+
120
+ return new_images
121
+
122
+
123
+ class ResizeKeepRatio:
124
+ """ Resize and Keep Ratio
125
+
126
+ Copy & paste from `timm`
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ size,
132
+ longest=0.,
133
+ interpolation=InterpolationMode.BICUBIC,
134
+ random_scale_prob=0.,
135
+ random_scale_range=(0.85, 1.05),
136
+ random_aspect_prob=0.,
137
+ random_aspect_range=(0.9, 1.11)
138
+ ):
139
+ if isinstance(size, (list, tuple)):
140
+ self.size = tuple(size)
141
+ else:
142
+ self.size = (size, size)
143
+ self.interpolation = interpolation
144
+ self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
145
+ self.random_scale_prob = random_scale_prob
146
+ self.random_scale_range = random_scale_range
147
+ self.random_aspect_prob = random_aspect_prob
148
+ self.random_aspect_range = random_aspect_range
149
+
150
+ @staticmethod
151
+ def get_params(
152
+ img,
153
+ target_size,
154
+ longest,
155
+ random_scale_prob=0.,
156
+ random_scale_range=(0.85, 1.05),
157
+ random_aspect_prob=0.,
158
+ random_aspect_range=(0.9, 1.11)
159
+ ):
160
+ """Get parameters
161
+ """
162
+ source_size = img.size[::-1] # h, w
163
+ h, w = source_size
164
+ target_h, target_w = target_size
165
+ ratio_h = h / target_h
166
+ ratio_w = w / target_w
167
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
168
+ if random_scale_prob > 0 and random.random() < random_scale_prob:
169
+ ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
170
+ ratio_factor = (ratio_factor, ratio_factor)
171
+ else:
172
+ ratio_factor = (1., 1.)
173
+ if random_aspect_prob > 0 and random.random() < random_aspect_prob:
174
+ aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
175
+ ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
176
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
177
+ return size
178
+
179
+ def __call__(self, img):
180
+ """
181
+ Args:
182
+ img (PIL Image): Image to be cropped and resized.
183
+
184
+ Returns:
185
+ PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
186
+ """
187
+ size = self.get_params(
188
+ img, self.size, self.longest,
189
+ self.random_scale_prob, self.random_scale_range,
190
+ self.random_aspect_prob, self.random_aspect_range
191
+ )
192
+ img = F.resize(img, size, self.interpolation)
193
+ return img
194
+
195
+ def __repr__(self):
196
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
197
+ format_string += f', interpolation={self.interpolation})'
198
+ format_string += f', longest={self.longest:.3f})'
199
+ return format_string
200
+
201
+ def _setup_size(size, error_msg):
202
+ if isinstance(size, numbers.Number):
203
+ return int(size), int(size)
204
+
205
+ if isinstance(size, Sequence) and len(size) == 1:
206
+ return size[0], size[0]
207
+
208
+ if len(size) != 2:
209
+ raise ValueError(error_msg)
210
+
211
+ return size
212
+
213
+ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
214
+ """Center crops and/or pads the given image.
215
+ If the image is torch Tensor, it is expected
216
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
217
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
218
+
219
+ Args:
220
+ img (PIL Image or Tensor): Image to be cropped.
221
+ output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
222
+ it is used for both directions.
223
+ fill (int, Tuple[int]): Padding color
224
+
225
+ Returns:
226
+ PIL Image or Tensor: Cropped image.
227
+ """
228
+ if isinstance(output_size, numbers.Number):
229
+ output_size = (int(output_size), int(output_size))
230
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
231
+ output_size = (output_size[0], output_size[0])
232
+
233
+ _, image_height, image_width = F.get_dimensions(img)
234
+ crop_height, crop_width = output_size
235
+
236
+ if crop_width > image_width or crop_height > image_height:
237
+ padding_ltrb = [
238
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
239
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
240
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
241
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
242
+ ]
243
+ img = F.pad(img, padding_ltrb, fill=fill)
244
+ _, image_height, image_width = F.get_dimensions(img)
245
+ if crop_width == image_width and crop_height == image_height:
246
+ return img
247
+
248
+ crop_top = int(round((image_height - crop_height) / 2.0))
249
+ crop_left = int(round((image_width - crop_width) / 2.0))
250
+ return F.crop(img, crop_top, crop_left, crop_height, crop_width)
251
+
252
+ class CenterCropOrPad(torch.nn.Module):
253
+ """Crops the given image at the center.
254
+ If the image is torch Tensor, it is expected
255
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
256
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
257
+
258
+ Args:
259
+ size (sequence or int): Desired output size of the crop. If size is an
260
+ int instead of sequence like (h, w), a square crop (size, size) is
261
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
262
+ """
263
+
264
+ def __init__(self, size, fill=0):
265
+ super().__init__()
266
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
267
+ self.fill = fill
268
+
269
+ def forward(self, img):
270
+ """
271
+ Args:
272
+ img (PIL Image or Tensor): Image to be cropped.
273
+
274
+ Returns:
275
+ PIL Image or Tensor: Cropped image.
276
+ """
277
+ return center_crop_or_pad(img, self.size, fill=self.fill)
278
+
279
+ def __repr__(self) -> str:
280
+ return f"{self.__class__.__name__}(size={self.size})"
281
+
282
+ def process_anyres_image(image, processor, processor_size, grid_pinpoints):
283
+ """
284
+ Process an image with variable resolutions.
285
+
286
+ Args:
287
+ image (PIL.Image.Image): The input image to be processed.
288
+ processor: The image processor object.
289
+ processor_size (tuple, list): The size of the image processor.
290
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
291
+
292
+ Returns:
293
+ torch.Tensor: A tensor containing the processed image patches.
294
+ """
295
+ # FIXME: determine grid_pinpoints from image sizes.
296
+ if type(grid_pinpoints) is list:
297
+ possible_resolutions = grid_pinpoints
298
+ else:
299
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
300
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
301
+ image_padded = resize_and_pad_image(image, best_resolution)
302
+
303
+ # processor_size = processor.transforms[0].size
304
+ patches = divide_to_patches(image_padded, processor_size[0])
305
+
306
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
307
+
308
+ image_patches = [image_original_resize] + patches
309
+ image_patches = [processor(image_patch)
310
+ for image_patch in image_patches]
311
+ return torch.stack(image_patches, dim=0)
312
+
313
+
314
+ def select_best_resolution(original_size, possible_resolutions):
315
+ """
316
+ Selects the best resolution from a list of possible resolutions based on the original size.
317
+
318
+ Args:
319
+ original_size (tuple): The original size of the image in the format (width, height).
320
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
321
+
322
+ Returns:
323
+ tuple: The best fit resolution in the format (width, height).
324
+ """
325
+ original_width, original_height = original_size
326
+ best_fit = None
327
+ max_effective_resolution = 0
328
+ min_wasted_resolution = float('inf')
329
+
330
+ for width, height in possible_resolutions:
331
+ scale = min(width / original_width, height / original_height)
332
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
333
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
334
+ wasted_resolution = (width * height) - effective_resolution
335
+
336
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
337
+ max_effective_resolution = effective_resolution
338
+ min_wasted_resolution = wasted_resolution
339
+ best_fit = (width, height)
340
+
341
+ return best_fit
342
+
343
+ def resize_and_pad_image(image, target_resolution):
344
+ """
345
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
346
+
347
+ Args:
348
+ image (PIL.Image.Image): The input image.
349
+ target_resolution (tuple): The target resolution (width, height) of the image.
350
+
351
+ Returns:
352
+ PIL.Image.Image: The resized and padded image.
353
+ """
354
+ original_width, original_height = image.size
355
+ target_width, target_height = target_resolution
356
+
357
+ scale_w = target_width / original_width
358
+ scale_h = target_height / original_height
359
+
360
+ if scale_w < scale_h:
361
+ new_width = target_width
362
+ new_height = min(math.ceil(original_height * scale_w), target_height)
363
+ else:
364
+ new_height = target_height
365
+ new_width = min(math.ceil(original_width * scale_h), target_width)
366
+
367
+ # Resize the image
368
+ resized_image = image.resize((new_width, new_height))
369
+
370
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
371
+ paste_x = (target_width - new_width) // 2
372
+ paste_y = (target_height - new_height) // 2
373
+ new_image.paste(resized_image, (paste_x, paste_y))
374
+
375
+ return new_image
376
+
377
+ def divide_to_patches(image, patch_size):
378
+ """
379
+ Divides an image into patches of a specified size.
380
+
381
+ Args:
382
+ image (PIL.Image.Image): The input image.
383
+ patch_size (int): The size of each patch.
384
+
385
+ Returns:
386
+ list: A list of PIL.Image.Image objects representing the patches.
387
+ """
388
+ patches = []
389
+ width, height = image.size
390
+ for i in range(0, height, patch_size):
391
+ for j in range(0, width, patch_size):
392
+ box = (j, i, j + patch_size, i + patch_size)
393
+ patch = image.crop(box)
394
+ patches.append(patch)
395
+
396
+ return patches
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45173ee95995173da1d019e4a66c5506e55f218380e7b8e31c78d53dbe75ced6
3
+ size 4962660968
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea05c36d3ff8eb154a24673a2ad4a068b0f35ef53df674501a160e5aa680ea73
3
+ size 4983112136
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36b1b5c0001ce8a8bc13930b4d613025ad92a708f126948bff2faecd3acae8da
3
+ size 4983112168
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25c7d1851170158d175cfa105f704daf18dc3f1e2d997a5bc5e8af883b93d2a0
3
+ size 2508236156
model.safetensors.index.json ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 17437028876
4
+ },
5
+ "weight_map": {
6
+ "vlm.lang_model.lm_head.additional_fc.bias": "model-00004-of-00004.safetensors",
7
+ "vlm.lang_model.lm_head.additional_fc.weight": "model-00004-of-00004.safetensors",
8
+ "vlm.lang_model.lm_head.bias": "model-00004-of-00004.safetensors",
9
+ "vlm.lang_model.lm_head.weight": "model-00004-of-00004.safetensors",
10
+ "vlm.lang_model.model.embed_tokens.additional_embedding.weight": "model-00001-of-00004.safetensors",
11
+ "vlm.lang_model.model.embed_tokens.weight": "model-00001-of-00004.safetensors",
12
+ "vlm.lang_model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "vlm.lang_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
14
+ "vlm.lang_model.model.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
15
+ "vlm.lang_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
16
+ "vlm.lang_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
17
+ "vlm.lang_model.model.layers.0.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
18
+ "vlm.lang_model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
19
+ "vlm.lang_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
20
+ "vlm.lang_model.model.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
21
+ "vlm.lang_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
22
+ "vlm.lang_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
23
+ "vlm.lang_model.model.layers.1.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
24
+ "vlm.lang_model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
25
+ "vlm.lang_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
26
+ "vlm.lang_model.model.layers.10.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
27
+ "vlm.lang_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
28
+ "vlm.lang_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
29
+ "vlm.lang_model.model.layers.10.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
30
+ "vlm.lang_model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
31
+ "vlm.lang_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
32
+ "vlm.lang_model.model.layers.11.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
33
+ "vlm.lang_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
34
+ "vlm.lang_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
35
+ "vlm.lang_model.model.layers.11.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
36
+ "vlm.lang_model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "vlm.lang_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
38
+ "vlm.lang_model.model.layers.12.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
39
+ "vlm.lang_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
40
+ "vlm.lang_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
41
+ "vlm.lang_model.model.layers.12.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
42
+ "vlm.lang_model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
43
+ "vlm.lang_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
44
+ "vlm.lang_model.model.layers.13.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
45
+ "vlm.lang_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
46
+ "vlm.lang_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
47
+ "vlm.lang_model.model.layers.13.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
48
+ "vlm.lang_model.model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
49
+ "vlm.lang_model.model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
50
+ "vlm.lang_model.model.layers.14.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
51
+ "vlm.lang_model.model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
52
+ "vlm.lang_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
53
+ "vlm.lang_model.model.layers.14.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
54
+ "vlm.lang_model.model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
55
+ "vlm.lang_model.model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
56
+ "vlm.lang_model.model.layers.15.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
57
+ "vlm.lang_model.model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
58
+ "vlm.lang_model.model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
59
+ "vlm.lang_model.model.layers.15.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
60
+ "vlm.lang_model.model.layers.16.input_layernorm.weight": "model-00003-of-00004.safetensors",
61
+ "vlm.lang_model.model.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
62
+ "vlm.lang_model.model.layers.16.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
63
+ "vlm.lang_model.model.layers.16.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
64
+ "vlm.lang_model.model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
65
+ "vlm.lang_model.model.layers.16.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
66
+ "vlm.lang_model.model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
67
+ "vlm.lang_model.model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
68
+ "vlm.lang_model.model.layers.17.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
69
+ "vlm.lang_model.model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
70
+ "vlm.lang_model.model.layers.17.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
71
+ "vlm.lang_model.model.layers.17.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
72
+ "vlm.lang_model.model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
73
+ "vlm.lang_model.model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
74
+ "vlm.lang_model.model.layers.18.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
75
+ "vlm.lang_model.model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
76
+ "vlm.lang_model.model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
77
+ "vlm.lang_model.model.layers.18.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
78
+ "vlm.lang_model.model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
79
+ "vlm.lang_model.model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
80
+ "vlm.lang_model.model.layers.19.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
81
+ "vlm.lang_model.model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
82
+ "vlm.lang_model.model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
83
+ "vlm.lang_model.model.layers.19.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
84
+ "vlm.lang_model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
85
+ "vlm.lang_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
86
+ "vlm.lang_model.model.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
87
+ "vlm.lang_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
88
+ "vlm.lang_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
89
+ "vlm.lang_model.model.layers.2.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
90
+ "vlm.lang_model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
91
+ "vlm.lang_model.model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
92
+ "vlm.lang_model.model.layers.20.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
93
+ "vlm.lang_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
94
+ "vlm.lang_model.model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
95
+ "vlm.lang_model.model.layers.20.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
96
+ "vlm.lang_model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
97
+ "vlm.lang_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
98
+ "vlm.lang_model.model.layers.21.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
99
+ "vlm.lang_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
100
+ "vlm.lang_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
101
+ "vlm.lang_model.model.layers.21.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
102
+ "vlm.lang_model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
103
+ "vlm.lang_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
104
+ "vlm.lang_model.model.layers.22.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
105
+ "vlm.lang_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
106
+ "vlm.lang_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
107
+ "vlm.lang_model.model.layers.22.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
108
+ "vlm.lang_model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
109
+ "vlm.lang_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
110
+ "vlm.lang_model.model.layers.23.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
111
+ "vlm.lang_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
112
+ "vlm.lang_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
113
+ "vlm.lang_model.model.layers.23.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
114
+ "vlm.lang_model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
115
+ "vlm.lang_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
116
+ "vlm.lang_model.model.layers.24.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
117
+ "vlm.lang_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
118
+ "vlm.lang_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
119
+ "vlm.lang_model.model.layers.24.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
120
+ "vlm.lang_model.model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
121
+ "vlm.lang_model.model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
122
+ "vlm.lang_model.model.layers.25.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
123
+ "vlm.lang_model.model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
124
+ "vlm.lang_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
125
+ "vlm.lang_model.model.layers.25.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
126
+ "vlm.lang_model.model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
127
+ "vlm.lang_model.model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
128
+ "vlm.lang_model.model.layers.26.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
129
+ "vlm.lang_model.model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
130
+ "vlm.lang_model.model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
131
+ "vlm.lang_model.model.layers.26.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
132
+ "vlm.lang_model.model.layers.27.input_layernorm.weight": "model-00004-of-00004.safetensors",
133
+ "vlm.lang_model.model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
134
+ "vlm.lang_model.model.layers.27.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
135
+ "vlm.lang_model.model.layers.27.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
136
+ "vlm.lang_model.model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
137
+ "vlm.lang_model.model.layers.27.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
138
+ "vlm.lang_model.model.layers.28.input_layernorm.weight": "model-00004-of-00004.safetensors",
139
+ "vlm.lang_model.model.layers.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
140
+ "vlm.lang_model.model.layers.28.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
141
+ "vlm.lang_model.model.layers.28.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
142
+ "vlm.lang_model.model.layers.28.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
143
+ "vlm.lang_model.model.layers.28.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
144
+ "vlm.lang_model.model.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
145
+ "vlm.lang_model.model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
146
+ "vlm.lang_model.model.layers.29.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
147
+ "vlm.lang_model.model.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
148
+ "vlm.lang_model.model.layers.29.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
149
+ "vlm.lang_model.model.layers.29.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
150
+ "vlm.lang_model.model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
151
+ "vlm.lang_model.model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
152
+ "vlm.lang_model.model.layers.3.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
153
+ "vlm.lang_model.model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
154
+ "vlm.lang_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
155
+ "vlm.lang_model.model.layers.3.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
156
+ "vlm.lang_model.model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
157
+ "vlm.lang_model.model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
158
+ "vlm.lang_model.model.layers.30.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
159
+ "vlm.lang_model.model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
160
+ "vlm.lang_model.model.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
161
+ "vlm.lang_model.model.layers.30.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
162
+ "vlm.lang_model.model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
163
+ "vlm.lang_model.model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
164
+ "vlm.lang_model.model.layers.31.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
165
+ "vlm.lang_model.model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
166
+ "vlm.lang_model.model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
167
+ "vlm.lang_model.model.layers.31.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
168
+ "vlm.lang_model.model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
169
+ "vlm.lang_model.model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
170
+ "vlm.lang_model.model.layers.4.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
171
+ "vlm.lang_model.model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
172
+ "vlm.lang_model.model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
173
+ "vlm.lang_model.model.layers.4.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
174
+ "vlm.lang_model.model.layers.5.input_layernorm.weight": "model-00002-of-00004.safetensors",
175
+ "vlm.lang_model.model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
176
+ "vlm.lang_model.model.layers.5.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
177
+ "vlm.lang_model.model.layers.5.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
178
+ "vlm.lang_model.model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
179
+ "vlm.lang_model.model.layers.5.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
180
+ "vlm.lang_model.model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
181
+ "vlm.lang_model.model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
182
+ "vlm.lang_model.model.layers.6.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
183
+ "vlm.lang_model.model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
184
+ "vlm.lang_model.model.layers.6.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
185
+ "vlm.lang_model.model.layers.6.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
186
+ "vlm.lang_model.model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
187
+ "vlm.lang_model.model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
188
+ "vlm.lang_model.model.layers.7.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
189
+ "vlm.lang_model.model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
190
+ "vlm.lang_model.model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
191
+ "vlm.lang_model.model.layers.7.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
192
+ "vlm.lang_model.model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
193
+ "vlm.lang_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
194
+ "vlm.lang_model.model.layers.8.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
195
+ "vlm.lang_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
196
+ "vlm.lang_model.model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
197
+ "vlm.lang_model.model.layers.8.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
198
+ "vlm.lang_model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
199
+ "vlm.lang_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
200
+ "vlm.lang_model.model.layers.9.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
201
+ "vlm.lang_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
202
+ "vlm.lang_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
203
+ "vlm.lang_model.model.layers.9.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
204
+ "vlm.lang_model.model.norm.weight": "model-00004-of-00004.safetensors",
205
+ "vlm.vision_encoder.embeddings.patch_embedding.bias": "model-00001-of-00004.safetensors",
206
+ "vlm.vision_encoder.embeddings.patch_embedding.weight": "model-00001-of-00004.safetensors",
207
+ "vlm.vision_encoder.embeddings.position_embedding.weight": "model-00001-of-00004.safetensors",
208
+ "vlm.vision_encoder.encoder.layers.0.layer_norm1.bias": "model-00001-of-00004.safetensors",
209
+ "vlm.vision_encoder.encoder.layers.0.layer_norm1.weight": "model-00001-of-00004.safetensors",
210
+ "vlm.vision_encoder.encoder.layers.0.layer_norm2.bias": "model-00001-of-00004.safetensors",
211
+ "vlm.vision_encoder.encoder.layers.0.layer_norm2.weight": "model-00001-of-00004.safetensors",
212
+ "vlm.vision_encoder.encoder.layers.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
213
+ "vlm.vision_encoder.encoder.layers.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
214
+ "vlm.vision_encoder.encoder.layers.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
215
+ "vlm.vision_encoder.encoder.layers.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
216
+ "vlm.vision_encoder.encoder.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
217
+ "vlm.vision_encoder.encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
218
+ "vlm.vision_encoder.encoder.layers.0.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
219
+ "vlm.vision_encoder.encoder.layers.0.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
220
+ "vlm.vision_encoder.encoder.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
221
+ "vlm.vision_encoder.encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
222
+ "vlm.vision_encoder.encoder.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
223
+ "vlm.vision_encoder.encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
224
+ "vlm.vision_encoder.encoder.layers.1.layer_norm1.bias": "model-00001-of-00004.safetensors",
225
+ "vlm.vision_encoder.encoder.layers.1.layer_norm1.weight": "model-00001-of-00004.safetensors",
226
+ "vlm.vision_encoder.encoder.layers.1.layer_norm2.bias": "model-00001-of-00004.safetensors",
227
+ "vlm.vision_encoder.encoder.layers.1.layer_norm2.weight": "model-00001-of-00004.safetensors",
228
+ "vlm.vision_encoder.encoder.layers.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
229
+ "vlm.vision_encoder.encoder.layers.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
230
+ "vlm.vision_encoder.encoder.layers.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
231
+ "vlm.vision_encoder.encoder.layers.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
232
+ "vlm.vision_encoder.encoder.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
233
+ "vlm.vision_encoder.encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
234
+ "vlm.vision_encoder.encoder.layers.1.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
235
+ "vlm.vision_encoder.encoder.layers.1.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
236
+ "vlm.vision_encoder.encoder.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
237
+ "vlm.vision_encoder.encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
238
+ "vlm.vision_encoder.encoder.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
239
+ "vlm.vision_encoder.encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
240
+ "vlm.vision_encoder.encoder.layers.10.layer_norm1.bias": "model-00001-of-00004.safetensors",
241
+ "vlm.vision_encoder.encoder.layers.10.layer_norm1.weight": "model-00001-of-00004.safetensors",
242
+ "vlm.vision_encoder.encoder.layers.10.layer_norm2.bias": "model-00001-of-00004.safetensors",
243
+ "vlm.vision_encoder.encoder.layers.10.layer_norm2.weight": "model-00001-of-00004.safetensors",
244
+ "vlm.vision_encoder.encoder.layers.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
245
+ "vlm.vision_encoder.encoder.layers.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
246
+ "vlm.vision_encoder.encoder.layers.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
247
+ "vlm.vision_encoder.encoder.layers.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
248
+ "vlm.vision_encoder.encoder.layers.10.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
249
+ "vlm.vision_encoder.encoder.layers.10.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
250
+ "vlm.vision_encoder.encoder.layers.10.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
251
+ "vlm.vision_encoder.encoder.layers.10.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
252
+ "vlm.vision_encoder.encoder.layers.10.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
253
+ "vlm.vision_encoder.encoder.layers.10.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
254
+ "vlm.vision_encoder.encoder.layers.10.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
255
+ "vlm.vision_encoder.encoder.layers.10.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
256
+ "vlm.vision_encoder.encoder.layers.11.layer_norm1.bias": "model-00001-of-00004.safetensors",
257
+ "vlm.vision_encoder.encoder.layers.11.layer_norm1.weight": "model-00001-of-00004.safetensors",
258
+ "vlm.vision_encoder.encoder.layers.11.layer_norm2.bias": "model-00001-of-00004.safetensors",
259
+ "vlm.vision_encoder.encoder.layers.11.layer_norm2.weight": "model-00001-of-00004.safetensors",
260
+ "vlm.vision_encoder.encoder.layers.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
261
+ "vlm.vision_encoder.encoder.layers.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
262
+ "vlm.vision_encoder.encoder.layers.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
263
+ "vlm.vision_encoder.encoder.layers.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
264
+ "vlm.vision_encoder.encoder.layers.11.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
265
+ "vlm.vision_encoder.encoder.layers.11.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
266
+ "vlm.vision_encoder.encoder.layers.11.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
267
+ "vlm.vision_encoder.encoder.layers.11.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
268
+ "vlm.vision_encoder.encoder.layers.11.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
269
+ "vlm.vision_encoder.encoder.layers.11.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
270
+ "vlm.vision_encoder.encoder.layers.11.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
271
+ "vlm.vision_encoder.encoder.layers.11.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
272
+ "vlm.vision_encoder.encoder.layers.12.layer_norm1.bias": "model-00001-of-00004.safetensors",
273
+ "vlm.vision_encoder.encoder.layers.12.layer_norm1.weight": "model-00001-of-00004.safetensors",
274
+ "vlm.vision_encoder.encoder.layers.12.layer_norm2.bias": "model-00001-of-00004.safetensors",
275
+ "vlm.vision_encoder.encoder.layers.12.layer_norm2.weight": "model-00001-of-00004.safetensors",
276
+ "vlm.vision_encoder.encoder.layers.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
277
+ "vlm.vision_encoder.encoder.layers.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
278
+ "vlm.vision_encoder.encoder.layers.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
279
+ "vlm.vision_encoder.encoder.layers.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
280
+ "vlm.vision_encoder.encoder.layers.12.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
281
+ "vlm.vision_encoder.encoder.layers.12.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
282
+ "vlm.vision_encoder.encoder.layers.12.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
283
+ "vlm.vision_encoder.encoder.layers.12.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
284
+ "vlm.vision_encoder.encoder.layers.12.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
285
+ "vlm.vision_encoder.encoder.layers.12.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
286
+ "vlm.vision_encoder.encoder.layers.12.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
287
+ "vlm.vision_encoder.encoder.layers.12.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
288
+ "vlm.vision_encoder.encoder.layers.13.layer_norm1.bias": "model-00001-of-00004.safetensors",
289
+ "vlm.vision_encoder.encoder.layers.13.layer_norm1.weight": "model-00001-of-00004.safetensors",
290
+ "vlm.vision_encoder.encoder.layers.13.layer_norm2.bias": "model-00001-of-00004.safetensors",
291
+ "vlm.vision_encoder.encoder.layers.13.layer_norm2.weight": "model-00001-of-00004.safetensors",
292
+ "vlm.vision_encoder.encoder.layers.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
293
+ "vlm.vision_encoder.encoder.layers.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
294
+ "vlm.vision_encoder.encoder.layers.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
295
+ "vlm.vision_encoder.encoder.layers.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
296
+ "vlm.vision_encoder.encoder.layers.13.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
297
+ "vlm.vision_encoder.encoder.layers.13.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
298
+ "vlm.vision_encoder.encoder.layers.13.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
299
+ "vlm.vision_encoder.encoder.layers.13.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
300
+ "vlm.vision_encoder.encoder.layers.13.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
301
+ "vlm.vision_encoder.encoder.layers.13.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
302
+ "vlm.vision_encoder.encoder.layers.13.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
303
+ "vlm.vision_encoder.encoder.layers.13.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
304
+ "vlm.vision_encoder.encoder.layers.14.layer_norm1.bias": "model-00001-of-00004.safetensors",
305
+ "vlm.vision_encoder.encoder.layers.14.layer_norm1.weight": "model-00001-of-00004.safetensors",
306
+ "vlm.vision_encoder.encoder.layers.14.layer_norm2.bias": "model-00001-of-00004.safetensors",
307
+ "vlm.vision_encoder.encoder.layers.14.layer_norm2.weight": "model-00001-of-00004.safetensors",
308
+ "vlm.vision_encoder.encoder.layers.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
309
+ "vlm.vision_encoder.encoder.layers.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
310
+ "vlm.vision_encoder.encoder.layers.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
311
+ "vlm.vision_encoder.encoder.layers.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
312
+ "vlm.vision_encoder.encoder.layers.14.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
313
+ "vlm.vision_encoder.encoder.layers.14.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
314
+ "vlm.vision_encoder.encoder.layers.14.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
315
+ "vlm.vision_encoder.encoder.layers.14.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
316
+ "vlm.vision_encoder.encoder.layers.14.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
317
+ "vlm.vision_encoder.encoder.layers.14.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
318
+ "vlm.vision_encoder.encoder.layers.14.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
319
+ "vlm.vision_encoder.encoder.layers.14.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
320
+ "vlm.vision_encoder.encoder.layers.15.layer_norm1.bias": "model-00001-of-00004.safetensors",
321
+ "vlm.vision_encoder.encoder.layers.15.layer_norm1.weight": "model-00001-of-00004.safetensors",
322
+ "vlm.vision_encoder.encoder.layers.15.layer_norm2.bias": "model-00001-of-00004.safetensors",
323
+ "vlm.vision_encoder.encoder.layers.15.layer_norm2.weight": "model-00001-of-00004.safetensors",
324
+ "vlm.vision_encoder.encoder.layers.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
325
+ "vlm.vision_encoder.encoder.layers.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
326
+ "vlm.vision_encoder.encoder.layers.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
327
+ "vlm.vision_encoder.encoder.layers.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
328
+ "vlm.vision_encoder.encoder.layers.15.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
329
+ "vlm.vision_encoder.encoder.layers.15.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
330
+ "vlm.vision_encoder.encoder.layers.15.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
331
+ "vlm.vision_encoder.encoder.layers.15.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
332
+ "vlm.vision_encoder.encoder.layers.15.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
333
+ "vlm.vision_encoder.encoder.layers.15.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
334
+ "vlm.vision_encoder.encoder.layers.15.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
335
+ "vlm.vision_encoder.encoder.layers.15.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
336
+ "vlm.vision_encoder.encoder.layers.16.layer_norm1.bias": "model-00001-of-00004.safetensors",
337
+ "vlm.vision_encoder.encoder.layers.16.layer_norm1.weight": "model-00001-of-00004.safetensors",
338
+ "vlm.vision_encoder.encoder.layers.16.layer_norm2.bias": "model-00001-of-00004.safetensors",
339
+ "vlm.vision_encoder.encoder.layers.16.layer_norm2.weight": "model-00001-of-00004.safetensors",
340
+ "vlm.vision_encoder.encoder.layers.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
341
+ "vlm.vision_encoder.encoder.layers.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
342
+ "vlm.vision_encoder.encoder.layers.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
343
+ "vlm.vision_encoder.encoder.layers.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
344
+ "vlm.vision_encoder.encoder.layers.16.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
345
+ "vlm.vision_encoder.encoder.layers.16.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
346
+ "vlm.vision_encoder.encoder.layers.16.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
347
+ "vlm.vision_encoder.encoder.layers.16.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
348
+ "vlm.vision_encoder.encoder.layers.16.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
349
+ "vlm.vision_encoder.encoder.layers.16.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
350
+ "vlm.vision_encoder.encoder.layers.16.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
351
+ "vlm.vision_encoder.encoder.layers.16.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
352
+ "vlm.vision_encoder.encoder.layers.17.layer_norm1.bias": "model-00001-of-00004.safetensors",
353
+ "vlm.vision_encoder.encoder.layers.17.layer_norm1.weight": "model-00001-of-00004.safetensors",
354
+ "vlm.vision_encoder.encoder.layers.17.layer_norm2.bias": "model-00001-of-00004.safetensors",
355
+ "vlm.vision_encoder.encoder.layers.17.layer_norm2.weight": "model-00001-of-00004.safetensors",
356
+ "vlm.vision_encoder.encoder.layers.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
357
+ "vlm.vision_encoder.encoder.layers.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
358
+ "vlm.vision_encoder.encoder.layers.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
359
+ "vlm.vision_encoder.encoder.layers.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
360
+ "vlm.vision_encoder.encoder.layers.17.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
361
+ "vlm.vision_encoder.encoder.layers.17.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
362
+ "vlm.vision_encoder.encoder.layers.17.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
363
+ "vlm.vision_encoder.encoder.layers.17.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
364
+ "vlm.vision_encoder.encoder.layers.17.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
365
+ "vlm.vision_encoder.encoder.layers.17.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
366
+ "vlm.vision_encoder.encoder.layers.17.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
367
+ "vlm.vision_encoder.encoder.layers.17.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
368
+ "vlm.vision_encoder.encoder.layers.18.layer_norm1.bias": "model-00001-of-00004.safetensors",
369
+ "vlm.vision_encoder.encoder.layers.18.layer_norm1.weight": "model-00001-of-00004.safetensors",
370
+ "vlm.vision_encoder.encoder.layers.18.layer_norm2.bias": "model-00001-of-00004.safetensors",
371
+ "vlm.vision_encoder.encoder.layers.18.layer_norm2.weight": "model-00001-of-00004.safetensors",
372
+ "vlm.vision_encoder.encoder.layers.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
373
+ "vlm.vision_encoder.encoder.layers.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
374
+ "vlm.vision_encoder.encoder.layers.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
375
+ "vlm.vision_encoder.encoder.layers.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
376
+ "vlm.vision_encoder.encoder.layers.18.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
377
+ "vlm.vision_encoder.encoder.layers.18.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
378
+ "vlm.vision_encoder.encoder.layers.18.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
379
+ "vlm.vision_encoder.encoder.layers.18.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
380
+ "vlm.vision_encoder.encoder.layers.18.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
381
+ "vlm.vision_encoder.encoder.layers.18.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
382
+ "vlm.vision_encoder.encoder.layers.18.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
383
+ "vlm.vision_encoder.encoder.layers.18.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
384
+ "vlm.vision_encoder.encoder.layers.19.layer_norm1.bias": "model-00001-of-00004.safetensors",
385
+ "vlm.vision_encoder.encoder.layers.19.layer_norm1.weight": "model-00001-of-00004.safetensors",
386
+ "vlm.vision_encoder.encoder.layers.19.layer_norm2.bias": "model-00001-of-00004.safetensors",
387
+ "vlm.vision_encoder.encoder.layers.19.layer_norm2.weight": "model-00001-of-00004.safetensors",
388
+ "vlm.vision_encoder.encoder.layers.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
389
+ "vlm.vision_encoder.encoder.layers.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
390
+ "vlm.vision_encoder.encoder.layers.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
391
+ "vlm.vision_encoder.encoder.layers.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
392
+ "vlm.vision_encoder.encoder.layers.19.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
393
+ "vlm.vision_encoder.encoder.layers.19.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
394
+ "vlm.vision_encoder.encoder.layers.19.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
395
+ "vlm.vision_encoder.encoder.layers.19.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
396
+ "vlm.vision_encoder.encoder.layers.19.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
397
+ "vlm.vision_encoder.encoder.layers.19.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
398
+ "vlm.vision_encoder.encoder.layers.19.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
399
+ "vlm.vision_encoder.encoder.layers.19.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
400
+ "vlm.vision_encoder.encoder.layers.2.layer_norm1.bias": "model-00001-of-00004.safetensors",
401
+ "vlm.vision_encoder.encoder.layers.2.layer_norm1.weight": "model-00001-of-00004.safetensors",
402
+ "vlm.vision_encoder.encoder.layers.2.layer_norm2.bias": "model-00001-of-00004.safetensors",
403
+ "vlm.vision_encoder.encoder.layers.2.layer_norm2.weight": "model-00001-of-00004.safetensors",
404
+ "vlm.vision_encoder.encoder.layers.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
405
+ "vlm.vision_encoder.encoder.layers.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
406
+ "vlm.vision_encoder.encoder.layers.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
407
+ "vlm.vision_encoder.encoder.layers.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
408
+ "vlm.vision_encoder.encoder.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
409
+ "vlm.vision_encoder.encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
410
+ "vlm.vision_encoder.encoder.layers.2.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
411
+ "vlm.vision_encoder.encoder.layers.2.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
412
+ "vlm.vision_encoder.encoder.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
413
+ "vlm.vision_encoder.encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
414
+ "vlm.vision_encoder.encoder.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
415
+ "vlm.vision_encoder.encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
416
+ "vlm.vision_encoder.encoder.layers.20.layer_norm1.bias": "model-00001-of-00004.safetensors",
417
+ "vlm.vision_encoder.encoder.layers.20.layer_norm1.weight": "model-00001-of-00004.safetensors",
418
+ "vlm.vision_encoder.encoder.layers.20.layer_norm2.bias": "model-00001-of-00004.safetensors",
419
+ "vlm.vision_encoder.encoder.layers.20.layer_norm2.weight": "model-00001-of-00004.safetensors",
420
+ "vlm.vision_encoder.encoder.layers.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
421
+ "vlm.vision_encoder.encoder.layers.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
422
+ "vlm.vision_encoder.encoder.layers.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
423
+ "vlm.vision_encoder.encoder.layers.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
424
+ "vlm.vision_encoder.encoder.layers.20.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
425
+ "vlm.vision_encoder.encoder.layers.20.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
426
+ "vlm.vision_encoder.encoder.layers.20.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
427
+ "vlm.vision_encoder.encoder.layers.20.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
428
+ "vlm.vision_encoder.encoder.layers.20.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
429
+ "vlm.vision_encoder.encoder.layers.20.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
430
+ "vlm.vision_encoder.encoder.layers.20.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
431
+ "vlm.vision_encoder.encoder.layers.20.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
432
+ "vlm.vision_encoder.encoder.layers.21.layer_norm1.bias": "model-00001-of-00004.safetensors",
433
+ "vlm.vision_encoder.encoder.layers.21.layer_norm1.weight": "model-00001-of-00004.safetensors",
434
+ "vlm.vision_encoder.encoder.layers.21.layer_norm2.bias": "model-00001-of-00004.safetensors",
435
+ "vlm.vision_encoder.encoder.layers.21.layer_norm2.weight": "model-00001-of-00004.safetensors",
436
+ "vlm.vision_encoder.encoder.layers.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
437
+ "vlm.vision_encoder.encoder.layers.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
438
+ "vlm.vision_encoder.encoder.layers.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
439
+ "vlm.vision_encoder.encoder.layers.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
440
+ "vlm.vision_encoder.encoder.layers.21.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
441
+ "vlm.vision_encoder.encoder.layers.21.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
442
+ "vlm.vision_encoder.encoder.layers.21.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
443
+ "vlm.vision_encoder.encoder.layers.21.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
444
+ "vlm.vision_encoder.encoder.layers.21.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
445
+ "vlm.vision_encoder.encoder.layers.21.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
446
+ "vlm.vision_encoder.encoder.layers.21.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
447
+ "vlm.vision_encoder.encoder.layers.21.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
448
+ "vlm.vision_encoder.encoder.layers.22.layer_norm1.bias": "model-00001-of-00004.safetensors",
449
+ "vlm.vision_encoder.encoder.layers.22.layer_norm1.weight": "model-00001-of-00004.safetensors",
450
+ "vlm.vision_encoder.encoder.layers.22.layer_norm2.bias": "model-00001-of-00004.safetensors",
451
+ "vlm.vision_encoder.encoder.layers.22.layer_norm2.weight": "model-00001-of-00004.safetensors",
452
+ "vlm.vision_encoder.encoder.layers.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
453
+ "vlm.vision_encoder.encoder.layers.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
454
+ "vlm.vision_encoder.encoder.layers.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
455
+ "vlm.vision_encoder.encoder.layers.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
456
+ "vlm.vision_encoder.encoder.layers.22.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
457
+ "vlm.vision_encoder.encoder.layers.22.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
458
+ "vlm.vision_encoder.encoder.layers.22.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
459
+ "vlm.vision_encoder.encoder.layers.22.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
460
+ "vlm.vision_encoder.encoder.layers.22.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
461
+ "vlm.vision_encoder.encoder.layers.22.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
462
+ "vlm.vision_encoder.encoder.layers.22.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
463
+ "vlm.vision_encoder.encoder.layers.22.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
464
+ "vlm.vision_encoder.encoder.layers.23.layer_norm1.bias": "model-00001-of-00004.safetensors",
465
+ "vlm.vision_encoder.encoder.layers.23.layer_norm1.weight": "model-00001-of-00004.safetensors",
466
+ "vlm.vision_encoder.encoder.layers.23.layer_norm2.bias": "model-00001-of-00004.safetensors",
467
+ "vlm.vision_encoder.encoder.layers.23.layer_norm2.weight": "model-00001-of-00004.safetensors",
468
+ "vlm.vision_encoder.encoder.layers.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
469
+ "vlm.vision_encoder.encoder.layers.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
470
+ "vlm.vision_encoder.encoder.layers.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
471
+ "vlm.vision_encoder.encoder.layers.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
472
+ "vlm.vision_encoder.encoder.layers.23.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
473
+ "vlm.vision_encoder.encoder.layers.23.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
474
+ "vlm.vision_encoder.encoder.layers.23.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
475
+ "vlm.vision_encoder.encoder.layers.23.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
476
+ "vlm.vision_encoder.encoder.layers.23.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
477
+ "vlm.vision_encoder.encoder.layers.23.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
478
+ "vlm.vision_encoder.encoder.layers.23.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
479
+ "vlm.vision_encoder.encoder.layers.23.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
480
+ "vlm.vision_encoder.encoder.layers.24.layer_norm1.bias": "model-00001-of-00004.safetensors",
481
+ "vlm.vision_encoder.encoder.layers.24.layer_norm1.weight": "model-00001-of-00004.safetensors",
482
+ "vlm.vision_encoder.encoder.layers.24.layer_norm2.bias": "model-00001-of-00004.safetensors",
483
+ "vlm.vision_encoder.encoder.layers.24.layer_norm2.weight": "model-00001-of-00004.safetensors",
484
+ "vlm.vision_encoder.encoder.layers.24.mlp.fc1.bias": "model-00001-of-00004.safetensors",
485
+ "vlm.vision_encoder.encoder.layers.24.mlp.fc1.weight": "model-00001-of-00004.safetensors",
486
+ "vlm.vision_encoder.encoder.layers.24.mlp.fc2.bias": "model-00001-of-00004.safetensors",
487
+ "vlm.vision_encoder.encoder.layers.24.mlp.fc2.weight": "model-00001-of-00004.safetensors",
488
+ "vlm.vision_encoder.encoder.layers.24.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
489
+ "vlm.vision_encoder.encoder.layers.24.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
490
+ "vlm.vision_encoder.encoder.layers.24.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
491
+ "vlm.vision_encoder.encoder.layers.24.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
492
+ "vlm.vision_encoder.encoder.layers.24.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
493
+ "vlm.vision_encoder.encoder.layers.24.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
494
+ "vlm.vision_encoder.encoder.layers.24.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
495
+ "vlm.vision_encoder.encoder.layers.24.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
496
+ "vlm.vision_encoder.encoder.layers.25.layer_norm1.bias": "model-00001-of-00004.safetensors",
497
+ "vlm.vision_encoder.encoder.layers.25.layer_norm1.weight": "model-00001-of-00004.safetensors",
498
+ "vlm.vision_encoder.encoder.layers.25.layer_norm2.bias": "model-00001-of-00004.safetensors",
499
+ "vlm.vision_encoder.encoder.layers.25.layer_norm2.weight": "model-00001-of-00004.safetensors",
500
+ "vlm.vision_encoder.encoder.layers.25.mlp.fc1.bias": "model-00001-of-00004.safetensors",
501
+ "vlm.vision_encoder.encoder.layers.25.mlp.fc1.weight": "model-00001-of-00004.safetensors",
502
+ "vlm.vision_encoder.encoder.layers.25.mlp.fc2.bias": "model-00001-of-00004.safetensors",
503
+ "vlm.vision_encoder.encoder.layers.25.mlp.fc2.weight": "model-00001-of-00004.safetensors",
504
+ "vlm.vision_encoder.encoder.layers.25.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
505
+ "vlm.vision_encoder.encoder.layers.25.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
506
+ "vlm.vision_encoder.encoder.layers.25.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
507
+ "vlm.vision_encoder.encoder.layers.25.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
508
+ "vlm.vision_encoder.encoder.layers.25.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
509
+ "vlm.vision_encoder.encoder.layers.25.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
510
+ "vlm.vision_encoder.encoder.layers.25.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
511
+ "vlm.vision_encoder.encoder.layers.25.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
512
+ "vlm.vision_encoder.encoder.layers.26.layer_norm1.bias": "model-00001-of-00004.safetensors",
513
+ "vlm.vision_encoder.encoder.layers.26.layer_norm1.weight": "model-00001-of-00004.safetensors",
514
+ "vlm.vision_encoder.encoder.layers.26.layer_norm2.bias": "model-00001-of-00004.safetensors",
515
+ "vlm.vision_encoder.encoder.layers.26.layer_norm2.weight": "model-00001-of-00004.safetensors",
516
+ "vlm.vision_encoder.encoder.layers.26.mlp.fc1.bias": "model-00001-of-00004.safetensors",
517
+ "vlm.vision_encoder.encoder.layers.26.mlp.fc1.weight": "model-00001-of-00004.safetensors",
518
+ "vlm.vision_encoder.encoder.layers.26.mlp.fc2.bias": "model-00001-of-00004.safetensors",
519
+ "vlm.vision_encoder.encoder.layers.26.mlp.fc2.weight": "model-00001-of-00004.safetensors",
520
+ "vlm.vision_encoder.encoder.layers.26.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
521
+ "vlm.vision_encoder.encoder.layers.26.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
522
+ "vlm.vision_encoder.encoder.layers.26.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
523
+ "vlm.vision_encoder.encoder.layers.26.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
524
+ "vlm.vision_encoder.encoder.layers.26.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
525
+ "vlm.vision_encoder.encoder.layers.26.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
526
+ "vlm.vision_encoder.encoder.layers.26.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
527
+ "vlm.vision_encoder.encoder.layers.26.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
528
+ "vlm.vision_encoder.encoder.layers.3.layer_norm1.bias": "model-00001-of-00004.safetensors",
529
+ "vlm.vision_encoder.encoder.layers.3.layer_norm1.weight": "model-00001-of-00004.safetensors",
530
+ "vlm.vision_encoder.encoder.layers.3.layer_norm2.bias": "model-00001-of-00004.safetensors",
531
+ "vlm.vision_encoder.encoder.layers.3.layer_norm2.weight": "model-00001-of-00004.safetensors",
532
+ "vlm.vision_encoder.encoder.layers.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
533
+ "vlm.vision_encoder.encoder.layers.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
534
+ "vlm.vision_encoder.encoder.layers.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
535
+ "vlm.vision_encoder.encoder.layers.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
536
+ "vlm.vision_encoder.encoder.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
537
+ "vlm.vision_encoder.encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
538
+ "vlm.vision_encoder.encoder.layers.3.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
539
+ "vlm.vision_encoder.encoder.layers.3.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
540
+ "vlm.vision_encoder.encoder.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
541
+ "vlm.vision_encoder.encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
542
+ "vlm.vision_encoder.encoder.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
543
+ "vlm.vision_encoder.encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
544
+ "vlm.vision_encoder.encoder.layers.4.layer_norm1.bias": "model-00001-of-00004.safetensors",
545
+ "vlm.vision_encoder.encoder.layers.4.layer_norm1.weight": "model-00001-of-00004.safetensors",
546
+ "vlm.vision_encoder.encoder.layers.4.layer_norm2.bias": "model-00001-of-00004.safetensors",
547
+ "vlm.vision_encoder.encoder.layers.4.layer_norm2.weight": "model-00001-of-00004.safetensors",
548
+ "vlm.vision_encoder.encoder.layers.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
549
+ "vlm.vision_encoder.encoder.layers.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
550
+ "vlm.vision_encoder.encoder.layers.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
551
+ "vlm.vision_encoder.encoder.layers.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
552
+ "vlm.vision_encoder.encoder.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
553
+ "vlm.vision_encoder.encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
554
+ "vlm.vision_encoder.encoder.layers.4.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
555
+ "vlm.vision_encoder.encoder.layers.4.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
556
+ "vlm.vision_encoder.encoder.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
557
+ "vlm.vision_encoder.encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
558
+ "vlm.vision_encoder.encoder.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
559
+ "vlm.vision_encoder.encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
560
+ "vlm.vision_encoder.encoder.layers.5.layer_norm1.bias": "model-00001-of-00004.safetensors",
561
+ "vlm.vision_encoder.encoder.layers.5.layer_norm1.weight": "model-00001-of-00004.safetensors",
562
+ "vlm.vision_encoder.encoder.layers.5.layer_norm2.bias": "model-00001-of-00004.safetensors",
563
+ "vlm.vision_encoder.encoder.layers.5.layer_norm2.weight": "model-00001-of-00004.safetensors",
564
+ "vlm.vision_encoder.encoder.layers.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
565
+ "vlm.vision_encoder.encoder.layers.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
566
+ "vlm.vision_encoder.encoder.layers.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
567
+ "vlm.vision_encoder.encoder.layers.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
568
+ "vlm.vision_encoder.encoder.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
569
+ "vlm.vision_encoder.encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
570
+ "vlm.vision_encoder.encoder.layers.5.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
571
+ "vlm.vision_encoder.encoder.layers.5.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
572
+ "vlm.vision_encoder.encoder.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
573
+ "vlm.vision_encoder.encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
574
+ "vlm.vision_encoder.encoder.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
575
+ "vlm.vision_encoder.encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
576
+ "vlm.vision_encoder.encoder.layers.6.layer_norm1.bias": "model-00001-of-00004.safetensors",
577
+ "vlm.vision_encoder.encoder.layers.6.layer_norm1.weight": "model-00001-of-00004.safetensors",
578
+ "vlm.vision_encoder.encoder.layers.6.layer_norm2.bias": "model-00001-of-00004.safetensors",
579
+ "vlm.vision_encoder.encoder.layers.6.layer_norm2.weight": "model-00001-of-00004.safetensors",
580
+ "vlm.vision_encoder.encoder.layers.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
581
+ "vlm.vision_encoder.encoder.layers.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
582
+ "vlm.vision_encoder.encoder.layers.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
583
+ "vlm.vision_encoder.encoder.layers.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
584
+ "vlm.vision_encoder.encoder.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
585
+ "vlm.vision_encoder.encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
586
+ "vlm.vision_encoder.encoder.layers.6.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
587
+ "vlm.vision_encoder.encoder.layers.6.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
588
+ "vlm.vision_encoder.encoder.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
589
+ "vlm.vision_encoder.encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
590
+ "vlm.vision_encoder.encoder.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
591
+ "vlm.vision_encoder.encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
592
+ "vlm.vision_encoder.encoder.layers.7.layer_norm1.bias": "model-00001-of-00004.safetensors",
593
+ "vlm.vision_encoder.encoder.layers.7.layer_norm1.weight": "model-00001-of-00004.safetensors",
594
+ "vlm.vision_encoder.encoder.layers.7.layer_norm2.bias": "model-00001-of-00004.safetensors",
595
+ "vlm.vision_encoder.encoder.layers.7.layer_norm2.weight": "model-00001-of-00004.safetensors",
596
+ "vlm.vision_encoder.encoder.layers.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
597
+ "vlm.vision_encoder.encoder.layers.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
598
+ "vlm.vision_encoder.encoder.layers.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
599
+ "vlm.vision_encoder.encoder.layers.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
600
+ "vlm.vision_encoder.encoder.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
601
+ "vlm.vision_encoder.encoder.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
602
+ "vlm.vision_encoder.encoder.layers.7.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
603
+ "vlm.vision_encoder.encoder.layers.7.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
604
+ "vlm.vision_encoder.encoder.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
605
+ "vlm.vision_encoder.encoder.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
606
+ "vlm.vision_encoder.encoder.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
607
+ "vlm.vision_encoder.encoder.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
608
+ "vlm.vision_encoder.encoder.layers.8.layer_norm1.bias": "model-00001-of-00004.safetensors",
609
+ "vlm.vision_encoder.encoder.layers.8.layer_norm1.weight": "model-00001-of-00004.safetensors",
610
+ "vlm.vision_encoder.encoder.layers.8.layer_norm2.bias": "model-00001-of-00004.safetensors",
611
+ "vlm.vision_encoder.encoder.layers.8.layer_norm2.weight": "model-00001-of-00004.safetensors",
612
+ "vlm.vision_encoder.encoder.layers.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
613
+ "vlm.vision_encoder.encoder.layers.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
614
+ "vlm.vision_encoder.encoder.layers.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
615
+ "vlm.vision_encoder.encoder.layers.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
616
+ "vlm.vision_encoder.encoder.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
617
+ "vlm.vision_encoder.encoder.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
618
+ "vlm.vision_encoder.encoder.layers.8.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
619
+ "vlm.vision_encoder.encoder.layers.8.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
620
+ "vlm.vision_encoder.encoder.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
621
+ "vlm.vision_encoder.encoder.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
622
+ "vlm.vision_encoder.encoder.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
623
+ "vlm.vision_encoder.encoder.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
624
+ "vlm.vision_encoder.encoder.layers.9.layer_norm1.bias": "model-00001-of-00004.safetensors",
625
+ "vlm.vision_encoder.encoder.layers.9.layer_norm1.weight": "model-00001-of-00004.safetensors",
626
+ "vlm.vision_encoder.encoder.layers.9.layer_norm2.bias": "model-00001-of-00004.safetensors",
627
+ "vlm.vision_encoder.encoder.layers.9.layer_norm2.weight": "model-00001-of-00004.safetensors",
628
+ "vlm.vision_encoder.encoder.layers.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
629
+ "vlm.vision_encoder.encoder.layers.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
630
+ "vlm.vision_encoder.encoder.layers.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
631
+ "vlm.vision_encoder.encoder.layers.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
632
+ "vlm.vision_encoder.encoder.layers.9.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
633
+ "vlm.vision_encoder.encoder.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
634
+ "vlm.vision_encoder.encoder.layers.9.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
635
+ "vlm.vision_encoder.encoder.layers.9.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
636
+ "vlm.vision_encoder.encoder.layers.9.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
637
+ "vlm.vision_encoder.encoder.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
638
+ "vlm.vision_encoder.encoder.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
639
+ "vlm.vision_encoder.encoder.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
640
+ "vlm.vision_encoder.head.attention.in_proj_bias": "model-00001-of-00004.safetensors",
641
+ "vlm.vision_encoder.head.attention.in_proj_weight": "model-00001-of-00004.safetensors",
642
+ "vlm.vision_encoder.head.attention.out_proj.bias": "model-00001-of-00004.safetensors",
643
+ "vlm.vision_encoder.head.attention.out_proj.weight": "model-00001-of-00004.safetensors",
644
+ "vlm.vision_encoder.head.layernorm.bias": "model-00001-of-00004.safetensors",
645
+ "vlm.vision_encoder.head.layernorm.weight": "model-00001-of-00004.safetensors",
646
+ "vlm.vision_encoder.head.mlp.fc1.bias": "model-00001-of-00004.safetensors",
647
+ "vlm.vision_encoder.head.mlp.fc1.weight": "model-00001-of-00004.safetensors",
648
+ "vlm.vision_encoder.head.mlp.fc2.bias": "model-00001-of-00004.safetensors",
649
+ "vlm.vision_encoder.head.mlp.fc2.weight": "model-00001-of-00004.safetensors",
650
+ "vlm.vision_encoder.head.probe": "model-00001-of-00004.safetensors",
651
+ "vlm.vision_encoder.post_layernorm.bias": "model-00001-of-00004.safetensors",
652
+ "vlm.vision_encoder.post_layernorm.weight": "model-00001-of-00004.safetensors",
653
+ "vlm.vision_tokenizer.latents": "model-00001-of-00004.safetensors",
654
+ "vlm.vision_tokenizer.layers.0.0.norm_latents.bias": "model-00001-of-00004.safetensors",
655
+ "vlm.vision_tokenizer.layers.0.0.norm_latents.weight": "model-00001-of-00004.safetensors",
656
+ "vlm.vision_tokenizer.layers.0.0.norm_media.bias": "model-00001-of-00004.safetensors",
657
+ "vlm.vision_tokenizer.layers.0.0.norm_media.weight": "model-00001-of-00004.safetensors",
658
+ "vlm.vision_tokenizer.layers.0.0.to_kv.weight": "model-00001-of-00004.safetensors",
659
+ "vlm.vision_tokenizer.layers.0.0.to_out.weight": "model-00001-of-00004.safetensors",
660
+ "vlm.vision_tokenizer.layers.0.0.to_q.weight": "model-00001-of-00004.safetensors",
661
+ "vlm.vision_tokenizer.layers.0.1.0.bias": "model-00001-of-00004.safetensors",
662
+ "vlm.vision_tokenizer.layers.0.1.0.weight": "model-00001-of-00004.safetensors",
663
+ "vlm.vision_tokenizer.layers.0.1.1.weight": "model-00001-of-00004.safetensors",
664
+ "vlm.vision_tokenizer.layers.0.1.3.weight": "model-00001-of-00004.safetensors",
665
+ "vlm.vision_tokenizer.layers.1.0.norm_latents.bias": "model-00001-of-00004.safetensors",
666
+ "vlm.vision_tokenizer.layers.1.0.norm_latents.weight": "model-00001-of-00004.safetensors",
667
+ "vlm.vision_tokenizer.layers.1.0.norm_media.bias": "model-00001-of-00004.safetensors",
668
+ "vlm.vision_tokenizer.layers.1.0.norm_media.weight": "model-00001-of-00004.safetensors",
669
+ "vlm.vision_tokenizer.layers.1.0.to_kv.weight": "model-00001-of-00004.safetensors",
670
+ "vlm.vision_tokenizer.layers.1.0.to_out.weight": "model-00001-of-00004.safetensors",
671
+ "vlm.vision_tokenizer.layers.1.0.to_q.weight": "model-00001-of-00004.safetensors",
672
+ "vlm.vision_tokenizer.layers.1.1.0.bias": "model-00001-of-00004.safetensors",
673
+ "vlm.vision_tokenizer.layers.1.1.0.weight": "model-00001-of-00004.safetensors",
674
+ "vlm.vision_tokenizer.layers.1.1.1.weight": "model-00001-of-00004.safetensors",
675
+ "vlm.vision_tokenizer.layers.1.1.3.weight": "model-00001-of-00004.safetensors",
676
+ "vlm.vision_tokenizer.layers.2.0.norm_latents.bias": "model-00001-of-00004.safetensors",
677
+ "vlm.vision_tokenizer.layers.2.0.norm_latents.weight": "model-00001-of-00004.safetensors",
678
+ "vlm.vision_tokenizer.layers.2.0.norm_media.bias": "model-00001-of-00004.safetensors",
679
+ "vlm.vision_tokenizer.layers.2.0.norm_media.weight": "model-00001-of-00004.safetensors",
680
+ "vlm.vision_tokenizer.layers.2.0.to_kv.weight": "model-00001-of-00004.safetensors",
681
+ "vlm.vision_tokenizer.layers.2.0.to_out.weight": "model-00001-of-00004.safetensors",
682
+ "vlm.vision_tokenizer.layers.2.0.to_q.weight": "model-00001-of-00004.safetensors",
683
+ "vlm.vision_tokenizer.layers.2.1.0.bias": "model-00001-of-00004.safetensors",
684
+ "vlm.vision_tokenizer.layers.2.1.0.weight": "model-00001-of-00004.safetensors",
685
+ "vlm.vision_tokenizer.layers.2.1.1.weight": "model-00001-of-00004.safetensors",
686
+ "vlm.vision_tokenizer.layers.2.1.3.weight": "model-00001-of-00004.safetensors",
687
+ "vlm.vision_tokenizer.layers.3.0.norm_latents.bias": "model-00001-of-00004.safetensors",
688
+ "vlm.vision_tokenizer.layers.3.0.norm_latents.weight": "model-00001-of-00004.safetensors",
689
+ "vlm.vision_tokenizer.layers.3.0.norm_media.bias": "model-00001-of-00004.safetensors",
690
+ "vlm.vision_tokenizer.layers.3.0.norm_media.weight": "model-00001-of-00004.safetensors",
691
+ "vlm.vision_tokenizer.layers.3.0.to_kv.weight": "model-00001-of-00004.safetensors",
692
+ "vlm.vision_tokenizer.layers.3.0.to_out.weight": "model-00001-of-00004.safetensors",
693
+ "vlm.vision_tokenizer.layers.3.0.to_q.weight": "model-00001-of-00004.safetensors",
694
+ "vlm.vision_tokenizer.layers.3.1.0.bias": "model-00001-of-00004.safetensors",
695
+ "vlm.vision_tokenizer.layers.3.1.0.weight": "model-00001-of-00004.safetensors",
696
+ "vlm.vision_tokenizer.layers.3.1.1.weight": "model-00001-of-00004.safetensors",
697
+ "vlm.vision_tokenizer.layers.3.1.3.weight": "model-00001-of-00004.safetensors",
698
+ "vlm.vision_tokenizer.layers.4.0.norm_latents.bias": "model-00001-of-00004.safetensors",
699
+ "vlm.vision_tokenizer.layers.4.0.norm_latents.weight": "model-00001-of-00004.safetensors",
700
+ "vlm.vision_tokenizer.layers.4.0.norm_media.bias": "model-00001-of-00004.safetensors",
701
+ "vlm.vision_tokenizer.layers.4.0.norm_media.weight": "model-00001-of-00004.safetensors",
702
+ "vlm.vision_tokenizer.layers.4.0.to_kv.weight": "model-00001-of-00004.safetensors",
703
+ "vlm.vision_tokenizer.layers.4.0.to_out.weight": "model-00001-of-00004.safetensors",
704
+ "vlm.vision_tokenizer.layers.4.0.to_q.weight": "model-00001-of-00004.safetensors",
705
+ "vlm.vision_tokenizer.layers.4.1.0.bias": "model-00001-of-00004.safetensors",
706
+ "vlm.vision_tokenizer.layers.4.1.0.weight": "model-00001-of-00004.safetensors",
707
+ "vlm.vision_tokenizer.layers.4.1.1.weight": "model-00001-of-00004.safetensors",
708
+ "vlm.vision_tokenizer.layers.4.1.3.weight": "model-00001-of-00004.safetensors",
709
+ "vlm.vision_tokenizer.layers.5.0.norm_latents.bias": "model-00001-of-00004.safetensors",
710
+ "vlm.vision_tokenizer.layers.5.0.norm_latents.weight": "model-00001-of-00004.safetensors",
711
+ "vlm.vision_tokenizer.layers.5.0.norm_media.bias": "model-00001-of-00004.safetensors",
712
+ "vlm.vision_tokenizer.layers.5.0.norm_media.weight": "model-00001-of-00004.safetensors",
713
+ "vlm.vision_tokenizer.layers.5.0.to_kv.weight": "model-00001-of-00004.safetensors",
714
+ "vlm.vision_tokenizer.layers.5.0.to_out.weight": "model-00001-of-00004.safetensors",
715
+ "vlm.vision_tokenizer.layers.5.0.to_q.weight": "model-00001-of-00004.safetensors",
716
+ "vlm.vision_tokenizer.layers.5.1.0.bias": "model-00001-of-00004.safetensors",
717
+ "vlm.vision_tokenizer.layers.5.1.0.weight": "model-00001-of-00004.safetensors",
718
+ "vlm.vision_tokenizer.layers.5.1.1.weight": "model-00001-of-00004.safetensors",
719
+ "vlm.vision_tokenizer.layers.5.1.3.weight": "model-00001-of-00004.safetensors",
720
+ "vlm.vision_tokenizer.norm.bias": "model-00001-of-00004.safetensors",
721
+ "vlm.vision_tokenizer.norm.weight": "model-00001-of-00004.safetensors",
722
+ "vlm.vision_tokenizer.projection.bias": "model-00001-of-00004.safetensors",
723
+ "vlm.vision_tokenizer.projection.weight": "model-00001-of-00004.safetensors"
724
+ }
725
+ }
modeling_xgenmm.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
2
+ import torch
3
+ import open_clip
4
+ from typing import List, Optional, Tuple, Union
5
+ from utils import check_embedding_fns
6
+ from vlm import PerceiverResampler, XGenMMPerceiver
7
+ from configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig
8
+
9
+ class XGenMMVisionEncoder(PreTrainedModel):
10
+ main_input_name = "pixel_values"
11
+ config_class = XGenMMVisionEncoderConfig
12
+
13
+ def __init__(self, config: XGenMMVisionEncoderConfig):
14
+ super().__init__(config)
15
+ if config.model_name != 'google/siglip-so400m-patch14-384':
16
+ raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
17
+ self.model = AutoModel.from_pretrained(config.model_name)
18
+
19
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
20
+ # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
21
+ return self.model.encode_image(pixel_values)
22
+
23
+
24
+ # vision tokenizer
25
+ class XGenMMVisionTokenizer(PreTrainedModel):
26
+ config_class = XGenMMVisionTokenizerConfig
27
+ def __init__(self, config: XGenMMVisionTokenizerConfig):
28
+ super().__init__(config)
29
+ self.model = PerceiverResampler(
30
+ dim=config.vis_feature_dim,
31
+ dim_inner=config.lang_embedding_dim,
32
+ num_latents=config.num_vis_tokens,
33
+ )
34
+
35
+ def forward(self,
36
+ vision_features: torch.Tensor,
37
+ vision_attn_masks: torch.Tensor):
38
+ return self.model(vision_features, vision_attn_masks)
39
+
40
+ # XGenMM model
41
+ class XGenMMModelForConditionalGeneration(PreTrainedModel):
42
+ config_class = XGenMMConfig
43
+
44
+ def __init__(self, config: XGenMMConfig):
45
+ super().__init__(config)
46
+
47
+ # vision encoder initialization
48
+ vision_encoder = AutoModel.from_pretrained(config.vision_encoder_config.model_name).vision_model
49
+
50
+ # language model initialization
51
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
52
+ check_embedding_fns(language_model)
53
+ # Update _tied_weights_keys using the base model used.
54
+ if language_model._tied_weights_keys is not None:
55
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
56
+
57
+ # vision tokenizer initialization
58
+ if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
59
+ overwrite = language_model.get_input_embeddings().weight.shape[1]
60
+ config.vision_tokenizer_config.lang_embedding_dim = overwrite
61
+ print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
62
+
63
+ vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
64
+
65
+ self.vlm = XGenMMPerceiver(
66
+ vision_encoder=vision_encoder,
67
+ vision_tokenizer=vision_tokenizer,
68
+ lang_model=language_model,
69
+ initial_tokenizer_len = config.text_config.initial_tokenizer_len,
70
+ pad_token_id = config.text_config.pad_token_id,
71
+ image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
72
+ )
73
+ # Initialize weights and apply final processing
74
+ self.post_init()
75
+
76
+ @torch.no_grad()
77
+ def generate(
78
+ self,
79
+ pixel_values: torch.FloatTensor,
80
+ input_ids: Optional[torch.LongTensor] = None,
81
+ attention_mask: Optional[torch.LongTensor] = None,
82
+ **generate_kwargs,
83
+ ) -> torch.LongTensor:
84
+ self.vlm = self.vlm.eval()
85
+ return self.vlm.generate(
86
+ vision_x = pixel_values,
87
+ lang_x = input_ids,
88
+ attention_mask = attention_mask,
89
+ **generate_kwargs)
90
+
91
+ def update_special_tokens(self, tokenizer):
92
+ tokenizer.add_special_tokens(
93
+ {"additional_special_tokens": list(self.vlm.special_tokens.values())}
94
+ )
95
+ self.vlm.lang_model.config.vocab_size = len(tokenizer)
96
+ self.vlm.set_special_token_ids(
97
+ {
98
+ v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
99
+ }
100
+ )
101
+ return tokenizer
102
+
preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_blip_3.Blip3ImageProcessor"
4
+ },
5
+ "do_resize": true,
6
+ "grids": null,
7
+ "image_mean": [
8
+ 0.5,
9
+ 0.5,
10
+ 0.5
11
+ ],
12
+ "image_processor_type": "Blip3ImageProcessor",
13
+ "image_std": [
14
+ 0.5,
15
+ 0.5,
16
+ 0.5
17
+ ],
18
+ "interpolation_mode": "bicubic",
19
+ "resize_mode": "squash",
20
+ "size": [
21
+ 384,
22
+ 384
23
+ ]
24
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
test_samples/few_shots.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "example_1": {
3
+ "image_path": "./test_samples/images/COCO_val2014_000000486568.jpg",
4
+ "instruction": "A short description of this image in one sentence:",
5
+ "output": "A man in a suit holding something in his office."
6
+ },
7
+ "example_2": {
8
+ "image_path": "./test_samples/images/COCO_val2014_000000176466.jpg",
9
+ "instruction": "A short description of this image in one sentence:",
10
+ "output": "The young girl is standing by the fire hydrant in curlers."
11
+ },
12
+ "example_3": {
13
+ "image_path": "./test_samples/images/COCO_val2014_000000392640.jpg",
14
+ "instruction": "A short description of this image in one sentence:",
15
+ "output": "A man with a skateboard that is jumping in the air."
16
+ },
17
+ "example_4": {
18
+ "image_path": "./test_samples/images/COCO_val2014_000000267408.jpg",
19
+ "instruction": "A short description of this image in one sentence:",
20
+ "output": "A few people looking at a television that's next to a laptop."
21
+ }
22
+ }
test_samples/images/000adfe5b817011c.jpg ADDED
test_samples/images/COCO_val2014_000000176466.jpg ADDED
test_samples/images/COCO_val2014_000000267408.jpg ADDED
test_samples/images/COCO_val2014_000000392640.jpg ADDED
test_samples/images/COCO_val2014_000000486568.jpg ADDED
test_samples/zero_shot.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "image_path": "./test_samples/images/000adfe5b817011c.jpg",
3
+ "instruction": "Please provide a short description of this image:"
4
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": true,
26
+ "single_word": false,
27
+ "special": false
28
+ },
29
+ "32000": {
30
+ "content": "<|endoftext|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "32001": {
38
+ "content": "<|assistant|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": true,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "32002": {
46
+ "content": "<|placeholder1|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": true,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "32003": {
54
+ "content": "<|placeholder2|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": true,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "32004": {
62
+ "content": "<|placeholder3|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": true,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "32005": {
70
+ "content": "<|placeholder4|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": true,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "32006": {
78
+ "content": "<|system|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": true,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "32007": {
86
+ "content": "<|end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": true,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "32008": {
94
+ "content": "<|placeholder5|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": true,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "32009": {
102
+ "content": "<|placeholder6|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": true,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "32010": {
110
+ "content": "<|user|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": true,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "32011": {
118
+ "content": "<pad>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": true
124
+ }
125
+ },
126
+ "bos_token": "<s>",
127
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
128
+ "clean_up_tokenization_spaces": false,
129
+ "eos_token": "<|endoftext|>",
130
+ "legacy": false,
131
+ "model_max_length": 4096,
132
+ "pad_token": "<pad>",
133
+ "padding_side": "left",
134
+ "sp_model_kwargs": {},
135
+ "tokenizer_class": "LlamaTokenizer",
136
+ "unk_token": "<unk>",
137
+ "use_default_system_prompt": false
138
+ }
utils.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import ast
3
+ import math
4
+ from PIL import Image
5
+ from packaging.version import Version
6
+
7
+ def has_fn(model, fn_name):
8
+ """Check if model has a function fn_name"""
9
+ return callable(getattr(model, fn_name, None))
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+ def num_params(module, filter_to_trainable=False):
15
+ """Returns the number of parameters in the module, or optionally only the trainable parameters"""
16
+ if filter_to_trainable:
17
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
18
+ else:
19
+ return sum(p.numel() for p in module.parameters())
20
+
21
+ def hasattr_recursive(obj, att):
22
+ """
23
+ Check if obj has nested attribute
24
+ Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
25
+ """
26
+ if att == "":
27
+ return True
28
+ i = att.find(".")
29
+ if i < 0:
30
+ return hasattr(obj, att)
31
+ else:
32
+ try:
33
+ return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
34
+ except:
35
+ return False
36
+
37
+ def getattr_recursive(obj, att):
38
+ """
39
+ Return nested attribute of obj
40
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
41
+ """
42
+ if att == "":
43
+ return obj
44
+ i = att.find(".")
45
+ if i < 0:
46
+ return getattr(obj, att)
47
+ else:
48
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
49
+
50
+
51
+ def setattr_recursive(obj, att, val):
52
+ """
53
+ Set nested attribute of obj
54
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
55
+ """
56
+ if "." in att:
57
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
58
+ setattr(obj, att.split(".")[-1], val)
59
+
60
+
61
+ def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
62
+ """
63
+ Stack a list of tensors with padding on one side
64
+ Args:
65
+ list_of_tensors (list[torch.Tensor]): List of tensors to stack
66
+ padding_value (int, optional): Value to pad with. Defaults to 0.
67
+ padding_side (str, optional): Side to pad on. Defaults to "right".
68
+ Returns:
69
+ torch.Tensor: Stacked tensors
70
+ """
71
+ max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
72
+ padded_tensors = []
73
+ for tensor in list_of_tensors:
74
+ num_tokens = tensor.size(0)
75
+ if len(tensor.size()) == 1:
76
+ padding = torch.full(
77
+ (max_tokens - num_tokens,),
78
+ padding_value,
79
+ dtype=tensor.dtype,
80
+ device=tensor.device,
81
+ )
82
+ else:
83
+ padding = torch.full(
84
+ (max_tokens - num_tokens, tensor.size(1)),
85
+ padding_value,
86
+ dtype=tensor.dtype,
87
+ device=tensor.device,
88
+ )
89
+ padded_tensor = (
90
+ torch.cat((tensor, padding), dim=0)
91
+ if padding_side == "right"
92
+ else torch.cat((padding, tensor), dim=0)
93
+ )
94
+ padded_tensors.append(padded_tensor)
95
+ return torch.stack(padded_tensors)
96
+
97
+
98
+ def check_embedding_fns(lang_model):
99
+ """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
100
+ if not has_fn(lang_model, "get_input_embeddings"):
101
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
102
+ lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
103
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
104
+ lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
105
+ else:
106
+ raise ValueError(
107
+ "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
108
+ )
109
+
110
+ if not has_fn(lang_model, "set_input_embeddings"):
111
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
112
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
113
+ lang_model, "transformer.wte", x
114
+ )
115
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
116
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
117
+ lang_model, "model.decoder.embed_tokens", x
118
+ )
119
+ else:
120
+ raise ValueError(
121
+ "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
122
+ )
123
+
124
+ if not has_fn(lang_model, "get_output_embeddings"):
125
+ if hasattr_recursive(lang_model, "lm_head"):
126
+ lang_model.get_output_embeddings = lambda: lang_model.lm_head
127
+ else:
128
+ raise ValueError(
129
+ "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
130
+ )
131
+
132
+ if not has_fn(lang_model, "set_output_embeddings"):
133
+ if hasattr_recursive(lang_model, "lm_head"):
134
+ lang_model.set_output_embeddings = lambda x: setattr_recursive(
135
+ lang_model, "lm_head", x
136
+ )
137
+ else:
138
+ raise ValueError(
139
+ "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
140
+ )
141
+
142
+
143
+ def has_fn(model, fn_name):
144
+ """Check if model has a function fn_name"""
145
+ return callable(getattr(model, fn_name, None))
146
+
147
+
148
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
149
+ #
150
+ # Licensed under the Apache License, Version 2.0 (the "License");
151
+ # you may not use this file except in compliance with the License.
152
+ # You may obtain a copy of the License at
153
+ #
154
+ # http://www.apache.org/licenses/LICENSE-2.0
155
+ #
156
+ # Unless required by applicable law or agreed to in writing, software
157
+ # distributed under the License is distributed on an "AS IS" BASIS,
158
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
159
+ # See the License for the specific language governing permissions and
160
+ # limitations under the License.
161
+
162
+ def unpad_image(tensor, original_size, keep_original_shape=False):
163
+ """
164
+ Unpads a PyTorch tensor of a padded and resized image.
165
+
166
+ Args:
167
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
168
+ original_size (tuple): The original size of the image (height, width).
169
+
170
+ Returns:
171
+ torch.Tensor: The unpadded image tensor.
172
+ """
173
+ original_width, original_height = original_size
174
+ current_height, current_width = tensor.shape[1:]
175
+
176
+ original_aspect_ratio = original_width / original_height
177
+ current_aspect_ratio = current_width / current_height
178
+
179
+ if original_aspect_ratio > current_aspect_ratio:
180
+ scale_factor = current_width / original_width
181
+ new_height = int(original_height * scale_factor)
182
+ padding = (current_height - new_height) // 2
183
+ if keep_original_shape:
184
+ attention_mask = torch.ones((current_height, current_width), device=tensor.device)
185
+ attention_mask[:padding, :] = 0
186
+ attention_mask[current_height - padding:, :] = 0
187
+ return tensor, attention_mask
188
+ else:
189
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
190
+ return unpadded_tensor, None
191
+ else:
192
+ scale_factor = current_height / original_height
193
+ new_width = int(original_width * scale_factor)
194
+ padding = (current_width - new_width) // 2
195
+ if keep_original_shape:
196
+ attention_mask = torch.ones((current_height, current_width), device=tensor.device)
197
+ attention_mask[:, :padding] = 0
198
+ attention_mask[:, current_width - padding:] = 0
199
+ return tensor, attention_mask
200
+ else:
201
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
202
+ return unpadded_tensor, None
203
+
204
+
205
+ def select_best_resolution(original_size, possible_resolutions):
206
+ """
207
+ Selects the best resolution from a list of possible resolutions based on the original size.
208
+
209
+ Args:
210
+ original_size (tuple): The original size of the image in the format (width, height).
211
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
212
+
213
+ Returns:
214
+ tuple: The best fit resolution in the format (width, height).
215
+ """
216
+ original_width, original_height = original_size
217
+ best_fit = None
218
+ max_effective_resolution = 0
219
+ min_wasted_resolution = float('inf')
220
+
221
+ for width, height in possible_resolutions:
222
+ scale = min(width / original_width, height / original_height)
223
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
224
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
225
+ wasted_resolution = (width * height) - effective_resolution
226
+
227
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
228
+ max_effective_resolution = effective_resolution
229
+ min_wasted_resolution = wasted_resolution
230
+ best_fit = (width, height)
231
+
232
+ return best_fit
233
+
234
+
235
+ def resize_and_pad_image(image, target_resolution):
236
+ """
237
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
238
+
239
+ Args:
240
+ image (PIL.Image.Image): The input image.
241
+ target_resolution (tuple): The target resolution (width, height) of the image.
242
+
243
+ Returns:
244
+ PIL.Image.Image: The resized and padded image.
245
+ """
246
+ original_width, original_height = image.size
247
+ target_width, target_height = target_resolution
248
+
249
+ scale_w = target_width / original_width
250
+ scale_h = target_height / original_height
251
+
252
+ if scale_w < scale_h:
253
+ new_width = target_width
254
+ new_height = min(math.ceil(original_height * scale_w), target_height)
255
+ else:
256
+ new_height = target_height
257
+ new_width = min(math.ceil(original_width * scale_h), target_width)
258
+
259
+ # Resize the image
260
+ resized_image = image.resize((new_width, new_height))
261
+
262
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
263
+ paste_x = (target_width - new_width) // 2
264
+ paste_y = (target_height - new_height) // 2
265
+ new_image.paste(resized_image, (paste_x, paste_y))
266
+
267
+ return new_image
268
+
269
+
270
+ def divide_to_patches(image, patch_size):
271
+ """
272
+ Divides an image into patches of a specified size.
273
+
274
+ Args:
275
+ image (PIL.Image.Image): The input image.
276
+ patch_size (int): The size of each patch.
277
+
278
+ Returns:
279
+ list: A list of PIL.Image.Image objects representing the patches.
280
+ """
281
+ patches = []
282
+ width, height = image.size
283
+ for i in range(0, height, patch_size):
284
+ for j in range(0, width, patch_size):
285
+ box = (j, i, j + patch_size, i + patch_size)
286
+ patch = image.crop(box)
287
+ patches.append(patch)
288
+
289
+ return patches
290
+
291
+
292
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
293
+ """
294
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
295
+
296
+ Args:
297
+ image_size (tuple): The size of the input image in the format (width, height).
298
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
299
+ patch_size (int): The size of each image patch.
300
+
301
+ Returns:
302
+ tuple: The shape of the image patch grid in the format (width, height).
303
+ """
304
+ if type(grid_pinpoints) is list:
305
+ possible_resolutions = grid_pinpoints
306
+ else:
307
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
308
+ width, height = select_best_resolution(image_size, possible_resolutions)
309
+ return width // patch_size, height // patch_size
310
+
311
+
312
+ def process_anyres_image(image, processor, grid_pinpoints):
313
+ """
314
+ Process an image with variable resolutions.
315
+
316
+ Args:
317
+ image (PIL.Image.Image): The input image to be processed.
318
+ processor: The image processor object.
319
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
320
+
321
+ Returns:
322
+ torch.Tensor: A tensor containing the processed image patches.
323
+ """
324
+ # FIXME: determine grid_pinpoints from image sizes.
325
+ if type(grid_pinpoints) is list:
326
+ possible_resolutions = grid_pinpoints
327
+ else:
328
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
329
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
330
+ image_padded = resize_and_pad_image(image, best_resolution)
331
+
332
+ processor_size = processor.transforms[0].size
333
+ patches = divide_to_patches(image_padded, processor_size[0])
334
+
335
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
336
+
337
+ image_patches = [image_original_resize] + patches
338
+ image_patches = [processor(image_patch)
339
+ for image_patch in image_patches]
340
+ return torch.stack(image_patches, dim=0)
341
+
342
+
343
+ def expand2square(pil_img, background_color):
344
+ width, height = pil_img.size
345
+ if width == height:
346
+ return pil_img
347
+ elif width > height:
348
+ result = Image.new(pil_img.mode, (width, width), background_color)
349
+ result.paste(pil_img, (0, (width - height) // 2))
350
+ return result
351
+ else:
352
+ result = Image.new(pil_img.mode, (height, height), background_color)
353
+ result.paste(pil_img, ((height - width) // 2, 0))
354
+ return result
355
+
356
+
357
+ def process_images(images, image_processor, model_cfg):
358
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
359
+ new_images = []
360
+ if image_aspect_ratio == 'pad':
361
+ for image in images:
362
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
363
+ image = image_processor(image)
364
+ new_images.append(image)
365
+ elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
366
+ base_img_size = image_processor.transforms[0].size[0]
367
+ for image in images:
368
+ image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
369
+ [base_img_size*2,base_img_size],
370
+ [base_img_size*2,base_img_size*2],
371
+ [base_img_size*3,base_img_size],
372
+ [base_img_size,base_img_size*3]])
373
+
374
+ # Debug any res inference by only using 672x672.
375
+ # image = process_anyres_image(image, image_processor, [[base_img_size*2,base_img_size*2]])
376
+ new_images.append(image)
377
+ else:
378
+ return image_processor(images)
379
+ if all(x.shape == new_images[0].shape for x in new_images):
380
+ new_images = torch.stack(new_images, dim=0)
381
+ return new_images
382
+
383
+
vlm.py ADDED
@@ -0,0 +1,1314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import einsum, nn
4
+ from einops import rearrange, repeat
5
+ from einops_exts import rearrange_many
6
+ from einops import rearrange
7
+ from typing import List, Optional, Tuple, Union
8
+ import torch.nn.functional as F
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from dataclasses import dataclass
11
+ from transformers import CLIPVisionModel
12
+ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
13
+
14
+ import transformers
15
+ from packaging.version import Version
16
+
17
+ from utils import num_params, getattr_recursive, stack_with_padding, get_anyres_image_grid_shape, unpad_image
18
+
19
+
20
+ class VisionTokenizer(nn.Module):
21
+ def __init__(self, dim_media, num_tokens_per_media):
22
+ super().__init__()
23
+ self.dim_media = dim_media
24
+ self.num_tokens_per_media = num_tokens_per_media
25
+
26
+ class PerceiverAttention(nn.Module):
27
+ def __init__(self, *, dim, dim_head=64, heads=8):
28
+ super().__init__()
29
+ self.scale = dim_head**-0.5
30
+ self.heads = heads
31
+ inner_dim = dim_head * heads
32
+
33
+ self.norm_media = nn.LayerNorm(dim)
34
+ self.norm_latents = nn.LayerNorm(dim)
35
+
36
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
37
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
38
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
39
+
40
+ def forward(self, x, latents, vision_attn_masks=None):
41
+ """
42
+ Args:
43
+ x (torch.Tensor): image features
44
+ shape (b, T, n1, D)
45
+ latent (torch.Tensor): latent features
46
+ shape (b, T, n2, D)
47
+ """
48
+ x = self.norm_media(x)
49
+ latents = self.norm_latents(latents)
50
+
51
+ h = self.heads
52
+
53
+ q = self.to_q(latents)
54
+ kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
55
+ if vision_attn_masks is not None:
56
+ vision_attn_masks = torch.cat((vision_attn_masks,
57
+ torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
58
+ dim=-1)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61
+ q = q * self.scale
62
+
63
+ # attention
64
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
65
+ # Apply vision attention mask here.
66
+ # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
67
+ if vision_attn_masks is not None:
68
+ attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
69
+ vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
70
+ attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
71
+ sim += attn_bias
72
+
73
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
74
+ attn = sim.softmax(dim=-1)
75
+
76
+
77
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
78
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
79
+ return self.to_out(out)
80
+
81
+
82
+ def FeedForward(dim, mult=4):
83
+ inner_dim = int(dim * mult)
84
+ return nn.Sequential(
85
+ nn.LayerNorm(dim),
86
+ nn.Linear(dim, inner_dim, bias=False),
87
+ nn.GELU(),
88
+ nn.Linear(inner_dim, dim, bias=False),
89
+ )
90
+
91
+
92
+ class PerceiverResampler(VisionTokenizer):
93
+ def __init__(
94
+ self,
95
+ *,
96
+ dim,
97
+ dim_inner=None,
98
+ depth=6,
99
+ dim_head=96,
100
+ heads=16,
101
+ num_latents=128,
102
+ max_num_media=None,
103
+ max_num_frames=None,
104
+ ff_mult=4,
105
+ ):
106
+ """
107
+ Perceiver module which takes in image features and outputs image tokens.
108
+ Args:
109
+ dim (int): dimension of the incoming image features
110
+ dim_inner (int, optional): final dimension to project the incoming image features to;
111
+ also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
112
+ depth (int, optional): number of layers. Defaults to 6.
113
+ dim_head (int, optional): dimension of each head. Defaults to 64.
114
+ heads (int, optional): number of heads. Defaults to 8.
115
+ num_latents (int, optional): number of latent tokens to use in the Perceiver;
116
+ also corresponds to number of tokens per sequence to output. Defaults to 64.
117
+ max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
118
+ and keep positional embeddings for. If None, no positional embeddings are used.
119
+ max_num_frames (int, optional): maximum number of frames to input into the Perceiver
120
+ and keep positional embeddings for. If None, no positional embeddings are used.
121
+ ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
122
+ """
123
+ if dim_inner is not None:
124
+ projection = nn.Linear(dim, dim_inner)
125
+ else:
126
+ projection = None
127
+ dim_inner = dim
128
+ super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
129
+ self.projection = projection
130
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
131
+
132
+ # positional embeddings
133
+ self.frame_embs = (
134
+ nn.Parameter(torch.randn(max_num_frames, dim))
135
+ if exists(max_num_frames)
136
+ else None
137
+ )
138
+ self.media_time_embs = (
139
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
140
+ if exists(max_num_media)
141
+ else None
142
+ )
143
+
144
+ self.layers = nn.ModuleList([])
145
+ for _ in range(depth):
146
+ self.layers.append(
147
+ nn.ModuleList(
148
+ [
149
+ PerceiverAttention(
150
+ dim=dim, dim_head=dim_head, heads=heads
151
+ ),
152
+ FeedForward(dim=dim, mult=ff_mult),
153
+ ]
154
+ )
155
+ )
156
+
157
+ self.norm = nn.LayerNorm(dim)
158
+
159
+ def forward(self, x, vision_attn_masks=None):
160
+ """
161
+ Args:
162
+ x (torch.Tensor): image features
163
+ shape (b, T, F, v, D)
164
+ vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
165
+ shape (b, v)
166
+ Returns:
167
+ shape (b, T, n, D) where n is self.num_latents
168
+ """
169
+ b, T, F, v = x.shape[:4]
170
+
171
+ # frame and media time embeddings
172
+ if exists(self.frame_embs):
173
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
174
+ x = x + frame_embs
175
+ x = rearrange(
176
+ x, "b T F v d -> b T (F v) d"
177
+ ) # flatten the frame and spatial dimensions
178
+ if exists(self.media_time_embs):
179
+ x = x + self.media_time_embs[:T]
180
+
181
+ # blocks
182
+ latents = self.latents
183
+ latents = repeat(latents, "n d -> b T n d", b=b, T=T)
184
+ for attn, ff in self.layers:
185
+ latents = attn(x, latents, vision_attn_masks) + latents
186
+ latents = ff(latents) + latents
187
+
188
+ if exists(self.projection):
189
+ return self.projection(self.norm(latents))
190
+ else:
191
+ return self.norm(latents)
192
+
193
+
194
+ class DecoupledEmbedding(nn.Embedding):
195
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
196
+ """
197
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
198
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
199
+ then it will create `num_additional_embeddings` additional parameters that are always trained. If
200
+ `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ max_original_id: int,
206
+ num_additional_embeddings: int = 0,
207
+ _weight: torch.Tensor = None,
208
+ num_original_embeddings: int = None,
209
+ embedding_dim: int = None,
210
+ partially_freeze=True,
211
+ device=None,
212
+ dtype=None,
213
+ pad_token_id=None,
214
+ ) -> None:
215
+ """
216
+ Args:
217
+ max_original_id (`int`):
218
+ The largest token id that should be embedded using the regular embedding (regular `weight`).
219
+ This is usually len(tokenizer) - 1 before additional tokens are added.
220
+ Note that this may not equal self.weight.shape[0]
221
+ num_additional_embeddings (`int`):
222
+ Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
223
+ _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
224
+ If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
225
+ num_original_embeddings (`int`):
226
+ self.weight.shape[0]
227
+ embedding_dim (`int`):
228
+ The size of each embedding vector
229
+ partially_freeze: (`bool`, *optional*, defaults to `True`):
230
+ If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
231
+ padding_idx (`int`, *optional*):
232
+ The padding index (needs to be less than num_embeddings)
233
+
234
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
235
+ `max_norm` or `norm_type`. We are not supporting these.
236
+ """
237
+ # validate args
238
+ if pad_token_id is not None and pad_token_id > max_original_id:
239
+ raise ValueError(
240
+ f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
241
+ + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
242
+ )
243
+ if _weight is not None:
244
+ assert (num_original_embeddings is None) or (
245
+ _weight.shape[0] == num_original_embeddings
246
+ ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
247
+ assert (embedding_dim is None) or (
248
+ _weight.shape[1] == embedding_dim
249
+ ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
250
+ num_original_embeddings = _weight.shape[0]
251
+ embedding_dim = _weight.shape[1]
252
+ else:
253
+ assert (
254
+ num_original_embeddings is not None
255
+ ), "num_original_embeddings must be provided if _weight is not provided"
256
+ assert (
257
+ embedding_dim is not None
258
+ ), "embedding_dim must be provided if _weight is not provided"
259
+
260
+ super().__init__(
261
+ num_embeddings=num_original_embeddings,
262
+ embedding_dim=embedding_dim,
263
+ device=device,
264
+ dtype=dtype,
265
+ padding_idx=pad_token_id,
266
+ _weight=_weight,
267
+ )
268
+ self.max_original_id = max_original_id
269
+ self.padding_idx = pad_token_id
270
+ self.num_additional_embeddings = num_additional_embeddings
271
+ if self.num_additional_embeddings > 0:
272
+ self.additional_embedding = nn.Embedding(
273
+ num_embeddings=self.num_additional_embeddings,
274
+ embedding_dim=embedding_dim,
275
+ device=device,
276
+ dtype=dtype,
277
+ )
278
+ self.set_requires_grad(
279
+ require_regular_grad=not partially_freeze, require_additional_grad=True
280
+ )
281
+
282
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
283
+ """
284
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
285
+ """
286
+ self.weight.requires_grad_(require_regular_grad)
287
+ self.additional_embedding.requires_grad_(require_additional_grad)
288
+
289
+ def forward(self, input_ids):
290
+ """
291
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
292
+ self.additional_embedding.weight that is being trained.
293
+
294
+ in order to make a lookup of the input ids, we:
295
+ 1. find out the indices of the entries belonging to the 2nd embedding
296
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
297
+ embedding starts from 0 and not num_embeddings
298
+ 3. perform the 2nd embedding lookup
299
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
300
+ 5. perform the 1st embedding lookup
301
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
302
+
303
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
304
+ then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
305
+ i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
306
+ usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
307
+ measure.
308
+
309
+ """
310
+ if self.num_additional_embeddings == 0:
311
+ return F.embedding(input_ids, self.weight)
312
+
313
+ # Clone so that we don't modify the original input_ids later on
314
+ input_ids = input_ids.clone()
315
+ additional_vocab_indices = torch.where(input_ids > self.max_original_id)
316
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
317
+ additional_embeddings = self.additional_embedding(
318
+ input_ids_additional_vocab - self.max_original_id - 1
319
+ )
320
+
321
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
322
+ input_ids[additional_vocab_indices] = 0
323
+ full_vector = F.embedding(input_ids, self.weight)
324
+
325
+ # overwrite the records with high indices
326
+ full_vector[additional_vocab_indices] = additional_embeddings
327
+
328
+ return full_vector
329
+
330
+ def extra_repr(self) -> str:
331
+ return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
332
+ self.max_original_id + 1,
333
+ self.num_additional_embeddings,
334
+ self.embedding_dim,
335
+ (not self.weight.requires_grad),
336
+ )
337
+
338
+
339
+ class DecoupledLinear(nn.Linear):
340
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
341
+ """
342
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
343
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
344
+ then it will create `additional_out_features * in_features` additional parameters that are always trained. If
345
+ `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ max_original_id: int,
351
+ additional_out_features: int = 0,
352
+ _weight: torch.Tensor = None,
353
+ _bias: torch.Tensor = None,
354
+ in_features: int = None,
355
+ original_out_features: int = None,
356
+ bias: bool = True,
357
+ partially_freeze: bool = True,
358
+ device=None,
359
+ dtype=None,
360
+ ) -> None:
361
+ """
362
+ Args:
363
+ max_original_id (`int`): The largest token id that should be extracted from the regular weight.
364
+ This is usually len(tokenizer) - 1 before additional tokens are added.
365
+ Note that this may not equal original_out_features - 1
366
+ _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
367
+ If provided, this sets the `in_features` and `original_out_features` parameters.
368
+ _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
369
+ in_features: int. Input hidden size.
370
+ original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
371
+ additional_out_features: int. Number of additional trainable dimensions.
372
+ bias: bool. Whether to include a bias term.
373
+ partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
374
+ """
375
+ # argument validation
376
+ if _weight is not None:
377
+ assert (_weight.shape[0] == original_out_features) or (
378
+ original_out_features is None
379
+ ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
380
+ assert (_weight.shape[1] == in_features) or (
381
+ in_features is None
382
+ ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
383
+ in_features = _weight.shape[1]
384
+ original_out_features = _weight.shape[0]
385
+ else:
386
+ assert (
387
+ in_features is not None
388
+ ), "in_features must be provided if _weight is not provided"
389
+ assert (
390
+ original_out_features is not None
391
+ ), "original_out_features must be provided if _weight is not provided"
392
+
393
+ if _bias is not None:
394
+ assert bias is True, "bias must be True if _bias is provided"
395
+
396
+ # initialize original linear
397
+ super().__init__(
398
+ in_features,
399
+ original_out_features,
400
+ bias,
401
+ device,
402
+ dtype)
403
+
404
+ # set weight and bias manually
405
+ if _weight is not None:
406
+ self.weight = nn.Parameter(_weight)
407
+ if _bias is not None:
408
+ self.bias = nn.Parameter(_bias)
409
+
410
+ self.in_features = in_features
411
+ self.original_out_features = original_out_features
412
+ self.max_original_id = max_original_id
413
+
414
+ # initialize additional linear
415
+ self.additional_out_features = additional_out_features
416
+ self.has_bias = bias
417
+ if additional_out_features > 0:
418
+ self.additional_fc = nn.Linear(
419
+ in_features=in_features,
420
+ out_features=additional_out_features,
421
+ bias=self.has_bias,
422
+ device=device,
423
+ dtype=dtype,
424
+ )
425
+ self.set_requires_grad(
426
+ require_regular_grad=not partially_freeze, require_additional_grad=True
427
+ )
428
+
429
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
430
+ """
431
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
432
+ """
433
+ self.weight.requires_grad_(require_regular_grad)
434
+ if self.has_bias:
435
+ self.bias.requires_grad_(require_regular_grad)
436
+ self.additional_fc.requires_grad_(require_additional_grad)
437
+
438
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
439
+ output = F.linear(input, self.weight, self.bias)
440
+ output = output[..., : self.max_original_id + 1]
441
+
442
+ if self.additional_out_features > 0:
443
+ additional_features = F.linear(
444
+ input, self.additional_fc.weight, self.additional_fc.bias
445
+ )
446
+ output = torch.cat((output, additional_features), -1)
447
+ return output
448
+
449
+ def extra_repr(self) -> str:
450
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
451
+ return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
452
+ self.in_features,
453
+ self.max_original_id + 1,
454
+ self.additional_out_features,
455
+ self.bias is not None,
456
+ (not self.weight.requires_grad or not self.bias.requires_grad),
457
+ )
458
+
459
+ class VLM(nn.Module):
460
+ """
461
+ Generic vision-language model (VLM) class.
462
+ A VLM consists of four components:
463
+ 1. A vision encoder that extracts features from pixels, e.g. CLIP
464
+ input: (B, T_img, F, C, H, W)
465
+ output: (B, T_img, F, v, d)
466
+ 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
467
+ input: (B, T_img, F, v, d)
468
+ output: (B, T_img, n, d)
469
+ 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
470
+ 4. A language model
471
+ """
472
+
473
+ def __init__(
474
+ self,
475
+ vision_encoder: nn.Module,
476
+ vision_tokenizer: nn.Module,
477
+ lang_model: nn.Module,
478
+ initial_tokenizer_len: int,
479
+ pad_token_id: int,
480
+ gradient_checkpointing: bool = False,
481
+ ):
482
+ """
483
+ Args:
484
+ vision_encoder (nn.Module): e.g. CLIP
485
+ vision_tokenizer (nn.Module): e.g. PerceiverResampler
486
+ lang_model (nn.Module): e.g. MPT
487
+ initial_tokenizer_len (int): size of the original tokenizer vocab
488
+ pad_token_id (int): id of the pad token
489
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
490
+ """
491
+ super().__init__()
492
+
493
+ # save dimension information
494
+ self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
495
+ if hasattr(lang_model.config, "d_model"):
496
+ self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
497
+ else:
498
+ self.lang_hidden_dim = lang_model.config.hidden_size
499
+ self.vis_embedding_dim = vision_tokenizer.dim_media
500
+ self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
501
+
502
+ # core components
503
+ self.vision_encoder = vision_encoder
504
+ self.vision_tokenizer = vision_tokenizer
505
+ self.lang_model = lang_model
506
+
507
+ # lm embeddings
508
+ self.pad_token_id = pad_token_id
509
+ self.initial_tokenizer_len = initial_tokenizer_len
510
+ input_embeds = DecoupledEmbedding(
511
+ max_original_id=initial_tokenizer_len - 1,
512
+ num_additional_embeddings=len(self.special_tokens),
513
+ _weight=self.lang_model.get_input_embeddings().weight,
514
+ pad_token_id=self.pad_token_id,
515
+ )
516
+ if hasattr(input_embeds, "additional_embedding"):
517
+ input_embeds.additional_embedding.weight.data.normal_(
518
+ mean=0.0,
519
+ std=self.lang_model.config.initializer_range
520
+ if hasattr(self.lang_model.config, "initializer_range")
521
+ else 0.02,
522
+ )
523
+ self.lang_model.set_input_embeddings(input_embeds)
524
+
525
+ out_embeds = DecoupledLinear(
526
+ max_original_id=initial_tokenizer_len - 1,
527
+ additional_out_features=len(self.special_tokens),
528
+ _weight=self.lang_model.get_output_embeddings().weight,
529
+ _bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
530
+ )
531
+ if hasattr(out_embeds, "additional_fc"):
532
+ out_embeds.additional_fc.weight.data.normal_(
533
+ mean=0.0,
534
+ std=self.lang_model.config.initializer_range
535
+ if hasattr(self.lang_model.config, "initializer_range")
536
+ else 0.02,
537
+ )
538
+ self.lang_model.set_output_embeddings(out_embeds)
539
+
540
+ # gradient checkpointing
541
+ self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
542
+
543
+ def forward(
544
+ self,
545
+ vision_x: Optional[torch.Tensor],
546
+ lang_x: torch.Tensor,
547
+ attention_mask: Optional[torch.Tensor] = None,
548
+ labels: Optional[torch.Tensor] = None,
549
+ past_key_values: Optional[
550
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
551
+ ] = None,
552
+ past_media_locations: Optional[torch.Tensor] = None,
553
+ past_vision_tokens: Optional[torch.Tensor] = None,
554
+ use_cache: Optional[bool] = False,
555
+ **kwargs,
556
+ ):
557
+ """
558
+ Args:
559
+ vision_x: Vision input
560
+ shape (B, T_img, F, C, H, W) with F=1
561
+ only F = 1 is supported (single-frame videos)
562
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
563
+ only the first number of media tokens in lang_x are used
564
+ lang_x: Language input ids, with media tokens denoting where
565
+ visual media should be inserted.
566
+ shape (B, T_txt)
567
+ attention_mask: Attention mask. Defaults to None.
568
+ labels: Labels. Defaults to None.
569
+ shape (B, T_txt)
570
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
571
+ list of length = number of decoder layers in the LM
572
+ exact implementation depends on LM, see Hugging Face docs
573
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
574
+ shape (B, T_txt)
575
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
576
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
577
+ If True, includes key_values, media_locations, and vision_tokens in the output.
578
+ """
579
+ assert not (past_vision_tokens is None) ^ (
580
+ past_media_locations is None
581
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
582
+
583
+ # convert pixels to vision tokens
584
+ if vision_x is not None:
585
+ vision_features = self._encode_vision_x(vision_x=vision_x)
586
+ vision_tokens = self.vision_tokenizer(vision_features)
587
+ else:
588
+ vision_tokens = None
589
+
590
+ # fuse the vision and language tokens
591
+ new_inputs = self._prepare_inputs_for_forward(
592
+ vision_tokens=vision_tokens,
593
+ lang_x=lang_x,
594
+ attention_mask=attention_mask,
595
+ labels=labels,
596
+ past_key_values=past_key_values,
597
+ past_media_locations=past_media_locations,
598
+ padding_side="right",
599
+ past_vision_tokens=past_vision_tokens,
600
+ )
601
+ output = self.lang_model(
602
+ **new_inputs,
603
+ use_cache=use_cache,
604
+ past_key_values=past_key_values,
605
+ **kwargs,
606
+ )
607
+
608
+ # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
609
+ # or to add the past_vision_tokens and past_media_locations to the output
610
+ output = self._postprocess_outputs_from_forward(
611
+ output=output,
612
+ lang_x=lang_x,
613
+ vision_tokens=vision_tokens,
614
+ use_cache=use_cache,
615
+ past_vision_tokens=past_vision_tokens,
616
+ past_media_locations=past_media_locations,
617
+ )
618
+
619
+ # postforward hooks
620
+ self._post_forward_hook()
621
+ return output
622
+
623
+ def _encode_vision_x_anyres(self, samples, device):
624
+ assert self.anyres_grids is not None
625
+ image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
626
+ image_sizes = samples["image_size"]
627
+
628
+ # Image_raw can be a list of list of patches, when a `samples` has multiple images.
629
+ if isinstance(image_raw[0], list):
630
+ images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
631
+ image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
632
+ else:
633
+ # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
634
+ # concate list of patches into one big patch for any res encoding.
635
+ images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
636
+ image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
637
+ image = image.to(device)
638
+
639
+ with torch.no_grad():
640
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
641
+ image_embeds = self.vision_encoder.trunk.forward_features(image)
642
+ elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
643
+ image_embeds = self.vision_encoder(image).last_hidden_state
644
+ else:
645
+ image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
646
+
647
+ if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(self.vision_encoder, SiglipVisionTransformer):
648
+ base_img_size = self.vision_encoder.config.image_size
649
+ else:
650
+ base_img_size = self.vision_encoder.image_size[0]
651
+
652
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
653
+ grid_size = self.vision_encoder.trunk.patch_embed.grid_size
654
+ elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
655
+ grid_size_base = self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size
656
+ grid_size = (grid_size_base, grid_size_base)
657
+ else:
658
+ grid_size = self.vision_encoder.grid_size
659
+ height, width = grid_size
660
+
661
+ if not image_embeds.shape[1] == height * width:
662
+ assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
663
+ image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
664
+ n_vis_token_per_patch = image_embeds.shape[1]
665
+
666
+ # Split encoded patches and merge patch features
667
+ # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
668
+ split_sizes = [image.shape[0] for image in images]
669
+ image_embeds = torch.split(image_embeds, split_sizes, dim=0)
670
+ # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
671
+ new_image_embeds = []
672
+ patch_attn_masks = []
673
+ max_n_img_token = -1
674
+ for idx, patch_embeds in enumerate(image_embeds):
675
+ if patch_embeds.shape[0] > 1:
676
+ # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
677
+ base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
678
+ patch_embeds = patch_embeds[1:]
679
+
680
+ assert height * width == base_patch_embeds.shape[0]
681
+
682
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
683
+ self.anyres_grids,
684
+ base_img_size) # Hardcoded grid_pinpoints.
685
+ patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
686
+
687
+ patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
688
+ patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
689
+ patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
690
+ if hasattr(self, 'image_newline'):
691
+ patch_embeds = torch.cat((
692
+ patch_embeds,
693
+ self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
694
+ ), dim=-1)
695
+ if self.anyres_patch_sampling:
696
+ patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
697
+ patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
698
+ assert patch_attn_mask is not None
699
+ patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
700
+ patch_attn_mask = patch_attn_mask.flatten(0, 1)
701
+ patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
702
+ patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
703
+ else:
704
+ patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
705
+ patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
706
+ else:
707
+ patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
708
+ patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
709
+ if hasattr(self, 'image_newline'):
710
+ patch_embeds = torch.cat((
711
+ patch_embeds,
712
+ self.image_newline[None]
713
+ ), dim=0)
714
+ if not self.anyres_patch_sampling:
715
+ max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
716
+
717
+ new_image_embeds.append(patch_embeds)
718
+ patch_attn_masks.append(patch_attn_mask)
719
+
720
+ if self.anyres_patch_sampling:
721
+ # Return individual patches for independent token downsampling.
722
+ return new_image_embeds, patch_attn_masks
723
+
724
+ # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
725
+ image_embeds = []
726
+ image_atts = []
727
+ for image_embed in new_image_embeds:
728
+ n_img_token = image_embed.shape[0]
729
+ img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
730
+ if n_img_token < max_n_img_token:
731
+ padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
732
+ padded_embed[:n_img_token, :] = image_embed
733
+ img_attn[n_img_token:] = 0 # Mask out the padded entries.
734
+ else:
735
+ padded_embed = image_embed
736
+ image_embeds.append(padded_embed)
737
+ image_atts.append(img_attn)
738
+ image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
739
+ image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
740
+ # TODO: reshape image_embeds and image_atts to "b T F v d"
741
+ image_embeds = image_embeds[:, None, None, :, :]
742
+ # image_atts = image_atts[:, None, None, :, :]
743
+
744
+ return image_embeds, image_atts
745
+
746
+ def _encode_vision_x(self, vision_x: torch.Tensor):
747
+ """
748
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
749
+ Args:
750
+ vision_x: Vision input
751
+ shape (B, T_img, F, C, H, W)
752
+ Images in the same chunk are collated along T_img, and frames are collated along F
753
+ Currently only F=1 is supported (single-frame videos)
754
+
755
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
756
+ """
757
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
758
+ b, T, F = vision_x.shape[:3]
759
+
760
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
761
+ with torch.no_grad():
762
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
763
+ vision_x = self.vision_encoder.trunk.forward_features(vision_x)
764
+ elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
765
+ vision_x = self.vision_encoder(vision_x).last_hidden_state
766
+ else:
767
+ vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
768
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
769
+ return vision_x
770
+
771
+ def _concat_vision_cache(
772
+ self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
773
+ ):
774
+ """
775
+ Helper function to include the past vision tokens and past media locations in the output.
776
+ """
777
+ if use_cache:
778
+ if past_media_locations is not None and past_vision_tokens is not None:
779
+ if vision_tokens is not None:
780
+ updated_vision_tokens = torch.cat(
781
+ [
782
+ past_vision_tokens,
783
+ vision_tokens,
784
+ ],
785
+ dim=1,
786
+ )
787
+ else:
788
+ updated_vision_tokens = past_vision_tokens
789
+ updated_media_locations = torch.cat(
790
+ [
791
+ past_media_locations,
792
+ lang_x == self.media_token_id,
793
+ ],
794
+ dim=1,
795
+ )
796
+ else:
797
+ updated_vision_tokens = vision_tokens
798
+ updated_media_locations = lang_x == self.media_token_id
799
+
800
+ else:
801
+ updated_vision_tokens = None
802
+ updated_media_locations = None
803
+
804
+ return updated_vision_tokens, updated_media_locations
805
+
806
+ def generate(
807
+ self,
808
+ vision_x: torch.Tensor,
809
+ lang_x: torch.Tensor,
810
+ attention_mask: torch.Tensor = None,
811
+ past_key_values: Optional[
812
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
813
+ ] = None,
814
+ past_media_locations: Optional[torch.Tensor] = None,
815
+ past_vision_tokens: Optional[torch.Tensor] = None,
816
+ **kwargs,
817
+ ):
818
+ """
819
+ Generate text conditioned on vision and language inputs.
820
+ Args:
821
+ vision_x (torch.Tensor): Vision input
822
+ shape (B, T_img, F, C, H, W)
823
+ see documentation for forward
824
+ lang_x (torch.Tensor): Language input
825
+ shape (B, T_txt)
826
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
827
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
828
+ Returns:
829
+ torch.Tensor: lang_x with generated tokens appended to it
830
+ """
831
+ num_beams = kwargs.pop("num_beams", 1)
832
+
833
+ # convert pixels to vision tokens
834
+ if vision_x is not None:
835
+ vision_features = self._encode_vision_x(vision_x=vision_x)
836
+ vision_tokens = self.vision_tokenizer(vision_features)
837
+ else:
838
+ vision_tokens = None
839
+
840
+ # fuse the vision and language tokens
841
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
842
+ # the total batch size is B * num_beams
843
+ new_inputs = self._prepare_inputs_for_forward(
844
+ vision_tokens=vision_tokens,
845
+ lang_x=lang_x,
846
+ attention_mask=attention_mask,
847
+ past_key_values=past_key_values,
848
+ past_media_locations=past_media_locations,
849
+ past_vision_tokens=past_vision_tokens,
850
+ padding_side="left",
851
+ num_beams=num_beams,
852
+ )
853
+ output = self.lang_model.generate(
854
+ **new_inputs,
855
+ past_key_values=past_key_values,
856
+ num_beams=num_beams,
857
+ use_cache=True,
858
+ **kwargs,
859
+ )
860
+ self._post_forward_hook()
861
+ return output
862
+
863
+ @property
864
+ def num_trainable_params(self):
865
+ """Print the number of trainable parameters"""
866
+ return num_params(self, filter_to_trainable=True)
867
+
868
+ def set_trainable(self):
869
+ """
870
+ Freeze appropriate parameters in the model.
871
+ """
872
+ raise NotImplementedError
873
+
874
+ def group_params_by_weight_decay(self):
875
+ """
876
+ Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
877
+ """
878
+ params_with_wd, params_without_wd = [], []
879
+ for n, p in self.named_parameters():
880
+ if p.requires_grad:
881
+ if self._should_apply_weight_decay(n):
882
+ params_with_wd.append(p)
883
+ else:
884
+ params_without_wd.append(p)
885
+ return params_with_wd, params_without_wd
886
+
887
+ def _should_apply_weight_decay(self, parameter_name):
888
+ """
889
+ Return whether weight decay should be applied to a parameter.
890
+ """
891
+ raise NotImplementedError
892
+
893
+ @property
894
+ def special_tokens(self):
895
+ """
896
+ Returns a dict mapping from the attribute name of a special token to its string format,
897
+ e.g. "media_token": "<image>"
898
+ """
899
+ assert (
900
+ "media_token" in self._special_tokens
901
+ ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
902
+ return self._special_tokens
903
+
904
+ @property
905
+ def special_token_ids(self):
906
+ """
907
+ Returns a list of the special token ids
908
+ """
909
+ return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
910
+
911
+ def set_special_token_ids(self, string_to_ids):
912
+ """
913
+ Args:
914
+ string_to_ids (dict): mapping from token string to id
915
+ """
916
+ assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
917
+ for att_name, token_str in self.special_tokens.items():
918
+ token_id = string_to_ids[token_str]
919
+ setattr(self, f"{att_name}_id", token_id)
920
+ setattr(self.lang_model, f"{att_name}_id", token_id)
921
+
922
+ def init_gradient_checkpointing(self):
923
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
924
+ checkpoint_wrapper,
925
+ CheckpointWrapper,
926
+ CheckpointImpl,
927
+ apply_activation_checkpointing,
928
+ )
929
+ from functools import partial
930
+
931
+ non_reentrant_wrapper = partial(
932
+ checkpoint_wrapper,
933
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
934
+ )
935
+ apply_activation_checkpointing(
936
+ self,
937
+ checkpoint_wrapper_fn=non_reentrant_wrapper,
938
+ check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
939
+ and not isinstance(m, CheckpointWrapper),
940
+ )
941
+
942
+ @dataclass
943
+ class VLMOutputWithPast(CausalLMOutputWithPast):
944
+ """
945
+ VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
946
+ past_media_locations: Optional[torch.Tensor] = None,
947
+ past_vision_tokens: Optional[torch.Tensor] = None,
948
+ """
949
+
950
+ past_media_locations: Optional[torch.Tensor] = None
951
+ past_vision_tokens: Optional[torch.Tensor] = None
952
+
953
+
954
+ def exists(val):
955
+ return val is not None
956
+
957
+
958
+ def FeedForward(dim, mult=4):
959
+ inner_dim = int(dim * mult)
960
+ return nn.Sequential(
961
+ nn.LayerNorm(dim),
962
+ nn.Linear(dim, inner_dim, bias=False),
963
+ nn.GELU(),
964
+ nn.Linear(inner_dim, dim, bias=False),
965
+ )
966
+
967
+ class VLMWithLanguageStream(VLM):
968
+ """
969
+ VLM that fuses modalities by inserting vision tokens directly into the language stream.
970
+ """
971
+
972
+ def __init__(
973
+ self,
974
+ vision_encoder: nn.Module,
975
+ vision_tokenizer: nn.Module,
976
+ lang_model: nn.Module,
977
+ initial_tokenizer_len: int,
978
+ pad_token_id: int,
979
+ decoder_layers_attr_name: str = None,
980
+ gradient_checkpointing: bool = False,
981
+ ):
982
+ super().__init__(
983
+ vision_encoder=vision_encoder,
984
+ vision_tokenizer=vision_tokenizer,
985
+ lang_model=lang_model,
986
+ initial_tokenizer_len=initial_tokenizer_len,
987
+ pad_token_id=pad_token_id,
988
+ gradient_checkpointing=gradient_checkpointing,
989
+ )
990
+ self.decoder_layers_attr_name = decoder_layers_attr_name
991
+ if decoder_layers_attr_name is not None:
992
+ for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
993
+ block._use_gradient_checkpointing = gradient_checkpointing
994
+
995
+ def _prepare_inputs_for_forward(
996
+ self,
997
+ vision_tokens: torch.Tensor,
998
+ lang_x: torch.Tensor,
999
+ attention_mask: torch.Tensor,
1000
+ labels: torch.Tensor = None,
1001
+ past_key_values=None,
1002
+ vision_attention_mask: Optional[torch.Tensor] = None,
1003
+ past_media_locations: torch.Tensor = None,
1004
+ past_vision_tokens: torch.Tensor = None,
1005
+ padding_side: str = "left",
1006
+ num_beams: int = 1,
1007
+ ):
1008
+ """
1009
+ Insert the vision tokens directly into the language stream/
1010
+ This requires us to modify the input_ids, attention_mask, and labels.
1011
+ """
1012
+ if past_key_values is not None:
1013
+ past_len = past_key_values[0][0].shape[2]
1014
+ assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1015
+ "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1016
+ + "Check that you've expanded the attention mask to account for past image tokens."
1017
+ )
1018
+
1019
+ if vision_tokens is None:
1020
+ return {
1021
+ "input_ids": lang_x,
1022
+ "attention_mask": attention_mask,
1023
+ "labels": labels,
1024
+ }
1025
+
1026
+ # get the language embeddings
1027
+ lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1028
+
1029
+ # build up the multimodal embeddings
1030
+ B = lang_x.shape[0]
1031
+ has_labels = labels is not None
1032
+ multimodal_embeds = []
1033
+ multimodal_attention_mask = []
1034
+ multimodal_labels = [] if has_labels else None
1035
+ for i in range(B):
1036
+ # get index of <image> tokens in lang_x[i]
1037
+ image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1038
+
1039
+ if len(image_token_idxs) == 0:
1040
+ multimodal_embeds.append(lang_embeds[i].clone())
1041
+ multimodal_attention_mask.append(attention_mask[i].clone())
1042
+ if has_labels:
1043
+ multimodal_labels.append(labels[i].clone())
1044
+ continue
1045
+
1046
+ # loop through the image_token_idxs and insert the vision tokens
1047
+ new_embed = lang_embeds[i].clone()
1048
+ new_attention_mask = (
1049
+ attention_mask[i].clone() if attention_mask is not None else None
1050
+ )
1051
+ if has_labels:
1052
+ new_label = labels[i].clone()
1053
+ print(vision_tokens.shape)
1054
+ for img_num, img_idx in enumerate(image_token_idxs):
1055
+ new_embed = torch.cat(
1056
+ (
1057
+ new_embed[:img_idx],
1058
+ vision_tokens[i][img_num],
1059
+ new_embed[img_idx + self.num_tokens_per_vis :],
1060
+ ),
1061
+ dim=0,
1062
+ )
1063
+ new_attention_mask = torch.cat(
1064
+ (
1065
+ new_attention_mask[:img_idx],
1066
+ torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
1067
+ attention_mask.device
1068
+ ),
1069
+ new_attention_mask[img_idx + self.num_tokens_per_vis :],
1070
+ ),
1071
+ dim=0,
1072
+ )
1073
+ if has_labels:
1074
+ new_label = torch.cat(
1075
+ (
1076
+ new_label[:img_idx],
1077
+ torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
1078
+ labels.device
1079
+ )
1080
+ * -100,
1081
+ new_label[img_idx + self.num_tokens_per_vis :],
1082
+ ),
1083
+ dim=0,
1084
+ )
1085
+ multimodal_embeds.append(new_embed)
1086
+ multimodal_attention_mask.append(new_attention_mask)
1087
+ if has_labels:
1088
+ multimodal_labels.append(new_label)
1089
+
1090
+ # stack
1091
+ multimodal_embeds = stack_with_padding(
1092
+ multimodal_embeds,
1093
+ padding_value=self.pad_token_id,
1094
+ padding_side=padding_side,
1095
+ )
1096
+ multimodal_attention_mask = stack_with_padding(
1097
+ multimodal_attention_mask,
1098
+ padding_value=0,
1099
+ padding_side=padding_side,
1100
+ )
1101
+ if has_labels:
1102
+ multimodal_labels = stack_with_padding(
1103
+ multimodal_labels,
1104
+ padding_value=-100,
1105
+ padding_side=padding_side,
1106
+ )
1107
+
1108
+ return {
1109
+ "inputs_embeds": multimodal_embeds,
1110
+ "attention_mask": multimodal_attention_mask,
1111
+ "labels": multimodal_labels,
1112
+ }
1113
+
1114
+ def _postprocess_outputs_from_forward(
1115
+ self,
1116
+ output: CausalLMOutputWithPast,
1117
+ lang_x: torch.Tensor,
1118
+ vision_tokens: torch.Tensor,
1119
+ past_vision_tokens: torch.Tensor,
1120
+ past_media_locations: torch.Tensor,
1121
+ use_cache: bool = False,
1122
+ ):
1123
+ # Include the past vision tokens and past media locations in the output
1124
+ updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1125
+ lang_x=lang_x,
1126
+ vision_tokens=vision_tokens,
1127
+ past_vision_tokens=past_vision_tokens,
1128
+ past_media_locations=past_media_locations,
1129
+ use_cache=use_cache,
1130
+ )
1131
+
1132
+ # return logits that are the same shape as the original input_ids
1133
+ logits = output.logits
1134
+ batch_logits = []
1135
+ B, T_txt = lang_x.shape
1136
+ for i in range(B):
1137
+ sequence_logits = []
1138
+ logits_j = 0
1139
+ for j in range(T_txt):
1140
+ if lang_x[i, j] != self.media_token_id:
1141
+ sequence_logits.append(logits[i, logits_j])
1142
+ logits_j += 1
1143
+ else:
1144
+ # append the logit for the first image token, then skip over the rest
1145
+ # note: the model actually learns to predict <im_patch>, not <image>
1146
+ sequence_logits.append(logits[i, logits_j])
1147
+ logits_j += self.num_tokens_per_vis
1148
+ sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1149
+ batch_logits.append(sequence_logits)
1150
+
1151
+ batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1152
+ # The final logits shape should be the same as the original input_ids shape
1153
+ assert batch_logits.shape[:2] == (B, T_txt)
1154
+
1155
+ # assemble the output
1156
+ output = VLMOutputWithPast(
1157
+ loss=output.loss,
1158
+ logits=batch_logits,
1159
+ past_key_values=output.past_key_values,
1160
+ hidden_states=output.hidden_states,
1161
+ attentions=output.attentions,
1162
+ past_media_locations=updated_media_locations,
1163
+ past_vision_tokens=updated_vision_tokens,
1164
+ )
1165
+
1166
+ return output
1167
+
1168
+ def _post_forward_hook(self):
1169
+ pass
1170
+
1171
+
1172
+ @property
1173
+ def num_params_per_module(self):
1174
+ """Print the number of parameters per module in the model"""
1175
+ return "\n".join(
1176
+ [
1177
+ f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1178
+ f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1179
+ f"Language model: {num_params(self.lang_model):,} parameters",
1180
+ ]
1181
+ )
1182
+
1183
+ @property
1184
+ def num_trainable_params_per_module(self):
1185
+ """Print the number of trainable parameters per module in the model"""
1186
+ return "\n".join(
1187
+ [
1188
+ f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1189
+ f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1190
+ f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1191
+ ]
1192
+ )
1193
+
1194
+
1195
+ class XGenMMPerceiver(VLMWithLanguageStream):
1196
+ def __init__(
1197
+ self,
1198
+ vision_encoder: nn.Module,
1199
+ vision_tokenizer: nn.Module,
1200
+ lang_model: nn.Module,
1201
+ initial_tokenizer_len: int,
1202
+ pad_token_id: int,
1203
+ decoder_layers_attr_name: str = None,
1204
+ gradient_checkpointing: bool = False,
1205
+ image_aspect_ratio: str = 'none',
1206
+ ):
1207
+ """
1208
+ Args:
1209
+ vision_encoder (nn.Module): HF CLIPModel
1210
+ lang_encoder (nn.Module): HF causal language model
1211
+ vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1212
+ initial_tokenizer_len (int): size of the tokenizer vocab
1213
+ padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1214
+ will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1215
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1216
+ gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1217
+ """
1218
+ self._special_tokens = {
1219
+ "media_token": "<image>",
1220
+ "image_placeholder_token": "<image placeholder>",
1221
+ "end_of_trunk_token": "<|endofchunk|>",
1222
+ }
1223
+ lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1224
+ super().__init__(
1225
+ vision_encoder=vision_encoder,
1226
+ vision_tokenizer=vision_tokenizer,
1227
+ lang_model=lang_model,
1228
+ initial_tokenizer_len=initial_tokenizer_len,
1229
+ gradient_checkpointing=gradient_checkpointing,
1230
+ decoder_layers_attr_name=decoder_layers_attr_name,
1231
+ pad_token_id=pad_token_id,
1232
+ )
1233
+ self.image_aspect_ratio = image_aspect_ratio
1234
+
1235
+ def set_trainable(self):
1236
+ """
1237
+ Unfreeze everything except the vision_encoder
1238
+ """
1239
+ self.requires_grad_(True)
1240
+ self.vision_encoder.requires_grad_(False)
1241
+
1242
+ def _should_apply_weight_decay(self, parameter_name):
1243
+ """
1244
+ Kosmos applies 0.01 weight deacy to everything
1245
+ """
1246
+ return True
1247
+
1248
+ def generate(
1249
+ self,
1250
+ vision_x: torch.Tensor,
1251
+ lang_x: torch.Tensor,
1252
+ image_size: Optional[Tuple] = None,
1253
+ attention_mask: torch.Tensor = None,
1254
+ past_key_values: Optional[
1255
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1256
+ ] = None,
1257
+ past_media_locations: Optional[torch.Tensor] = None,
1258
+ past_vision_tokens: Optional[torch.Tensor] = None,
1259
+ **kwargs,
1260
+ ):
1261
+ """
1262
+ Generate text conditioned on vision and language inputs.
1263
+ Args:
1264
+ vision_x (torch.Tensor): Vision input
1265
+ shape (B, T_img, F, C, H, W)
1266
+ see documentation for forward
1267
+ lang_x (torch.Tensor): Language input
1268
+ shape (B, T_txt)
1269
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1270
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1271
+ Returns:
1272
+ torch.Tensor: lang_x with generated tokens appended to it
1273
+ """
1274
+ num_beams = kwargs.pop("num_beams", 1)
1275
+
1276
+ # convert pixels to vision tokens
1277
+ vision_attention_mask = None
1278
+ if vision_x is not None:
1279
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1280
+ vision_tokens = self.vision_tokenizer(vision_features)
1281
+ else:
1282
+ vision_tokens = None
1283
+
1284
+ # fuse the vision and language tokens
1285
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1286
+ # the total batch size is B * num_beams
1287
+ new_inputs = self._prepare_inputs_for_forward(
1288
+ vision_tokens=vision_tokens,
1289
+ lang_x=lang_x,
1290
+ attention_mask=attention_mask,
1291
+ vision_attention_mask=vision_attention_mask,
1292
+ past_key_values=past_key_values,
1293
+ past_media_locations=past_media_locations,
1294
+ past_vision_tokens=past_vision_tokens,
1295
+ padding_side="left",
1296
+ num_beams=num_beams,
1297
+ )
1298
+ if past_key_values is not None:
1299
+ output = self.lang_model.generate(
1300
+ **new_inputs,
1301
+ past_key_values=past_key_values,
1302
+ num_beams=num_beams,
1303
+ use_cache=True,
1304
+ **kwargs,
1305
+ )
1306
+ else:
1307
+ output = self.lang_model.generate(
1308
+ **new_inputs,
1309
+ num_beams=num_beams,
1310
+ use_cache=True,
1311
+ **kwargs,
1312
+ )
1313
+ self._post_forward_hook()
1314
+ return output