UncleFish commited on
Commit
bef178b
1 Parent(s): 7a3362f

base model release

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **/__pycache__/**
README.md CHANGED
@@ -7,26 +7,185 @@ pipeline_tag: image-text-to-text
7
 
8
 
9
  # Model description
 
10
 
11
- BLIP-3 consists of 3 models: a CLIP-like image encoder, a VL connector, and a large language model.
 
12
 
13
- # Direct Use and Downstream Use
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Bias, Risks, Limitations, and Ethical Considerations
17
 
18
  # How to use
19
 
20
- > We require use the development version (`"4.41.0.dev0"`) of the `transformers` library. To get it, as of 05/07/2024, one can use `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers.`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  # License
24
 
25
- Our code and weights are released under the Creative Commons Attribution Non Commercial 4.0 [LICENSE](LICENSE.txt).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Troubleshoot
28
 
29
- 1. If you missing any packages, please consider the followings
30
 
31
  ```
32
  pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
 
7
 
8
 
9
  # Model description
10
+ We are excited to announce the continuation and rebranding of our **BLIP series** into **XGen-MM**, to be better aligned with Salesforce's unified XGen initiative for large foundation models! This rebranding marks a significant step in our ongoing development of cutting-edge multimodal technologies.
11
 
12
+ `XGen-MM` is a series of the latest foundational Large Multimodal Models (LMMs) developed by Salesforce AI Research. This series advances upon the successful designs of the `BLIP` series, incorporating fundamental enhancements that ensure a more robust and superior foundation. \
13
+ These models have been trained at scale on high-quality image caption datasets and interleaved image-text data. XGen-MM highlights a few features below,
14
 
15
+ * The **pretrained** foundation model, `xgen-mm-phi3-mini-base-r-v1`, achieves state-of-the-art performance under 5b parameters and demonstrates strong in-context learning capabilities.
16
+ * The **instruct** fine-tuned model, `xgen-mm-phi3-mini-instruct-r-v1`, achieves state-of-the-art performance among open-source and closed-source VLMs under 5b parameters.
17
+ * `xgen-mm-phi3-mini-instruct-r-v1` supports flexible high-resolution image encoding with efficient visual token sampling.
18
 
19
+ More technical details will come with a technical report soon.
20
+
21
+
22
+ # Datasets
23
+
24
+ | Dataset Type| Dataset(s) Used |
25
+ |--------|------------------------------------------|
26
+ | Pretrain | caption data: (datacomp, cc12m, cc3m, SBU, vg) && interleaved data: obelics |
27
+ | Instruction Tuning | LLaVA-Instruct-150K, ShareGPT4V captions, a mixture of academic VQA data including OCR/Document/Chart-focused tasks, publicly available text-only instruction data |
28
+
29
+ # Results
30
+
31
+ ### Pretrain (base model without instruction tuning)
32
+ | Model | Shot | COCO (val) | NoCaps (val) | TextCaps (val) | OKVQA (val) | TextVQA (val) | VizWiz (testdev) | VQAv2 (testdev) |
33
+ |-------------|------|------------|--------------|----------------|--------------|---------------|------------------|-----------------|
34
+ | Flamingo-3B | 4 | 85.0 | - | - | 43.3 | 32.7 | 34 | 53.2 |
35
+ | | 8 | 90.6 | - | - | 44.6 | 32.4 | 38.4 | 55.4 |
36
+ | MM1-3B | 0 | 73.5 | 55.6 | 63.3 | 26.1 | 29.4 | 15.6 | 46.2 |
37
+ | | 4 | 112.3 | 99.7 | 84.1 | 48.6 | 45.3 | 38.0 | 57.9 |
38
+ | | 8 | 114.6 | 104.7 | 88.8 | 48.4 | 44.6 | 46.4 | 63.6 |
39
+ | **xgen-mm-phi3-mini-base-r-v1 (Ours)**| 0 | **81.7** | **80.2** | 60.7 | **26.5** | **36.0** | **21.2** | **48.1** |
40
+ | | 4 | 110.5 | **101.7** | **84.6** | **49.2** | **46.1** | **38.4** | **63.9** |
41
+ | | 8 | 112.1 | 104.4 | 87.7 | **49.1** | **46.4** | 44.3 | **63.8** |
42
+
43
+ ### Instruct (after instruction tuning)
44
+ | Model | SEED-IMG | MMBench(dev) | MME-total | MME-P | MME-C | MMStar | MMMU (val) | MMVet | MathVista (mini) | ScienceQA (test) | POPE | AI2D | |
45
+ |----------------------------|----------|--------------|-----------|----------|---------|----------|------------|----------|------------------|------------------|----------|----------|---|
46
+ | MM1-3B-Chat | 68.8 | 67.8 | 1761 | **1482** | 279 | - | 33.9 | 43.7 | - | - | **87.4** | - | |
47
+ | openbmb/MiniCPM-V-2 | 67.1 | 69.6 | 1808 | - | - | - | 38.2 | - | 38.7 | - | - | - | |
48
+ | VILA1.5-3B | 67.9 | 63.4 | - | 1442 | - | - | 33.3 | 35.4 | - | 69.0 | 85.9 | - | |
49
+ | xtuner/llava-phi-3-mini-hf | 70.0 | 69.2 | 1790 | 1477 | 313 | 43.7 | **41.4** | - | - | 73.7 | 87.3 | 69.3 | |
50
+ | **xgen-mm-phi3-mini-instruct-r-v1 (Ours)** | **72.1** | **74.1** | **1827** | 1467 | **360** | **44.6** | 39.8 | **45.1** | **39.3** | **74.2** | 87.2 | **75.8** | |
51
 
 
52
 
53
  # How to use
54
 
55
+ > We require the use of the development version (`"4.41.0.dev0"`) of the `transformers` library. To get it, as of 05/07/2024, one can use `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers.`
56
+
57
+ ```python
58
+ from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor
59
+ import json
60
+ import PIL
61
+ import IPython.display as display
62
+ import torch
63
+ model = AutoModelForVision2Seq.from_pretrained("./", trust_remote_code=True)
64
+ tokenizer = AutoTokenizer.from_pretrained("./", trust_remote_code=True, use_fast=True, legacy=False)
65
+ image_processor = AutoImageProcessor.from_pretrained("./", trust_remote_code=True)
66
+ tokenizer = model.update_special_tokens(tokenizer)
67
+
68
+ model = model.to('cuda')
69
+ tokenizer.padding_side = "left"
70
+
71
+ def apply_prompt_template(prompt, num_images=1, num_tokens_per_vis = 128, in_context=False, output=None):
72
+ """
73
+ num_tokens_per_vis: model.vlm.num_tokens_per_vis
74
+ """
75
+ placeholder_image_tokens = "<image placeholder>" * (num_tokens_per_vis - 1)
76
+ if in_context:
77
+ formatted_prompt = f"<image>{placeholder_image_tokens}" + f"{prompt}" + f"{output}" + "<|endofchunk|>"
78
+ else:
79
+ formatted_prompt = f"<image>{placeholder_image_tokens}"*num_images + f"{prompt}"
80
+ return formatted_prompt
81
+
82
+ ############ Zero shot inference ##########
83
+ with open('./test_samples/zero_shot.json') as f:
84
+ sample = json.load(f)
85
+ instruction = sample['instruction']
86
+ img = PIL.Image.open(sample['image_path'])
87
+ print("==> Instruction: ", instruction)
88
+ print("==> Image: ")
89
+ display.display(img.resize((int(img.width*0.3), int(img.height*0.3))))
90
+ inputs = image_processor([img], return_tensors="pt")
91
+ prompt = apply_prompt_template(instruction)
92
+ language_inputs = tokenizer([prompt], return_tensors="pt")
93
+ inputs.update(language_inputs)
94
+ inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
95
+
96
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
97
+ generated_text = model.generate(**inputs,
98
+ pad_token_id=tokenizer.pad_token_id,
99
+ do_sample=False, max_new_tokens=256, top_p=None, num_beams=1,
100
+ length_penalty=1.0, repetition_penalty=2.0)
101
+ prediction = tokenizer.decode(generated_text[0], skip_special_tokens=True)
102
+ print("==> prediciton: ", prediction)
103
+ print("-"*120)
104
+ # ==> prediciton: A man sits on a bench in front of the Red Corner Cafe.
105
+
106
+ ############ Few shots inference ##########
107
+ # prepare in-context examples
108
+ with open('./test_samples/few_shots.json') as f:
109
+ incontext_data = json.load(f)
110
+ print(f'In-context learning with {len(incontext_data)} examples.')
111
+ context_images, context_text = [], ""
112
+ for example in incontext_data:
113
+ print("-"*40 + f" {example} " + "-"*40)
114
+ img = PIL.Image.open(incontext_data[example]['image_path'])
115
+ instruction = incontext_data[example]['instruction']
116
+ example_text = apply_prompt_template(prompt=instruction, in_context=True, output=incontext_data[example]['output'])
117
+ context_images.append(img)
118
+ context_text += (example_text)
119
+ print("==> Instruction: ", instruction)
120
+ print("==> Image: ")
121
+ display.display(img.resize((int(img.width*0.3), int(img.height*0.3))))
122
+ print("==> Output: ", incontext_data[example]['output'])
123
+ # prepare test example
124
+ with open('./test_samples/zero_shot.json') as f:
125
+ sample = json.load(f)
126
+ instruction = "A short description of this image in one sentence:"
127
+ print("-"*40 + " Prediction " + "-"*40)
128
+ img = PIL.Image.open(sample['image_path'])
129
+ print("==> Instruction: ", instruction)
130
+ print("==> Image: ")
131
+ display.display(img.resize((int(img.width*0.3), int(img.height*0.3))))
132
+ prompt = apply_prompt_template(instruction)
133
+ batch_images = context_images + [img]
134
+ batch_text = context_text + prompt
135
+ # prepare inputs
136
+ inputs = image_processor(batch_images, return_tensors="pt")
137
+ language_inputs = tokenizer([batch_text], return_tensors="pt")
138
+ inputs.update(language_inputs)
139
+ inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
140
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
141
+ generated_text = model.generate(**inputs,
142
+ pad_token_id=tokenizer.pad_token_id,
143
+ do_sample=False, max_new_tokens=256, top_p=None, num_beams=1,
144
+ length_penalty=1.0)
145
+ prediction = tokenizer.decode(generated_text[0], skip_special_tokens=True)
146
+ print("==> prediciton: ", prediction)
147
+ print("-"*120)
148
+ ```
149
+
150
+ More comprehensive examples can be found in the [notebook](demo.ipynb).
151
+
152
+ # Reproducibility:
153
+
154
+ Our SFT evaluation is based on the VLMEvalKit, in which we fixed some inconsistencies with the official benchmarks (e.g., LLM judge API). During our development, we noticed that the raw resolution of the input image would noticeably affect the model output in some cases.
155
+
156
+
157
+ # Bias, Risks, Limitations, and Ethical Considerations
158
+ The main data sources are from the internet, including webpages,
159
+ image stock sites, and curated datasets released by the research community. We have excluded certain data, such as LAION, due to known CSAM concerns.
160
+ The model may be subject to bias from the original data source, as well as bias from LLMs and commercial APIs.
161
+ We strongly recommend users assess safety and fairness before applying to downstream applications.
162
 
163
 
164
  # License
165
 
166
+ Our code and weights are released under the Creative Commons Attribution Non Commercial 4.0 [LICENSE](LICENSE.txt). Please fill out a form at [here](https://forms.gle/ffPc9oZC2ZGeJ1N68) to consult the commercial use of model weights.
167
+
168
+ # Code acknowledgement
169
+
170
+ [LAVIS](https://github.com/salesforce/LAVIS) \
171
+ [openflamingo](https://github.com/mlfoundations/open_flamingo) \
172
+ [VLMEvalKit](https://github.com/open-compass/VLMEvalKit/tree/main)
173
+
174
+
175
+ # Citation
176
+ ```
177
+ @misc{xgen_mm_phi3_mini,
178
+ title={xgen-mm-phi3-mini-base Model Card},
179
+ url={https://huggingface.co/Salesforce/xgen-mm-phi3-mini-base-r-v1},
180
+ author={Salesforce AI Research},
181
+ month={May},
182
+ year={2024}
183
+ }
184
+ ```
185
 
186
  # Troubleshoot
187
 
188
+ 1. If you missed any packages, please consider the following
189
 
190
  ```
191
  pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
added_tokens.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<pad>": 32011,
3
+ "<|assistant|>": 32001,
4
+ "<|endoftext|>": 32000,
5
+ "<|end|>": 32007,
6
+ "<|placeholder1|>": 32002,
7
+ "<|placeholder2|>": 32003,
8
+ "<|placeholder3|>": 32004,
9
+ "<|placeholder4|>": 32005,
10
+ "<|placeholder5|>": 32008,
11
+ "<|placeholder6|>": 32009,
12
+ "<|system|>": 32006,
13
+ "<|user|>": 32010
14
+ }
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "XGenMMModelForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_xgenmm.XGenMMConfig",
7
+ "AutoModelForVision2Seq": "modeling_xgenmm.XGenMMModelForConditionalGeneration"
8
+ },
9
+ "model_type": "xgenmm",
10
+ "text_config": {
11
+ "initial_tokenizer_len": 32012,
12
+ "model_type": "phi3",
13
+ "sliding_window": 2047,
14
+ "torch_dtype": "bfloat16"
15
+ },
16
+ "torch_dtype": "float32",
17
+ "transformers_version": "4.41.0.dev0",
18
+ "vision_encoder_config": {
19
+ "anyres_patch_sampling": false,
20
+ "image_aspect_ratio": "pad",
21
+ "model_type": "xgenmm_vision_encoder"
22
+ },
23
+ "vision_tokenizer_config": {
24
+ "model_type": "xgenmm_vision_tokenizer"
25
+ }
26
+ }
configuration_xgenmm.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers import logging
3
+ from transformers import CONFIG_MAPPING
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+ class XGenMMVisionEncoderConfig(PretrainedConfig):
8
+ model_type = "xgenmm_vision_encoder"
9
+
10
+ def __init__(self,
11
+ model_name: str = 'ViT-H-14-378-quickgelu',
12
+ force_image_size: int = 378,
13
+ **kwargs):
14
+ self.model_name = model_name
15
+ self.force_image_size = force_image_size
16
+ super().__init__(**kwargs)
17
+
18
+
19
+ class XGenMMVisionTokenizerConfig(PretrainedConfig):
20
+ model_type = "xgenmm_vision_tokenizer"
21
+
22
+ def __init__(self,
23
+ vis_feature_dim: int = 1280,
24
+ lang_embedding_dim: int = 3072,
25
+ **kwargs):
26
+ self.vis_feature_dim = vis_feature_dim
27
+ self.lang_embedding_dim = lang_embedding_dim
28
+ super().__init__(**kwargs)
29
+
30
+
31
+ class XGenMMConfig(PretrainedConfig):
32
+ model_type = "xgenmm"
33
+
34
+ def __init__(self,
35
+ vision_encoder_config: dict = None,
36
+ vision_tokenizer_config: dict = None,
37
+ text_config: dict = None,
38
+ **kwargs):
39
+
40
+ if vision_encoder_config is None:
41
+ vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True}
42
+ logger.info("vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values.")
43
+
44
+ if vision_tokenizer_config is None:
45
+ vision_tokenizer_config = {}
46
+ logger.info("vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values.")
47
+
48
+ if text_config is None:
49
+ text_config = {
50
+ 'initial_tokenizer_len':32012,
51
+ 'pad_token_id':32011,
52
+ 'bos_token_id':1,
53
+ 'eos_token_id':32000,
54
+ 'vocab_size': 32064,
55
+ 'hidden_size': 3072,
56
+ 'intermediate_size': 8192,
57
+ 'num_hidden_layers': 32,
58
+ 'num_attention_heads': 32,
59
+ 'num_key_value_heads': 32,
60
+ 'resid_pdrop': 0.0,
61
+ 'embd_pdrop': 0.0,
62
+ 'attention_dropout': 0.0,
63
+ 'hidden_act': 'silu',
64
+ 'max_position_embeddings': 4096,
65
+ 'original_max_position_embeddings': 4096,
66
+ 'initializer_range': 0.02,
67
+ 'rms_norm_eps': 1e-05,
68
+ 'use_cache': True,
69
+ 'rope_theta': 10000.0,
70
+ 'rope_scaling': None,
71
+ 'sliding_window': 2047,
72
+ 'return_dict': True,
73
+ 'output_hidden_states': False,
74
+ 'output_attentions': False,
75
+ 'torchscript': False,
76
+ 'torch_dtype': 'bfloat16',
77
+ 'use_bfloat16': False,
78
+ 'tf_legacy_loss': False,
79
+ 'pruned_heads': {},
80
+ 'tie_word_embeddings': False,
81
+ 'chunk_size_feed_forward': 0,
82
+ 'is_encoder_decoder': False,
83
+ 'is_decoder': False,
84
+ 'cross_attention_hidden_size': None,
85
+ 'add_cross_attention': False,
86
+ 'tie_encoder_decoder': False,
87
+ 'max_length': 20,
88
+ 'min_length': 0,
89
+ 'do_sample': False,
90
+ 'early_stopping': False,
91
+ 'num_beams': 1,
92
+ 'num_beam_groups': 1,
93
+ 'diversity_penalty': 0.0,
94
+ 'temperature': 1.0,
95
+ 'top_k': 50,
96
+ 'top_p': 1.0,
97
+ 'typical_p': 1.0,
98
+ 'repetition_penalty': 1.0,
99
+ 'length_penalty': 1.0,
100
+ 'no_repeat_ngram_size': 0,
101
+ 'encoder_no_repeat_ngram_size': 0,
102
+ 'bad_words_ids': None,
103
+ 'num_return_sequences': 1,
104
+ 'output_scores': False,
105
+ 'return_dict_in_generate': False,
106
+ 'forced_bos_token_id': None,
107
+ 'forced_eos_token_id': None,
108
+ 'remove_invalid_values': False,
109
+ 'exponential_decay_length_penalty': None,
110
+ 'suppress_tokens': None,
111
+ 'begin_suppress_tokens': None,
112
+ 'finetuning_task': None,
113
+ 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},
114
+ 'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
115
+ 'tokenizer_class': None,
116
+ 'prefix': None,
117
+ 'bos_token_id': 1,
118
+ 'pad_token_id': 32000,
119
+ 'eos_token_id': 32000,
120
+ 'sep_token_id': None,
121
+ 'decoder_start_token_id': None,
122
+ 'task_specific_params': None,
123
+ 'problem_type': None,
124
+ 'model_type': 'phi3'
125
+ }
126
+ logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).")
127
+
128
+ self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
129
+
130
+ self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(**vision_tokenizer_config)
131
+
132
+ text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3"
133
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
134
+
135
+ for key in ['initial_tokenizer_len', 'pad_token_id']:
136
+ if key not in self.text_config.to_dict():
137
+ raise ValueError(f"The key `{key}` is missing in the text_config.")
138
+
139
+ super().__init__(**kwargs)
140
+
141
+ @classmethod
142
+ def from_vision_encoder_vision_tokenizer_text_configs(
143
+ cls,
144
+ vision_encoder_config: XGenMMVisionEncoderConfig,
145
+ vision_tokenizer_config: XGenMMVisionTokenizerConfig,
146
+ text_config: PretrainedConfig,
147
+ **kwargs):
148
+
149
+ return cls(
150
+ vision_encoder_config=vision_encoder_config.to_dict(),
151
+ vision_tokenizer_config=vision_tokenizer_config.to_dict(),
152
+ text_config=text_config.to_dict(),
153
+ **kwargs,
154
+ )
155
+
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 32000,
5
+ "pad_token_id": 32000,
6
+ "transformers_version": "4.41.0.dev0"
7
+ }
image_processing_xgenmm.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3
+ import torchvision.transforms.functional as F
4
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
5
+ CenterCrop, ColorJitter, Grayscale
6
+ import numbers
7
+ import torch
8
+ import ast
9
+ import math
10
+ from PIL import Image
11
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
12
+ from transformers.image_utils import ImageInput
13
+ from transformers.utils import TensorType
14
+
15
+
16
+ class XGenMMImageProcessor(BaseImageProcessor):
17
+
18
+ def __init__(
19
+ self,
20
+ do_resize: bool = True,
21
+ resize_mode: str = "squash",
22
+ interpolation_mode: str = "bicubic",
23
+ size: Union[Tuple[int, int], List[int]] = None,
24
+ image_mean: Optional[Union[float, List[float]]] = None,
25
+ image_std: Optional[Union[float, List[float]]] = None,
26
+ **kwargs,
27
+ ) -> None:
28
+ super().__init__(**kwargs)
29
+ self.do_resize = do_resize
30
+ self.resize_mode = resize_mode
31
+ self.interpolation_mode = interpolation_mode
32
+ self.size = size if size is not None else (378, 378)
33
+ self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
34
+ self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
35
+
36
+
37
+ @classmethod
38
+ def resize(cls, image_size, resize_mode, interpolation='bicubic', fill_color=0):
39
+ interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
40
+ if resize_mode == 'longest':
41
+ transforms = [
42
+ ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
43
+ CenterCropOrPad(image_size, fill=fill_color)
44
+ ]
45
+ elif resize_mode == 'squash':
46
+ if isinstance(image_size, int):
47
+ image_size = (image_size, image_size)
48
+ transforms = [
49
+ Resize(image_size, interpolation=interpolation_mode),
50
+ ]
51
+ else:
52
+ assert resize_mode == 'shortest'
53
+ if not isinstance(image_size, (tuple, list)):
54
+ image_size = (image_size, image_size)
55
+ if image_size[0] == image_size[1]:
56
+ # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
57
+ transforms = [
58
+ Resize(image_size[0], interpolation=interpolation_mode)
59
+ ]
60
+ else:
61
+ # resize shortest edge to matching target dim for non-square target
62
+ transforms = [ResizeKeepRatio(image_size)]
63
+ transforms += [CenterCrop(image_size)]
64
+ return transforms
65
+
66
+ @classmethod
67
+ def convert_rgb(cls, image):
68
+ return image.convert("RGB")
69
+
70
+
71
+ def _preprocess(self,
72
+ images: ImageInput
73
+ ) -> torch.Tensor:
74
+ transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
75
+ transforms.extend([
76
+ self.convert_rgb,
77
+ ToTensor(),
78
+ Normalize(mean=self.image_mean, std=self.image_std)
79
+ ])
80
+ composed_transforms = Compose(transforms)
81
+ images_tensor = composed_transforms(images)
82
+ return images_tensor
83
+
84
+ def preprocess(self,
85
+ images: ImageInput,
86
+ return_tensors: Optional[Union[str, TensorType]] = None,
87
+ **kwargs) -> BatchFeature:
88
+ if 'image_aspect_ratio' in kwargs:
89
+ image_aspect_ratio = kwargs['image_aspect_ratio']
90
+ else:
91
+ image_aspect_ratio = 'pad'
92
+ new_images = []
93
+ if image_aspect_ratio == 'pad':
94
+ for image in images:
95
+ image = self._preprocess(image)
96
+ new_images.append(image)
97
+ else:
98
+ if isinstance(self.size, (tuple, list)):
99
+ base_img_size = self.size[0]
100
+ else:
101
+ raise ValueError("size should be list or tuple")
102
+ for image in images:
103
+ image = process_anyres_image(image, self._preprocess, self.size,
104
+ [
105
+ [base_img_size,base_img_size*2],
106
+ [base_img_size*2,base_img_size],
107
+ [base_img_size*2,base_img_size*2],
108
+ [base_img_size*3,base_img_size],
109
+ [base_img_size,base_img_size*3]
110
+ ])
111
+ new_images.append(image)
112
+
113
+ if all(x.shape == new_images[0].shape for x in new_images):
114
+ new_images = torch.stack(new_images, dim=0)
115
+ if image_aspect_ratio == 'pad':
116
+ new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(1).unsqueeze(0)}, tensor_type=return_tensors)
117
+ else:
118
+ new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(0)}, tensor_type=return_tensors)
119
+ return new_images
120
+ # def preprocess(self,
121
+ # images: ImageInput,
122
+ # return_tensors: Optional[Union[str, TensorType]] = None,
123
+ # **kwargs) -> BatchFeature:
124
+ # transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
125
+ # transforms.extend([
126
+ # self.convert_rgb,
127
+ # ToTensor(),
128
+ # Normalize(mean=self.image_mean, std=self.image_std)
129
+ # ])
130
+ # composed_transforms = Compose(transforms)
131
+ # images_tensor = composed_transforms(images).unsqueeze(0).unsqueeze(1).unsqueeze(0)
132
+ # encoded_outputs = BatchFeature(data={"pixel_values": images_tensor}, tensor_type=return_tensors)
133
+ # return encoded_outputs
134
+
135
+
136
+ class ResizeKeepRatio:
137
+ """ Resize and Keep Ratio
138
+
139
+ Copy & paste from `timm`
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ size,
145
+ longest=0.,
146
+ interpolation=InterpolationMode.BICUBIC,
147
+ random_scale_prob=0.,
148
+ random_scale_range=(0.85, 1.05),
149
+ random_aspect_prob=0.,
150
+ random_aspect_range=(0.9, 1.11)
151
+ ):
152
+ if isinstance(size, (list, tuple)):
153
+ self.size = tuple(size)
154
+ else:
155
+ self.size = (size, size)
156
+ self.interpolation = interpolation
157
+ self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
158
+ self.random_scale_prob = random_scale_prob
159
+ self.random_scale_range = random_scale_range
160
+ self.random_aspect_prob = random_aspect_prob
161
+ self.random_aspect_range = random_aspect_range
162
+
163
+ @staticmethod
164
+ def get_params(
165
+ img,
166
+ target_size,
167
+ longest,
168
+ random_scale_prob=0.,
169
+ random_scale_range=(0.85, 1.05),
170
+ random_aspect_prob=0.,
171
+ random_aspect_range=(0.9, 1.11)
172
+ ):
173
+ """Get parameters
174
+ """
175
+ source_size = img.size[::-1] # h, w
176
+ h, w = source_size
177
+ target_h, target_w = target_size
178
+ ratio_h = h / target_h
179
+ ratio_w = w / target_w
180
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
181
+ if random_scale_prob > 0 and random.random() < random_scale_prob:
182
+ ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
183
+ ratio_factor = (ratio_factor, ratio_factor)
184
+ else:
185
+ ratio_factor = (1., 1.)
186
+ if random_aspect_prob > 0 and random.random() < random_aspect_prob:
187
+ aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
188
+ ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
189
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
190
+ return size
191
+
192
+ def __call__(self, img):
193
+ """
194
+ Args:
195
+ img (PIL Image): Image to be cropped and resized.
196
+
197
+ Returns:
198
+ PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
199
+ """
200
+ size = self.get_params(
201
+ img, self.size, self.longest,
202
+ self.random_scale_prob, self.random_scale_range,
203
+ self.random_aspect_prob, self.random_aspect_range
204
+ )
205
+ img = F.resize(img, size, self.interpolation)
206
+ return img
207
+
208
+ def __repr__(self):
209
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
210
+ format_string += f', interpolation={self.interpolation})'
211
+ format_string += f', longest={self.longest:.3f})'
212
+ return format_string
213
+
214
+ def _setup_size(size, error_msg):
215
+ if isinstance(size, numbers.Number):
216
+ return int(size), int(size)
217
+
218
+ if isinstance(size, Sequence) and len(size) == 1:
219
+ return size[0], size[0]
220
+
221
+ if len(size) != 2:
222
+ raise ValueError(error_msg)
223
+
224
+ return size
225
+
226
+ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
227
+ """Center crops and/or pads the given image.
228
+ If the image is torch Tensor, it is expected
229
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
230
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
231
+
232
+ Args:
233
+ img (PIL Image or Tensor): Image to be cropped.
234
+ output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
235
+ it is used for both directions.
236
+ fill (int, Tuple[int]): Padding color
237
+
238
+ Returns:
239
+ PIL Image or Tensor: Cropped image.
240
+ """
241
+ if isinstance(output_size, numbers.Number):
242
+ output_size = (int(output_size), int(output_size))
243
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
244
+ output_size = (output_size[0], output_size[0])
245
+
246
+ _, image_height, image_width = F.get_dimensions(img)
247
+ crop_height, crop_width = output_size
248
+
249
+ if crop_width > image_width or crop_height > image_height:
250
+ padding_ltrb = [
251
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
252
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
253
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
254
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
255
+ ]
256
+ img = F.pad(img, padding_ltrb, fill=fill)
257
+ _, image_height, image_width = F.get_dimensions(img)
258
+ if crop_width == image_width and crop_height == image_height:
259
+ return img
260
+
261
+ crop_top = int(round((image_height - crop_height) / 2.0))
262
+ crop_left = int(round((image_width - crop_width) / 2.0))
263
+ return F.crop(img, crop_top, crop_left, crop_height, crop_width)
264
+
265
+ class CenterCropOrPad(torch.nn.Module):
266
+ """Crops the given image at the center.
267
+ If the image is torch Tensor, it is expected
268
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
269
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
270
+
271
+ Args:
272
+ size (sequence or int): Desired output size of the crop. If size is an
273
+ int instead of sequence like (h, w), a square crop (size, size) is
274
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
275
+ """
276
+
277
+ def __init__(self, size, fill=0):
278
+ super().__init__()
279
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
280
+ self.fill = fill
281
+
282
+ def forward(self, img):
283
+ """
284
+ Args:
285
+ img (PIL Image or Tensor): Image to be cropped.
286
+
287
+ Returns:
288
+ PIL Image or Tensor: Cropped image.
289
+ """
290
+ return center_crop_or_pad(img, self.size, fill=self.fill)
291
+
292
+ def __repr__(self) -> str:
293
+ return f"{self.__class__.__name__}(size={self.size})"
294
+
295
+ def process_anyres_image(image, processor, processor_size, grid_pinpoints):
296
+ """
297
+ Process an image with variable resolutions.
298
+
299
+ Args:
300
+ image (PIL.Image.Image): The input image to be processed.
301
+ processor: The image processor object.
302
+ processor_size (tuple, list): The size of the image processor.
303
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
304
+
305
+ Returns:
306
+ torch.Tensor: A tensor containing the processed image patches.
307
+ """
308
+ # FIXME: determine grid_pinpoints from image sizes.
309
+ if type(grid_pinpoints) is list:
310
+ possible_resolutions = grid_pinpoints
311
+ else:
312
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
313
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
314
+ image_padded = resize_and_pad_image(image, best_resolution)
315
+
316
+ # processor_size = processor.transforms[0].size
317
+ patches = divide_to_patches(image_padded, processor_size[0])
318
+
319
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
320
+
321
+ image_patches = [image_original_resize] + patches
322
+ image_patches = [processor(image_patch)
323
+ for image_patch in image_patches]
324
+ return torch.stack(image_patches, dim=0)
325
+
326
+
327
+ def select_best_resolution(original_size, possible_resolutions):
328
+ """
329
+ Selects the best resolution from a list of possible resolutions based on the original size.
330
+
331
+ Args:
332
+ original_size (tuple): The original size of the image in the format (width, height).
333
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
334
+
335
+ Returns:
336
+ tuple: The best fit resolution in the format (width, height).
337
+ """
338
+ original_width, original_height = original_size
339
+ best_fit = None
340
+ max_effective_resolution = 0
341
+ min_wasted_resolution = float('inf')
342
+
343
+ for width, height in possible_resolutions:
344
+ scale = min(width / original_width, height / original_height)
345
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
346
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
347
+ wasted_resolution = (width * height) - effective_resolution
348
+
349
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
350
+ max_effective_resolution = effective_resolution
351
+ min_wasted_resolution = wasted_resolution
352
+ best_fit = (width, height)
353
+
354
+ return best_fit
355
+
356
+ def resize_and_pad_image(image, target_resolution):
357
+ """
358
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
359
+
360
+ Args:
361
+ image (PIL.Image.Image): The input image.
362
+ target_resolution (tuple): The target resolution (width, height) of the image.
363
+
364
+ Returns:
365
+ PIL.Image.Image: The resized and padded image.
366
+ """
367
+ original_width, original_height = image.size
368
+ target_width, target_height = target_resolution
369
+
370
+ scale_w = target_width / original_width
371
+ scale_h = target_height / original_height
372
+
373
+ if scale_w < scale_h:
374
+ new_width = target_width
375
+ new_height = min(math.ceil(original_height * scale_w), target_height)
376
+ else:
377
+ new_height = target_height
378
+ new_width = min(math.ceil(original_width * scale_h), target_width)
379
+
380
+ # Resize the image
381
+ resized_image = image.resize((new_width, new_height))
382
+
383
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
384
+ paste_x = (target_width - new_width) // 2
385
+ paste_y = (target_height - new_height) // 2
386
+ new_image.paste(resized_image, (paste_x, paste_y))
387
+
388
+ return new_image
389
+
390
+ def divide_to_patches(image, patch_size):
391
+ """
392
+ Divides an image into patches of a specified size.
393
+
394
+ Args:
395
+ image (PIL.Image.Image): The input image.
396
+ patch_size (int): The size of each patch.
397
+
398
+ Returns:
399
+ list: A list of PIL.Image.Image objects representing the patches.
400
+ """
401
+ patches = []
402
+ width, height = image.size
403
+ for i in range(0, height, patch_size):
404
+ for j in range(0, width, patch_size):
405
+ box = (j, i, j + patch_size, i + patch_size)
406
+ patch = image.crop(box)
407
+ patches.append(patch)
408
+
409
+ return patches
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca8bbede81e4a6265ab0d3fb2af2ed7b3bf64af352d2af5a81c8117723103569
3
+ size 4954761920
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:531b227276744050439141b2c6a6429c2b2d72b83717472dfce35194d63667d5
3
+ size 4983112128
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0067a3e761a18e08eab4a48142845652329a76c3b2a0fa6780cc10aabe1b89ec
3
+ size 4983112168
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d34abbbba4bfd0fca5c99b17d7ff5d26d3e1a9d8fbb3cce8127e9e521a10dd4
3
+ size 3414256548
model.safetensors.index.json ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 18335156492
4
+ },
5
+ "weight_map": {
6
+ "vlm.lang_model.lm_head.additional_fc.bias": "model-00004-of-00004.safetensors",
7
+ "vlm.lang_model.lm_head.additional_fc.weight": "model-00004-of-00004.safetensors",
8
+ "vlm.lang_model.lm_head.bias": "model-00004-of-00004.safetensors",
9
+ "vlm.lang_model.lm_head.weight": "model-00004-of-00004.safetensors",
10
+ "vlm.lang_model.model.embed_tokens.additional_embedding.weight": "model-00001-of-00004.safetensors",
11
+ "vlm.lang_model.model.embed_tokens.weight": "model-00001-of-00004.safetensors",
12
+ "vlm.lang_model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "vlm.lang_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
14
+ "vlm.lang_model.model.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
15
+ "vlm.lang_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
16
+ "vlm.lang_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
17
+ "vlm.lang_model.model.layers.0.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
18
+ "vlm.lang_model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
19
+ "vlm.lang_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
20
+ "vlm.lang_model.model.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
21
+ "vlm.lang_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
22
+ "vlm.lang_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
23
+ "vlm.lang_model.model.layers.1.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
24
+ "vlm.lang_model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
25
+ "vlm.lang_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
26
+ "vlm.lang_model.model.layers.10.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
27
+ "vlm.lang_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
28
+ "vlm.lang_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
29
+ "vlm.lang_model.model.layers.10.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
30
+ "vlm.lang_model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
31
+ "vlm.lang_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
32
+ "vlm.lang_model.model.layers.11.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
33
+ "vlm.lang_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
34
+ "vlm.lang_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
35
+ "vlm.lang_model.model.layers.11.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
36
+ "vlm.lang_model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "vlm.lang_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
38
+ "vlm.lang_model.model.layers.12.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
39
+ "vlm.lang_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
40
+ "vlm.lang_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
41
+ "vlm.lang_model.model.layers.12.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
42
+ "vlm.lang_model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
43
+ "vlm.lang_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
44
+ "vlm.lang_model.model.layers.13.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
45
+ "vlm.lang_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
46
+ "vlm.lang_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
47
+ "vlm.lang_model.model.layers.13.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
48
+ "vlm.lang_model.model.layers.14.input_layernorm.weight": "model-00003-of-00004.safetensors",
49
+ "vlm.lang_model.model.layers.14.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
50
+ "vlm.lang_model.model.layers.14.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
51
+ "vlm.lang_model.model.layers.14.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
52
+ "vlm.lang_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
53
+ "vlm.lang_model.model.layers.14.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
54
+ "vlm.lang_model.model.layers.15.input_layernorm.weight": "model-00003-of-00004.safetensors",
55
+ "vlm.lang_model.model.layers.15.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
56
+ "vlm.lang_model.model.layers.15.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
57
+ "vlm.lang_model.model.layers.15.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
58
+ "vlm.lang_model.model.layers.15.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
59
+ "vlm.lang_model.model.layers.15.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
60
+ "vlm.lang_model.model.layers.16.input_layernorm.weight": "model-00003-of-00004.safetensors",
61
+ "vlm.lang_model.model.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
62
+ "vlm.lang_model.model.layers.16.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
63
+ "vlm.lang_model.model.layers.16.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
64
+ "vlm.lang_model.model.layers.16.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
65
+ "vlm.lang_model.model.layers.16.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
66
+ "vlm.lang_model.model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
67
+ "vlm.lang_model.model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
68
+ "vlm.lang_model.model.layers.17.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
69
+ "vlm.lang_model.model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
70
+ "vlm.lang_model.model.layers.17.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
71
+ "vlm.lang_model.model.layers.17.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
72
+ "vlm.lang_model.model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
73
+ "vlm.lang_model.model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
74
+ "vlm.lang_model.model.layers.18.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
75
+ "vlm.lang_model.model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
76
+ "vlm.lang_model.model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
77
+ "vlm.lang_model.model.layers.18.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
78
+ "vlm.lang_model.model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
79
+ "vlm.lang_model.model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
80
+ "vlm.lang_model.model.layers.19.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
81
+ "vlm.lang_model.model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
82
+ "vlm.lang_model.model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
83
+ "vlm.lang_model.model.layers.19.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
84
+ "vlm.lang_model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
85
+ "vlm.lang_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
86
+ "vlm.lang_model.model.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
87
+ "vlm.lang_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
88
+ "vlm.lang_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
89
+ "vlm.lang_model.model.layers.2.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
90
+ "vlm.lang_model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
91
+ "vlm.lang_model.model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
92
+ "vlm.lang_model.model.layers.20.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
93
+ "vlm.lang_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
94
+ "vlm.lang_model.model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
95
+ "vlm.lang_model.model.layers.20.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
96
+ "vlm.lang_model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
97
+ "vlm.lang_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
98
+ "vlm.lang_model.model.layers.21.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
99
+ "vlm.lang_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
100
+ "vlm.lang_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
101
+ "vlm.lang_model.model.layers.21.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
102
+ "vlm.lang_model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
103
+ "vlm.lang_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
104
+ "vlm.lang_model.model.layers.22.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
105
+ "vlm.lang_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
106
+ "vlm.lang_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
107
+ "vlm.lang_model.model.layers.22.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
108
+ "vlm.lang_model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
109
+ "vlm.lang_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
110
+ "vlm.lang_model.model.layers.23.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
111
+ "vlm.lang_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
112
+ "vlm.lang_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
113
+ "vlm.lang_model.model.layers.23.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
114
+ "vlm.lang_model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
115
+ "vlm.lang_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
116
+ "vlm.lang_model.model.layers.24.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
117
+ "vlm.lang_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
118
+ "vlm.lang_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
119
+ "vlm.lang_model.model.layers.24.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
120
+ "vlm.lang_model.model.layers.25.input_layernorm.weight": "model-00004-of-00004.safetensors",
121
+ "vlm.lang_model.model.layers.25.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
122
+ "vlm.lang_model.model.layers.25.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
123
+ "vlm.lang_model.model.layers.25.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
124
+ "vlm.lang_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
125
+ "vlm.lang_model.model.layers.25.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
126
+ "vlm.lang_model.model.layers.26.input_layernorm.weight": "model-00004-of-00004.safetensors",
127
+ "vlm.lang_model.model.layers.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
128
+ "vlm.lang_model.model.layers.26.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
129
+ "vlm.lang_model.model.layers.26.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
130
+ "vlm.lang_model.model.layers.26.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
131
+ "vlm.lang_model.model.layers.26.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
132
+ "vlm.lang_model.model.layers.27.input_layernorm.weight": "model-00004-of-00004.safetensors",
133
+ "vlm.lang_model.model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
134
+ "vlm.lang_model.model.layers.27.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
135
+ "vlm.lang_model.model.layers.27.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
136
+ "vlm.lang_model.model.layers.27.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
137
+ "vlm.lang_model.model.layers.27.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
138
+ "vlm.lang_model.model.layers.28.input_layernorm.weight": "model-00004-of-00004.safetensors",
139
+ "vlm.lang_model.model.layers.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
140
+ "vlm.lang_model.model.layers.28.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
141
+ "vlm.lang_model.model.layers.28.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
142
+ "vlm.lang_model.model.layers.28.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
143
+ "vlm.lang_model.model.layers.28.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
144
+ "vlm.lang_model.model.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
145
+ "vlm.lang_model.model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
146
+ "vlm.lang_model.model.layers.29.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
147
+ "vlm.lang_model.model.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
148
+ "vlm.lang_model.model.layers.29.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
149
+ "vlm.lang_model.model.layers.29.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
150
+ "vlm.lang_model.model.layers.3.input_layernorm.weight": "model-00002-of-00004.safetensors",
151
+ "vlm.lang_model.model.layers.3.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
152
+ "vlm.lang_model.model.layers.3.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
153
+ "vlm.lang_model.model.layers.3.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
154
+ "vlm.lang_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
155
+ "vlm.lang_model.model.layers.3.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
156
+ "vlm.lang_model.model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
157
+ "vlm.lang_model.model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
158
+ "vlm.lang_model.model.layers.30.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
159
+ "vlm.lang_model.model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
160
+ "vlm.lang_model.model.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
161
+ "vlm.lang_model.model.layers.30.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
162
+ "vlm.lang_model.model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
163
+ "vlm.lang_model.model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
164
+ "vlm.lang_model.model.layers.31.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
165
+ "vlm.lang_model.model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
166
+ "vlm.lang_model.model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
167
+ "vlm.lang_model.model.layers.31.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
168
+ "vlm.lang_model.model.layers.4.input_layernorm.weight": "model-00002-of-00004.safetensors",
169
+ "vlm.lang_model.model.layers.4.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
170
+ "vlm.lang_model.model.layers.4.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
171
+ "vlm.lang_model.model.layers.4.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
172
+ "vlm.lang_model.model.layers.4.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
173
+ "vlm.lang_model.model.layers.4.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
174
+ "vlm.lang_model.model.layers.5.input_layernorm.weight": "model-00002-of-00004.safetensors",
175
+ "vlm.lang_model.model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
176
+ "vlm.lang_model.model.layers.5.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
177
+ "vlm.lang_model.model.layers.5.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
178
+ "vlm.lang_model.model.layers.5.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
179
+ "vlm.lang_model.model.layers.5.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
180
+ "vlm.lang_model.model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
181
+ "vlm.lang_model.model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
182
+ "vlm.lang_model.model.layers.6.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
183
+ "vlm.lang_model.model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
184
+ "vlm.lang_model.model.layers.6.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
185
+ "vlm.lang_model.model.layers.6.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
186
+ "vlm.lang_model.model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
187
+ "vlm.lang_model.model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
188
+ "vlm.lang_model.model.layers.7.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
189
+ "vlm.lang_model.model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
190
+ "vlm.lang_model.model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
191
+ "vlm.lang_model.model.layers.7.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
192
+ "vlm.lang_model.model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
193
+ "vlm.lang_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
194
+ "vlm.lang_model.model.layers.8.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
195
+ "vlm.lang_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
196
+ "vlm.lang_model.model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
197
+ "vlm.lang_model.model.layers.8.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
198
+ "vlm.lang_model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
199
+ "vlm.lang_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
200
+ "vlm.lang_model.model.layers.9.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
201
+ "vlm.lang_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
202
+ "vlm.lang_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
203
+ "vlm.lang_model.model.layers.9.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
204
+ "vlm.lang_model.model.norm.weight": "model-00004-of-00004.safetensors",
205
+ "vlm.vision_encoder.class_embedding": "model-00001-of-00004.safetensors",
206
+ "vlm.vision_encoder.conv1.weight": "model-00001-of-00004.safetensors",
207
+ "vlm.vision_encoder.ln_post.bias": "model-00001-of-00004.safetensors",
208
+ "vlm.vision_encoder.ln_post.weight": "model-00001-of-00004.safetensors",
209
+ "vlm.vision_encoder.ln_pre.bias": "model-00001-of-00004.safetensors",
210
+ "vlm.vision_encoder.ln_pre.weight": "model-00001-of-00004.safetensors",
211
+ "vlm.vision_encoder.positional_embedding": "model-00001-of-00004.safetensors",
212
+ "vlm.vision_encoder.proj": "model-00001-of-00004.safetensors",
213
+ "vlm.vision_encoder.transformer.resblocks.0.attn.in_proj_bias": "model-00001-of-00004.safetensors",
214
+ "vlm.vision_encoder.transformer.resblocks.0.attn.in_proj_weight": "model-00001-of-00004.safetensors",
215
+ "vlm.vision_encoder.transformer.resblocks.0.attn.out_proj.bias": "model-00001-of-00004.safetensors",
216
+ "vlm.vision_encoder.transformer.resblocks.0.attn.out_proj.weight": "model-00001-of-00004.safetensors",
217
+ "vlm.vision_encoder.transformer.resblocks.0.ln_1.bias": "model-00001-of-00004.safetensors",
218
+ "vlm.vision_encoder.transformer.resblocks.0.ln_1.weight": "model-00001-of-00004.safetensors",
219
+ "vlm.vision_encoder.transformer.resblocks.0.ln_2.bias": "model-00001-of-00004.safetensors",
220
+ "vlm.vision_encoder.transformer.resblocks.0.ln_2.weight": "model-00001-of-00004.safetensors",
221
+ "vlm.vision_encoder.transformer.resblocks.0.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
222
+ "vlm.vision_encoder.transformer.resblocks.0.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
223
+ "vlm.vision_encoder.transformer.resblocks.0.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
224
+ "vlm.vision_encoder.transformer.resblocks.0.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
225
+ "vlm.vision_encoder.transformer.resblocks.1.attn.in_proj_bias": "model-00001-of-00004.safetensors",
226
+ "vlm.vision_encoder.transformer.resblocks.1.attn.in_proj_weight": "model-00001-of-00004.safetensors",
227
+ "vlm.vision_encoder.transformer.resblocks.1.attn.out_proj.bias": "model-00001-of-00004.safetensors",
228
+ "vlm.vision_encoder.transformer.resblocks.1.attn.out_proj.weight": "model-00001-of-00004.safetensors",
229
+ "vlm.vision_encoder.transformer.resblocks.1.ln_1.bias": "model-00001-of-00004.safetensors",
230
+ "vlm.vision_encoder.transformer.resblocks.1.ln_1.weight": "model-00001-of-00004.safetensors",
231
+ "vlm.vision_encoder.transformer.resblocks.1.ln_2.bias": "model-00001-of-00004.safetensors",
232
+ "vlm.vision_encoder.transformer.resblocks.1.ln_2.weight": "model-00001-of-00004.safetensors",
233
+ "vlm.vision_encoder.transformer.resblocks.1.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
234
+ "vlm.vision_encoder.transformer.resblocks.1.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
235
+ "vlm.vision_encoder.transformer.resblocks.1.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
236
+ "vlm.vision_encoder.transformer.resblocks.1.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
237
+ "vlm.vision_encoder.transformer.resblocks.10.attn.in_proj_bias": "model-00001-of-00004.safetensors",
238
+ "vlm.vision_encoder.transformer.resblocks.10.attn.in_proj_weight": "model-00001-of-00004.safetensors",
239
+ "vlm.vision_encoder.transformer.resblocks.10.attn.out_proj.bias": "model-00001-of-00004.safetensors",
240
+ "vlm.vision_encoder.transformer.resblocks.10.attn.out_proj.weight": "model-00001-of-00004.safetensors",
241
+ "vlm.vision_encoder.transformer.resblocks.10.ln_1.bias": "model-00001-of-00004.safetensors",
242
+ "vlm.vision_encoder.transformer.resblocks.10.ln_1.weight": "model-00001-of-00004.safetensors",
243
+ "vlm.vision_encoder.transformer.resblocks.10.ln_2.bias": "model-00001-of-00004.safetensors",
244
+ "vlm.vision_encoder.transformer.resblocks.10.ln_2.weight": "model-00001-of-00004.safetensors",
245
+ "vlm.vision_encoder.transformer.resblocks.10.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
246
+ "vlm.vision_encoder.transformer.resblocks.10.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
247
+ "vlm.vision_encoder.transformer.resblocks.10.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
248
+ "vlm.vision_encoder.transformer.resblocks.10.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
249
+ "vlm.vision_encoder.transformer.resblocks.11.attn.in_proj_bias": "model-00001-of-00004.safetensors",
250
+ "vlm.vision_encoder.transformer.resblocks.11.attn.in_proj_weight": "model-00001-of-00004.safetensors",
251
+ "vlm.vision_encoder.transformer.resblocks.11.attn.out_proj.bias": "model-00001-of-00004.safetensors",
252
+ "vlm.vision_encoder.transformer.resblocks.11.attn.out_proj.weight": "model-00001-of-00004.safetensors",
253
+ "vlm.vision_encoder.transformer.resblocks.11.ln_1.bias": "model-00001-of-00004.safetensors",
254
+ "vlm.vision_encoder.transformer.resblocks.11.ln_1.weight": "model-00001-of-00004.safetensors",
255
+ "vlm.vision_encoder.transformer.resblocks.11.ln_2.bias": "model-00001-of-00004.safetensors",
256
+ "vlm.vision_encoder.transformer.resblocks.11.ln_2.weight": "model-00001-of-00004.safetensors",
257
+ "vlm.vision_encoder.transformer.resblocks.11.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
258
+ "vlm.vision_encoder.transformer.resblocks.11.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
259
+ "vlm.vision_encoder.transformer.resblocks.11.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
260
+ "vlm.vision_encoder.transformer.resblocks.11.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
261
+ "vlm.vision_encoder.transformer.resblocks.12.attn.in_proj_bias": "model-00001-of-00004.safetensors",
262
+ "vlm.vision_encoder.transformer.resblocks.12.attn.in_proj_weight": "model-00001-of-00004.safetensors",
263
+ "vlm.vision_encoder.transformer.resblocks.12.attn.out_proj.bias": "model-00001-of-00004.safetensors",
264
+ "vlm.vision_encoder.transformer.resblocks.12.attn.out_proj.weight": "model-00001-of-00004.safetensors",
265
+ "vlm.vision_encoder.transformer.resblocks.12.ln_1.bias": "model-00001-of-00004.safetensors",
266
+ "vlm.vision_encoder.transformer.resblocks.12.ln_1.weight": "model-00001-of-00004.safetensors",
267
+ "vlm.vision_encoder.transformer.resblocks.12.ln_2.bias": "model-00001-of-00004.safetensors",
268
+ "vlm.vision_encoder.transformer.resblocks.12.ln_2.weight": "model-00001-of-00004.safetensors",
269
+ "vlm.vision_encoder.transformer.resblocks.12.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
270
+ "vlm.vision_encoder.transformer.resblocks.12.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
271
+ "vlm.vision_encoder.transformer.resblocks.12.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
272
+ "vlm.vision_encoder.transformer.resblocks.12.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
273
+ "vlm.vision_encoder.transformer.resblocks.13.attn.in_proj_bias": "model-00001-of-00004.safetensors",
274
+ "vlm.vision_encoder.transformer.resblocks.13.attn.in_proj_weight": "model-00001-of-00004.safetensors",
275
+ "vlm.vision_encoder.transformer.resblocks.13.attn.out_proj.bias": "model-00001-of-00004.safetensors",
276
+ "vlm.vision_encoder.transformer.resblocks.13.attn.out_proj.weight": "model-00001-of-00004.safetensors",
277
+ "vlm.vision_encoder.transformer.resblocks.13.ln_1.bias": "model-00001-of-00004.safetensors",
278
+ "vlm.vision_encoder.transformer.resblocks.13.ln_1.weight": "model-00001-of-00004.safetensors",
279
+ "vlm.vision_encoder.transformer.resblocks.13.ln_2.bias": "model-00001-of-00004.safetensors",
280
+ "vlm.vision_encoder.transformer.resblocks.13.ln_2.weight": "model-00001-of-00004.safetensors",
281
+ "vlm.vision_encoder.transformer.resblocks.13.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
282
+ "vlm.vision_encoder.transformer.resblocks.13.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
283
+ "vlm.vision_encoder.transformer.resblocks.13.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
284
+ "vlm.vision_encoder.transformer.resblocks.13.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
285
+ "vlm.vision_encoder.transformer.resblocks.14.attn.in_proj_bias": "model-00001-of-00004.safetensors",
286
+ "vlm.vision_encoder.transformer.resblocks.14.attn.in_proj_weight": "model-00001-of-00004.safetensors",
287
+ "vlm.vision_encoder.transformer.resblocks.14.attn.out_proj.bias": "model-00001-of-00004.safetensors",
288
+ "vlm.vision_encoder.transformer.resblocks.14.attn.out_proj.weight": "model-00001-of-00004.safetensors",
289
+ "vlm.vision_encoder.transformer.resblocks.14.ln_1.bias": "model-00001-of-00004.safetensors",
290
+ "vlm.vision_encoder.transformer.resblocks.14.ln_1.weight": "model-00001-of-00004.safetensors",
291
+ "vlm.vision_encoder.transformer.resblocks.14.ln_2.bias": "model-00001-of-00004.safetensors",
292
+ "vlm.vision_encoder.transformer.resblocks.14.ln_2.weight": "model-00001-of-00004.safetensors",
293
+ "vlm.vision_encoder.transformer.resblocks.14.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
294
+ "vlm.vision_encoder.transformer.resblocks.14.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
295
+ "vlm.vision_encoder.transformer.resblocks.14.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
296
+ "vlm.vision_encoder.transformer.resblocks.14.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
297
+ "vlm.vision_encoder.transformer.resblocks.15.attn.in_proj_bias": "model-00001-of-00004.safetensors",
298
+ "vlm.vision_encoder.transformer.resblocks.15.attn.in_proj_weight": "model-00001-of-00004.safetensors",
299
+ "vlm.vision_encoder.transformer.resblocks.15.attn.out_proj.bias": "model-00001-of-00004.safetensors",
300
+ "vlm.vision_encoder.transformer.resblocks.15.attn.out_proj.weight": "model-00001-of-00004.safetensors",
301
+ "vlm.vision_encoder.transformer.resblocks.15.ln_1.bias": "model-00001-of-00004.safetensors",
302
+ "vlm.vision_encoder.transformer.resblocks.15.ln_1.weight": "model-00001-of-00004.safetensors",
303
+ "vlm.vision_encoder.transformer.resblocks.15.ln_2.bias": "model-00001-of-00004.safetensors",
304
+ "vlm.vision_encoder.transformer.resblocks.15.ln_2.weight": "model-00001-of-00004.safetensors",
305
+ "vlm.vision_encoder.transformer.resblocks.15.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
306
+ "vlm.vision_encoder.transformer.resblocks.15.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
307
+ "vlm.vision_encoder.transformer.resblocks.15.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
308
+ "vlm.vision_encoder.transformer.resblocks.15.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
309
+ "vlm.vision_encoder.transformer.resblocks.16.attn.in_proj_bias": "model-00001-of-00004.safetensors",
310
+ "vlm.vision_encoder.transformer.resblocks.16.attn.in_proj_weight": "model-00001-of-00004.safetensors",
311
+ "vlm.vision_encoder.transformer.resblocks.16.attn.out_proj.bias": "model-00001-of-00004.safetensors",
312
+ "vlm.vision_encoder.transformer.resblocks.16.attn.out_proj.weight": "model-00001-of-00004.safetensors",
313
+ "vlm.vision_encoder.transformer.resblocks.16.ln_1.bias": "model-00001-of-00004.safetensors",
314
+ "vlm.vision_encoder.transformer.resblocks.16.ln_1.weight": "model-00001-of-00004.safetensors",
315
+ "vlm.vision_encoder.transformer.resblocks.16.ln_2.bias": "model-00001-of-00004.safetensors",
316
+ "vlm.vision_encoder.transformer.resblocks.16.ln_2.weight": "model-00001-of-00004.safetensors",
317
+ "vlm.vision_encoder.transformer.resblocks.16.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
318
+ "vlm.vision_encoder.transformer.resblocks.16.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
319
+ "vlm.vision_encoder.transformer.resblocks.16.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
320
+ "vlm.vision_encoder.transformer.resblocks.16.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
321
+ "vlm.vision_encoder.transformer.resblocks.17.attn.in_proj_bias": "model-00001-of-00004.safetensors",
322
+ "vlm.vision_encoder.transformer.resblocks.17.attn.in_proj_weight": "model-00001-of-00004.safetensors",
323
+ "vlm.vision_encoder.transformer.resblocks.17.attn.out_proj.bias": "model-00001-of-00004.safetensors",
324
+ "vlm.vision_encoder.transformer.resblocks.17.attn.out_proj.weight": "model-00001-of-00004.safetensors",
325
+ "vlm.vision_encoder.transformer.resblocks.17.ln_1.bias": "model-00001-of-00004.safetensors",
326
+ "vlm.vision_encoder.transformer.resblocks.17.ln_1.weight": "model-00001-of-00004.safetensors",
327
+ "vlm.vision_encoder.transformer.resblocks.17.ln_2.bias": "model-00001-of-00004.safetensors",
328
+ "vlm.vision_encoder.transformer.resblocks.17.ln_2.weight": "model-00001-of-00004.safetensors",
329
+ "vlm.vision_encoder.transformer.resblocks.17.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
330
+ "vlm.vision_encoder.transformer.resblocks.17.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
331
+ "vlm.vision_encoder.transformer.resblocks.17.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
332
+ "vlm.vision_encoder.transformer.resblocks.17.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
333
+ "vlm.vision_encoder.transformer.resblocks.18.attn.in_proj_bias": "model-00001-of-00004.safetensors",
334
+ "vlm.vision_encoder.transformer.resblocks.18.attn.in_proj_weight": "model-00001-of-00004.safetensors",
335
+ "vlm.vision_encoder.transformer.resblocks.18.attn.out_proj.bias": "model-00001-of-00004.safetensors",
336
+ "vlm.vision_encoder.transformer.resblocks.18.attn.out_proj.weight": "model-00001-of-00004.safetensors",
337
+ "vlm.vision_encoder.transformer.resblocks.18.ln_1.bias": "model-00001-of-00004.safetensors",
338
+ "vlm.vision_encoder.transformer.resblocks.18.ln_1.weight": "model-00001-of-00004.safetensors",
339
+ "vlm.vision_encoder.transformer.resblocks.18.ln_2.bias": "model-00001-of-00004.safetensors",
340
+ "vlm.vision_encoder.transformer.resblocks.18.ln_2.weight": "model-00001-of-00004.safetensors",
341
+ "vlm.vision_encoder.transformer.resblocks.18.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
342
+ "vlm.vision_encoder.transformer.resblocks.18.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
343
+ "vlm.vision_encoder.transformer.resblocks.18.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
344
+ "vlm.vision_encoder.transformer.resblocks.18.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
345
+ "vlm.vision_encoder.transformer.resblocks.19.attn.in_proj_bias": "model-00001-of-00004.safetensors",
346
+ "vlm.vision_encoder.transformer.resblocks.19.attn.in_proj_weight": "model-00001-of-00004.safetensors",
347
+ "vlm.vision_encoder.transformer.resblocks.19.attn.out_proj.bias": "model-00001-of-00004.safetensors",
348
+ "vlm.vision_encoder.transformer.resblocks.19.attn.out_proj.weight": "model-00001-of-00004.safetensors",
349
+ "vlm.vision_encoder.transformer.resblocks.19.ln_1.bias": "model-00001-of-00004.safetensors",
350
+ "vlm.vision_encoder.transformer.resblocks.19.ln_1.weight": "model-00001-of-00004.safetensors",
351
+ "vlm.vision_encoder.transformer.resblocks.19.ln_2.bias": "model-00001-of-00004.safetensors",
352
+ "vlm.vision_encoder.transformer.resblocks.19.ln_2.weight": "model-00001-of-00004.safetensors",
353
+ "vlm.vision_encoder.transformer.resblocks.19.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
354
+ "vlm.vision_encoder.transformer.resblocks.19.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
355
+ "vlm.vision_encoder.transformer.resblocks.19.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
356
+ "vlm.vision_encoder.transformer.resblocks.19.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
357
+ "vlm.vision_encoder.transformer.resblocks.2.attn.in_proj_bias": "model-00001-of-00004.safetensors",
358
+ "vlm.vision_encoder.transformer.resblocks.2.attn.in_proj_weight": "model-00001-of-00004.safetensors",
359
+ "vlm.vision_encoder.transformer.resblocks.2.attn.out_proj.bias": "model-00001-of-00004.safetensors",
360
+ "vlm.vision_encoder.transformer.resblocks.2.attn.out_proj.weight": "model-00001-of-00004.safetensors",
361
+ "vlm.vision_encoder.transformer.resblocks.2.ln_1.bias": "model-00001-of-00004.safetensors",
362
+ "vlm.vision_encoder.transformer.resblocks.2.ln_1.weight": "model-00001-of-00004.safetensors",
363
+ "vlm.vision_encoder.transformer.resblocks.2.ln_2.bias": "model-00001-of-00004.safetensors",
364
+ "vlm.vision_encoder.transformer.resblocks.2.ln_2.weight": "model-00001-of-00004.safetensors",
365
+ "vlm.vision_encoder.transformer.resblocks.2.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
366
+ "vlm.vision_encoder.transformer.resblocks.2.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
367
+ "vlm.vision_encoder.transformer.resblocks.2.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
368
+ "vlm.vision_encoder.transformer.resblocks.2.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
369
+ "vlm.vision_encoder.transformer.resblocks.20.attn.in_proj_bias": "model-00001-of-00004.safetensors",
370
+ "vlm.vision_encoder.transformer.resblocks.20.attn.in_proj_weight": "model-00001-of-00004.safetensors",
371
+ "vlm.vision_encoder.transformer.resblocks.20.attn.out_proj.bias": "model-00001-of-00004.safetensors",
372
+ "vlm.vision_encoder.transformer.resblocks.20.attn.out_proj.weight": "model-00001-of-00004.safetensors",
373
+ "vlm.vision_encoder.transformer.resblocks.20.ln_1.bias": "model-00001-of-00004.safetensors",
374
+ "vlm.vision_encoder.transformer.resblocks.20.ln_1.weight": "model-00001-of-00004.safetensors",
375
+ "vlm.vision_encoder.transformer.resblocks.20.ln_2.bias": "model-00001-of-00004.safetensors",
376
+ "vlm.vision_encoder.transformer.resblocks.20.ln_2.weight": "model-00001-of-00004.safetensors",
377
+ "vlm.vision_encoder.transformer.resblocks.20.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
378
+ "vlm.vision_encoder.transformer.resblocks.20.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
379
+ "vlm.vision_encoder.transformer.resblocks.20.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
380
+ "vlm.vision_encoder.transformer.resblocks.20.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
381
+ "vlm.vision_encoder.transformer.resblocks.21.attn.in_proj_bias": "model-00001-of-00004.safetensors",
382
+ "vlm.vision_encoder.transformer.resblocks.21.attn.in_proj_weight": "model-00001-of-00004.safetensors",
383
+ "vlm.vision_encoder.transformer.resblocks.21.attn.out_proj.bias": "model-00001-of-00004.safetensors",
384
+ "vlm.vision_encoder.transformer.resblocks.21.attn.out_proj.weight": "model-00001-of-00004.safetensors",
385
+ "vlm.vision_encoder.transformer.resblocks.21.ln_1.bias": "model-00001-of-00004.safetensors",
386
+ "vlm.vision_encoder.transformer.resblocks.21.ln_1.weight": "model-00001-of-00004.safetensors",
387
+ "vlm.vision_encoder.transformer.resblocks.21.ln_2.bias": "model-00001-of-00004.safetensors",
388
+ "vlm.vision_encoder.transformer.resblocks.21.ln_2.weight": "model-00001-of-00004.safetensors",
389
+ "vlm.vision_encoder.transformer.resblocks.21.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
390
+ "vlm.vision_encoder.transformer.resblocks.21.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
391
+ "vlm.vision_encoder.transformer.resblocks.21.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
392
+ "vlm.vision_encoder.transformer.resblocks.21.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
393
+ "vlm.vision_encoder.transformer.resblocks.22.attn.in_proj_bias": "model-00001-of-00004.safetensors",
394
+ "vlm.vision_encoder.transformer.resblocks.22.attn.in_proj_weight": "model-00001-of-00004.safetensors",
395
+ "vlm.vision_encoder.transformer.resblocks.22.attn.out_proj.bias": "model-00001-of-00004.safetensors",
396
+ "vlm.vision_encoder.transformer.resblocks.22.attn.out_proj.weight": "model-00001-of-00004.safetensors",
397
+ "vlm.vision_encoder.transformer.resblocks.22.ln_1.bias": "model-00001-of-00004.safetensors",
398
+ "vlm.vision_encoder.transformer.resblocks.22.ln_1.weight": "model-00001-of-00004.safetensors",
399
+ "vlm.vision_encoder.transformer.resblocks.22.ln_2.bias": "model-00001-of-00004.safetensors",
400
+ "vlm.vision_encoder.transformer.resblocks.22.ln_2.weight": "model-00001-of-00004.safetensors",
401
+ "vlm.vision_encoder.transformer.resblocks.22.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
402
+ "vlm.vision_encoder.transformer.resblocks.22.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
403
+ "vlm.vision_encoder.transformer.resblocks.22.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
404
+ "vlm.vision_encoder.transformer.resblocks.22.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
405
+ "vlm.vision_encoder.transformer.resblocks.23.attn.in_proj_bias": "model-00001-of-00004.safetensors",
406
+ "vlm.vision_encoder.transformer.resblocks.23.attn.in_proj_weight": "model-00001-of-00004.safetensors",
407
+ "vlm.vision_encoder.transformer.resblocks.23.attn.out_proj.bias": "model-00001-of-00004.safetensors",
408
+ "vlm.vision_encoder.transformer.resblocks.23.attn.out_proj.weight": "model-00001-of-00004.safetensors",
409
+ "vlm.vision_encoder.transformer.resblocks.23.ln_1.bias": "model-00001-of-00004.safetensors",
410
+ "vlm.vision_encoder.transformer.resblocks.23.ln_1.weight": "model-00001-of-00004.safetensors",
411
+ "vlm.vision_encoder.transformer.resblocks.23.ln_2.bias": "model-00001-of-00004.safetensors",
412
+ "vlm.vision_encoder.transformer.resblocks.23.ln_2.weight": "model-00001-of-00004.safetensors",
413
+ "vlm.vision_encoder.transformer.resblocks.23.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
414
+ "vlm.vision_encoder.transformer.resblocks.23.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
415
+ "vlm.vision_encoder.transformer.resblocks.23.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
416
+ "vlm.vision_encoder.transformer.resblocks.23.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
417
+ "vlm.vision_encoder.transformer.resblocks.24.attn.in_proj_bias": "model-00001-of-00004.safetensors",
418
+ "vlm.vision_encoder.transformer.resblocks.24.attn.in_proj_weight": "model-00001-of-00004.safetensors",
419
+ "vlm.vision_encoder.transformer.resblocks.24.attn.out_proj.bias": "model-00001-of-00004.safetensors",
420
+ "vlm.vision_encoder.transformer.resblocks.24.attn.out_proj.weight": "model-00001-of-00004.safetensors",
421
+ "vlm.vision_encoder.transformer.resblocks.24.ln_1.bias": "model-00001-of-00004.safetensors",
422
+ "vlm.vision_encoder.transformer.resblocks.24.ln_1.weight": "model-00001-of-00004.safetensors",
423
+ "vlm.vision_encoder.transformer.resblocks.24.ln_2.bias": "model-00001-of-00004.safetensors",
424
+ "vlm.vision_encoder.transformer.resblocks.24.ln_2.weight": "model-00001-of-00004.safetensors",
425
+ "vlm.vision_encoder.transformer.resblocks.24.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
426
+ "vlm.vision_encoder.transformer.resblocks.24.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
427
+ "vlm.vision_encoder.transformer.resblocks.24.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
428
+ "vlm.vision_encoder.transformer.resblocks.24.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
429
+ "vlm.vision_encoder.transformer.resblocks.25.attn.in_proj_bias": "model-00001-of-00004.safetensors",
430
+ "vlm.vision_encoder.transformer.resblocks.25.attn.in_proj_weight": "model-00001-of-00004.safetensors",
431
+ "vlm.vision_encoder.transformer.resblocks.25.attn.out_proj.bias": "model-00001-of-00004.safetensors",
432
+ "vlm.vision_encoder.transformer.resblocks.25.attn.out_proj.weight": "model-00001-of-00004.safetensors",
433
+ "vlm.vision_encoder.transformer.resblocks.25.ln_1.bias": "model-00001-of-00004.safetensors",
434
+ "vlm.vision_encoder.transformer.resblocks.25.ln_1.weight": "model-00001-of-00004.safetensors",
435
+ "vlm.vision_encoder.transformer.resblocks.25.ln_2.bias": "model-00001-of-00004.safetensors",
436
+ "vlm.vision_encoder.transformer.resblocks.25.ln_2.weight": "model-00001-of-00004.safetensors",
437
+ "vlm.vision_encoder.transformer.resblocks.25.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
438
+ "vlm.vision_encoder.transformer.resblocks.25.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
439
+ "vlm.vision_encoder.transformer.resblocks.25.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
440
+ "vlm.vision_encoder.transformer.resblocks.25.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
441
+ "vlm.vision_encoder.transformer.resblocks.26.attn.in_proj_bias": "model-00001-of-00004.safetensors",
442
+ "vlm.vision_encoder.transformer.resblocks.26.attn.in_proj_weight": "model-00001-of-00004.safetensors",
443
+ "vlm.vision_encoder.transformer.resblocks.26.attn.out_proj.bias": "model-00001-of-00004.safetensors",
444
+ "vlm.vision_encoder.transformer.resblocks.26.attn.out_proj.weight": "model-00001-of-00004.safetensors",
445
+ "vlm.vision_encoder.transformer.resblocks.26.ln_1.bias": "model-00001-of-00004.safetensors",
446
+ "vlm.vision_encoder.transformer.resblocks.26.ln_1.weight": "model-00001-of-00004.safetensors",
447
+ "vlm.vision_encoder.transformer.resblocks.26.ln_2.bias": "model-00001-of-00004.safetensors",
448
+ "vlm.vision_encoder.transformer.resblocks.26.ln_2.weight": "model-00001-of-00004.safetensors",
449
+ "vlm.vision_encoder.transformer.resblocks.26.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
450
+ "vlm.vision_encoder.transformer.resblocks.26.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
451
+ "vlm.vision_encoder.transformer.resblocks.26.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
452
+ "vlm.vision_encoder.transformer.resblocks.26.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
453
+ "vlm.vision_encoder.transformer.resblocks.27.attn.in_proj_bias": "model-00001-of-00004.safetensors",
454
+ "vlm.vision_encoder.transformer.resblocks.27.attn.in_proj_weight": "model-00001-of-00004.safetensors",
455
+ "vlm.vision_encoder.transformer.resblocks.27.attn.out_proj.bias": "model-00001-of-00004.safetensors",
456
+ "vlm.vision_encoder.transformer.resblocks.27.attn.out_proj.weight": "model-00001-of-00004.safetensors",
457
+ "vlm.vision_encoder.transformer.resblocks.27.ln_1.bias": "model-00001-of-00004.safetensors",
458
+ "vlm.vision_encoder.transformer.resblocks.27.ln_1.weight": "model-00001-of-00004.safetensors",
459
+ "vlm.vision_encoder.transformer.resblocks.27.ln_2.bias": "model-00001-of-00004.safetensors",
460
+ "vlm.vision_encoder.transformer.resblocks.27.ln_2.weight": "model-00001-of-00004.safetensors",
461
+ "vlm.vision_encoder.transformer.resblocks.27.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
462
+ "vlm.vision_encoder.transformer.resblocks.27.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
463
+ "vlm.vision_encoder.transformer.resblocks.27.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
464
+ "vlm.vision_encoder.transformer.resblocks.27.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
465
+ "vlm.vision_encoder.transformer.resblocks.28.attn.in_proj_bias": "model-00001-of-00004.safetensors",
466
+ "vlm.vision_encoder.transformer.resblocks.28.attn.in_proj_weight": "model-00001-of-00004.safetensors",
467
+ "vlm.vision_encoder.transformer.resblocks.28.attn.out_proj.bias": "model-00001-of-00004.safetensors",
468
+ "vlm.vision_encoder.transformer.resblocks.28.attn.out_proj.weight": "model-00001-of-00004.safetensors",
469
+ "vlm.vision_encoder.transformer.resblocks.28.ln_1.bias": "model-00001-of-00004.safetensors",
470
+ "vlm.vision_encoder.transformer.resblocks.28.ln_1.weight": "model-00001-of-00004.safetensors",
471
+ "vlm.vision_encoder.transformer.resblocks.28.ln_2.bias": "model-00001-of-00004.safetensors",
472
+ "vlm.vision_encoder.transformer.resblocks.28.ln_2.weight": "model-00001-of-00004.safetensors",
473
+ "vlm.vision_encoder.transformer.resblocks.28.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
474
+ "vlm.vision_encoder.transformer.resblocks.28.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
475
+ "vlm.vision_encoder.transformer.resblocks.28.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
476
+ "vlm.vision_encoder.transformer.resblocks.28.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
477
+ "vlm.vision_encoder.transformer.resblocks.29.attn.in_proj_bias": "model-00001-of-00004.safetensors",
478
+ "vlm.vision_encoder.transformer.resblocks.29.attn.in_proj_weight": "model-00001-of-00004.safetensors",
479
+ "vlm.vision_encoder.transformer.resblocks.29.attn.out_proj.bias": "model-00001-of-00004.safetensors",
480
+ "vlm.vision_encoder.transformer.resblocks.29.attn.out_proj.weight": "model-00001-of-00004.safetensors",
481
+ "vlm.vision_encoder.transformer.resblocks.29.ln_1.bias": "model-00001-of-00004.safetensors",
482
+ "vlm.vision_encoder.transformer.resblocks.29.ln_1.weight": "model-00001-of-00004.safetensors",
483
+ "vlm.vision_encoder.transformer.resblocks.29.ln_2.bias": "model-00001-of-00004.safetensors",
484
+ "vlm.vision_encoder.transformer.resblocks.29.ln_2.weight": "model-00001-of-00004.safetensors",
485
+ "vlm.vision_encoder.transformer.resblocks.29.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
486
+ "vlm.vision_encoder.transformer.resblocks.29.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
487
+ "vlm.vision_encoder.transformer.resblocks.29.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
488
+ "vlm.vision_encoder.transformer.resblocks.29.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
489
+ "vlm.vision_encoder.transformer.resblocks.3.attn.in_proj_bias": "model-00001-of-00004.safetensors",
490
+ "vlm.vision_encoder.transformer.resblocks.3.attn.in_proj_weight": "model-00001-of-00004.safetensors",
491
+ "vlm.vision_encoder.transformer.resblocks.3.attn.out_proj.bias": "model-00001-of-00004.safetensors",
492
+ "vlm.vision_encoder.transformer.resblocks.3.attn.out_proj.weight": "model-00001-of-00004.safetensors",
493
+ "vlm.vision_encoder.transformer.resblocks.3.ln_1.bias": "model-00001-of-00004.safetensors",
494
+ "vlm.vision_encoder.transformer.resblocks.3.ln_1.weight": "model-00001-of-00004.safetensors",
495
+ "vlm.vision_encoder.transformer.resblocks.3.ln_2.bias": "model-00001-of-00004.safetensors",
496
+ "vlm.vision_encoder.transformer.resblocks.3.ln_2.weight": "model-00001-of-00004.safetensors",
497
+ "vlm.vision_encoder.transformer.resblocks.3.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
498
+ "vlm.vision_encoder.transformer.resblocks.3.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
499
+ "vlm.vision_encoder.transformer.resblocks.3.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
500
+ "vlm.vision_encoder.transformer.resblocks.3.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
501
+ "vlm.vision_encoder.transformer.resblocks.30.attn.in_proj_bias": "model-00001-of-00004.safetensors",
502
+ "vlm.vision_encoder.transformer.resblocks.30.attn.in_proj_weight": "model-00001-of-00004.safetensors",
503
+ "vlm.vision_encoder.transformer.resblocks.30.attn.out_proj.bias": "model-00001-of-00004.safetensors",
504
+ "vlm.vision_encoder.transformer.resblocks.30.attn.out_proj.weight": "model-00001-of-00004.safetensors",
505
+ "vlm.vision_encoder.transformer.resblocks.30.ln_1.bias": "model-00001-of-00004.safetensors",
506
+ "vlm.vision_encoder.transformer.resblocks.30.ln_1.weight": "model-00001-of-00004.safetensors",
507
+ "vlm.vision_encoder.transformer.resblocks.30.ln_2.bias": "model-00001-of-00004.safetensors",
508
+ "vlm.vision_encoder.transformer.resblocks.30.ln_2.weight": "model-00001-of-00004.safetensors",
509
+ "vlm.vision_encoder.transformer.resblocks.30.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
510
+ "vlm.vision_encoder.transformer.resblocks.30.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
511
+ "vlm.vision_encoder.transformer.resblocks.30.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
512
+ "vlm.vision_encoder.transformer.resblocks.30.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
513
+ "vlm.vision_encoder.transformer.resblocks.31.attn.in_proj_bias": "model-00001-of-00004.safetensors",
514
+ "vlm.vision_encoder.transformer.resblocks.31.attn.in_proj_weight": "model-00001-of-00004.safetensors",
515
+ "vlm.vision_encoder.transformer.resblocks.31.attn.out_proj.bias": "model-00001-of-00004.safetensors",
516
+ "vlm.vision_encoder.transformer.resblocks.31.attn.out_proj.weight": "model-00001-of-00004.safetensors",
517
+ "vlm.vision_encoder.transformer.resblocks.31.ln_1.bias": "model-00001-of-00004.safetensors",
518
+ "vlm.vision_encoder.transformer.resblocks.31.ln_1.weight": "model-00001-of-00004.safetensors",
519
+ "vlm.vision_encoder.transformer.resblocks.31.ln_2.bias": "model-00001-of-00004.safetensors",
520
+ "vlm.vision_encoder.transformer.resblocks.31.ln_2.weight": "model-00001-of-00004.safetensors",
521
+ "vlm.vision_encoder.transformer.resblocks.31.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
522
+ "vlm.vision_encoder.transformer.resblocks.31.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
523
+ "vlm.vision_encoder.transformer.resblocks.31.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
524
+ "vlm.vision_encoder.transformer.resblocks.31.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
525
+ "vlm.vision_encoder.transformer.resblocks.4.attn.in_proj_bias": "model-00001-of-00004.safetensors",
526
+ "vlm.vision_encoder.transformer.resblocks.4.attn.in_proj_weight": "model-00001-of-00004.safetensors",
527
+ "vlm.vision_encoder.transformer.resblocks.4.attn.out_proj.bias": "model-00001-of-00004.safetensors",
528
+ "vlm.vision_encoder.transformer.resblocks.4.attn.out_proj.weight": "model-00001-of-00004.safetensors",
529
+ "vlm.vision_encoder.transformer.resblocks.4.ln_1.bias": "model-00001-of-00004.safetensors",
530
+ "vlm.vision_encoder.transformer.resblocks.4.ln_1.weight": "model-00001-of-00004.safetensors",
531
+ "vlm.vision_encoder.transformer.resblocks.4.ln_2.bias": "model-00001-of-00004.safetensors",
532
+ "vlm.vision_encoder.transformer.resblocks.4.ln_2.weight": "model-00001-of-00004.safetensors",
533
+ "vlm.vision_encoder.transformer.resblocks.4.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
534
+ "vlm.vision_encoder.transformer.resblocks.4.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
535
+ "vlm.vision_encoder.transformer.resblocks.4.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
536
+ "vlm.vision_encoder.transformer.resblocks.4.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
537
+ "vlm.vision_encoder.transformer.resblocks.5.attn.in_proj_bias": "model-00001-of-00004.safetensors",
538
+ "vlm.vision_encoder.transformer.resblocks.5.attn.in_proj_weight": "model-00001-of-00004.safetensors",
539
+ "vlm.vision_encoder.transformer.resblocks.5.attn.out_proj.bias": "model-00001-of-00004.safetensors",
540
+ "vlm.vision_encoder.transformer.resblocks.5.attn.out_proj.weight": "model-00001-of-00004.safetensors",
541
+ "vlm.vision_encoder.transformer.resblocks.5.ln_1.bias": "model-00001-of-00004.safetensors",
542
+ "vlm.vision_encoder.transformer.resblocks.5.ln_1.weight": "model-00001-of-00004.safetensors",
543
+ "vlm.vision_encoder.transformer.resblocks.5.ln_2.bias": "model-00001-of-00004.safetensors",
544
+ "vlm.vision_encoder.transformer.resblocks.5.ln_2.weight": "model-00001-of-00004.safetensors",
545
+ "vlm.vision_encoder.transformer.resblocks.5.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
546
+ "vlm.vision_encoder.transformer.resblocks.5.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
547
+ "vlm.vision_encoder.transformer.resblocks.5.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
548
+ "vlm.vision_encoder.transformer.resblocks.5.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
549
+ "vlm.vision_encoder.transformer.resblocks.6.attn.in_proj_bias": "model-00001-of-00004.safetensors",
550
+ "vlm.vision_encoder.transformer.resblocks.6.attn.in_proj_weight": "model-00001-of-00004.safetensors",
551
+ "vlm.vision_encoder.transformer.resblocks.6.attn.out_proj.bias": "model-00001-of-00004.safetensors",
552
+ "vlm.vision_encoder.transformer.resblocks.6.attn.out_proj.weight": "model-00001-of-00004.safetensors",
553
+ "vlm.vision_encoder.transformer.resblocks.6.ln_1.bias": "model-00001-of-00004.safetensors",
554
+ "vlm.vision_encoder.transformer.resblocks.6.ln_1.weight": "model-00001-of-00004.safetensors",
555
+ "vlm.vision_encoder.transformer.resblocks.6.ln_2.bias": "model-00001-of-00004.safetensors",
556
+ "vlm.vision_encoder.transformer.resblocks.6.ln_2.weight": "model-00001-of-00004.safetensors",
557
+ "vlm.vision_encoder.transformer.resblocks.6.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
558
+ "vlm.vision_encoder.transformer.resblocks.6.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
559
+ "vlm.vision_encoder.transformer.resblocks.6.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
560
+ "vlm.vision_encoder.transformer.resblocks.6.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
561
+ "vlm.vision_encoder.transformer.resblocks.7.attn.in_proj_bias": "model-00001-of-00004.safetensors",
562
+ "vlm.vision_encoder.transformer.resblocks.7.attn.in_proj_weight": "model-00001-of-00004.safetensors",
563
+ "vlm.vision_encoder.transformer.resblocks.7.attn.out_proj.bias": "model-00001-of-00004.safetensors",
564
+ "vlm.vision_encoder.transformer.resblocks.7.attn.out_proj.weight": "model-00001-of-00004.safetensors",
565
+ "vlm.vision_encoder.transformer.resblocks.7.ln_1.bias": "model-00001-of-00004.safetensors",
566
+ "vlm.vision_encoder.transformer.resblocks.7.ln_1.weight": "model-00001-of-00004.safetensors",
567
+ "vlm.vision_encoder.transformer.resblocks.7.ln_2.bias": "model-00001-of-00004.safetensors",
568
+ "vlm.vision_encoder.transformer.resblocks.7.ln_2.weight": "model-00001-of-00004.safetensors",
569
+ "vlm.vision_encoder.transformer.resblocks.7.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
570
+ "vlm.vision_encoder.transformer.resblocks.7.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
571
+ "vlm.vision_encoder.transformer.resblocks.7.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
572
+ "vlm.vision_encoder.transformer.resblocks.7.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
573
+ "vlm.vision_encoder.transformer.resblocks.8.attn.in_proj_bias": "model-00001-of-00004.safetensors",
574
+ "vlm.vision_encoder.transformer.resblocks.8.attn.in_proj_weight": "model-00001-of-00004.safetensors",
575
+ "vlm.vision_encoder.transformer.resblocks.8.attn.out_proj.bias": "model-00001-of-00004.safetensors",
576
+ "vlm.vision_encoder.transformer.resblocks.8.attn.out_proj.weight": "model-00001-of-00004.safetensors",
577
+ "vlm.vision_encoder.transformer.resblocks.8.ln_1.bias": "model-00001-of-00004.safetensors",
578
+ "vlm.vision_encoder.transformer.resblocks.8.ln_1.weight": "model-00001-of-00004.safetensors",
579
+ "vlm.vision_encoder.transformer.resblocks.8.ln_2.bias": "model-00001-of-00004.safetensors",
580
+ "vlm.vision_encoder.transformer.resblocks.8.ln_2.weight": "model-00001-of-00004.safetensors",
581
+ "vlm.vision_encoder.transformer.resblocks.8.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
582
+ "vlm.vision_encoder.transformer.resblocks.8.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
583
+ "vlm.vision_encoder.transformer.resblocks.8.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
584
+ "vlm.vision_encoder.transformer.resblocks.8.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
585
+ "vlm.vision_encoder.transformer.resblocks.9.attn.in_proj_bias": "model-00001-of-00004.safetensors",
586
+ "vlm.vision_encoder.transformer.resblocks.9.attn.in_proj_weight": "model-00001-of-00004.safetensors",
587
+ "vlm.vision_encoder.transformer.resblocks.9.attn.out_proj.bias": "model-00001-of-00004.safetensors",
588
+ "vlm.vision_encoder.transformer.resblocks.9.attn.out_proj.weight": "model-00001-of-00004.safetensors",
589
+ "vlm.vision_encoder.transformer.resblocks.9.ln_1.bias": "model-00001-of-00004.safetensors",
590
+ "vlm.vision_encoder.transformer.resblocks.9.ln_1.weight": "model-00001-of-00004.safetensors",
591
+ "vlm.vision_encoder.transformer.resblocks.9.ln_2.bias": "model-00001-of-00004.safetensors",
592
+ "vlm.vision_encoder.transformer.resblocks.9.ln_2.weight": "model-00001-of-00004.safetensors",
593
+ "vlm.vision_encoder.transformer.resblocks.9.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
594
+ "vlm.vision_encoder.transformer.resblocks.9.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
595
+ "vlm.vision_encoder.transformer.resblocks.9.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
596
+ "vlm.vision_encoder.transformer.resblocks.9.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
597
+ "vlm.vision_tokenizer.latents": "model-00001-of-00004.safetensors",
598
+ "vlm.vision_tokenizer.layers.0.0.norm_latents.bias": "model-00001-of-00004.safetensors",
599
+ "vlm.vision_tokenizer.layers.0.0.norm_latents.weight": "model-00001-of-00004.safetensors",
600
+ "vlm.vision_tokenizer.layers.0.0.norm_media.bias": "model-00001-of-00004.safetensors",
601
+ "vlm.vision_tokenizer.layers.0.0.norm_media.weight": "model-00001-of-00004.safetensors",
602
+ "vlm.vision_tokenizer.layers.0.0.to_kv.weight": "model-00001-of-00004.safetensors",
603
+ "vlm.vision_tokenizer.layers.0.0.to_out.weight": "model-00001-of-00004.safetensors",
604
+ "vlm.vision_tokenizer.layers.0.0.to_q.weight": "model-00001-of-00004.safetensors",
605
+ "vlm.vision_tokenizer.layers.0.1.0.bias": "model-00001-of-00004.safetensors",
606
+ "vlm.vision_tokenizer.layers.0.1.0.weight": "model-00001-of-00004.safetensors",
607
+ "vlm.vision_tokenizer.layers.0.1.1.weight": "model-00001-of-00004.safetensors",
608
+ "vlm.vision_tokenizer.layers.0.1.3.weight": "model-00001-of-00004.safetensors",
609
+ "vlm.vision_tokenizer.layers.1.0.norm_latents.bias": "model-00001-of-00004.safetensors",
610
+ "vlm.vision_tokenizer.layers.1.0.norm_latents.weight": "model-00001-of-00004.safetensors",
611
+ "vlm.vision_tokenizer.layers.1.0.norm_media.bias": "model-00001-of-00004.safetensors",
612
+ "vlm.vision_tokenizer.layers.1.0.norm_media.weight": "model-00001-of-00004.safetensors",
613
+ "vlm.vision_tokenizer.layers.1.0.to_kv.weight": "model-00001-of-00004.safetensors",
614
+ "vlm.vision_tokenizer.layers.1.0.to_out.weight": "model-00001-of-00004.safetensors",
615
+ "vlm.vision_tokenizer.layers.1.0.to_q.weight": "model-00001-of-00004.safetensors",
616
+ "vlm.vision_tokenizer.layers.1.1.0.bias": "model-00001-of-00004.safetensors",
617
+ "vlm.vision_tokenizer.layers.1.1.0.weight": "model-00001-of-00004.safetensors",
618
+ "vlm.vision_tokenizer.layers.1.1.1.weight": "model-00001-of-00004.safetensors",
619
+ "vlm.vision_tokenizer.layers.1.1.3.weight": "model-00001-of-00004.safetensors",
620
+ "vlm.vision_tokenizer.layers.2.0.norm_latents.bias": "model-00001-of-00004.safetensors",
621
+ "vlm.vision_tokenizer.layers.2.0.norm_latents.weight": "model-00001-of-00004.safetensors",
622
+ "vlm.vision_tokenizer.layers.2.0.norm_media.bias": "model-00001-of-00004.safetensors",
623
+ "vlm.vision_tokenizer.layers.2.0.norm_media.weight": "model-00001-of-00004.safetensors",
624
+ "vlm.vision_tokenizer.layers.2.0.to_kv.weight": "model-00001-of-00004.safetensors",
625
+ "vlm.vision_tokenizer.layers.2.0.to_out.weight": "model-00001-of-00004.safetensors",
626
+ "vlm.vision_tokenizer.layers.2.0.to_q.weight": "model-00001-of-00004.safetensors",
627
+ "vlm.vision_tokenizer.layers.2.1.0.bias": "model-00001-of-00004.safetensors",
628
+ "vlm.vision_tokenizer.layers.2.1.0.weight": "model-00001-of-00004.safetensors",
629
+ "vlm.vision_tokenizer.layers.2.1.1.weight": "model-00001-of-00004.safetensors",
630
+ "vlm.vision_tokenizer.layers.2.1.3.weight": "model-00001-of-00004.safetensors",
631
+ "vlm.vision_tokenizer.layers.3.0.norm_latents.bias": "model-00001-of-00004.safetensors",
632
+ "vlm.vision_tokenizer.layers.3.0.norm_latents.weight": "model-00001-of-00004.safetensors",
633
+ "vlm.vision_tokenizer.layers.3.0.norm_media.bias": "model-00001-of-00004.safetensors",
634
+ "vlm.vision_tokenizer.layers.3.0.norm_media.weight": "model-00001-of-00004.safetensors",
635
+ "vlm.vision_tokenizer.layers.3.0.to_kv.weight": "model-00001-of-00004.safetensors",
636
+ "vlm.vision_tokenizer.layers.3.0.to_out.weight": "model-00001-of-00004.safetensors",
637
+ "vlm.vision_tokenizer.layers.3.0.to_q.weight": "model-00001-of-00004.safetensors",
638
+ "vlm.vision_tokenizer.layers.3.1.0.bias": "model-00001-of-00004.safetensors",
639
+ "vlm.vision_tokenizer.layers.3.1.0.weight": "model-00001-of-00004.safetensors",
640
+ "vlm.vision_tokenizer.layers.3.1.1.weight": "model-00001-of-00004.safetensors",
641
+ "vlm.vision_tokenizer.layers.3.1.3.weight": "model-00001-of-00004.safetensors",
642
+ "vlm.vision_tokenizer.layers.4.0.norm_latents.bias": "model-00001-of-00004.safetensors",
643
+ "vlm.vision_tokenizer.layers.4.0.norm_latents.weight": "model-00001-of-00004.safetensors",
644
+ "vlm.vision_tokenizer.layers.4.0.norm_media.bias": "model-00001-of-00004.safetensors",
645
+ "vlm.vision_tokenizer.layers.4.0.norm_media.weight": "model-00001-of-00004.safetensors",
646
+ "vlm.vision_tokenizer.layers.4.0.to_kv.weight": "model-00001-of-00004.safetensors",
647
+ "vlm.vision_tokenizer.layers.4.0.to_out.weight": "model-00001-of-00004.safetensors",
648
+ "vlm.vision_tokenizer.layers.4.0.to_q.weight": "model-00001-of-00004.safetensors",
649
+ "vlm.vision_tokenizer.layers.4.1.0.bias": "model-00001-of-00004.safetensors",
650
+ "vlm.vision_tokenizer.layers.4.1.0.weight": "model-00001-of-00004.safetensors",
651
+ "vlm.vision_tokenizer.layers.4.1.1.weight": "model-00001-of-00004.safetensors",
652
+ "vlm.vision_tokenizer.layers.4.1.3.weight": "model-00001-of-00004.safetensors",
653
+ "vlm.vision_tokenizer.layers.5.0.norm_latents.bias": "model-00001-of-00004.safetensors",
654
+ "vlm.vision_tokenizer.layers.5.0.norm_latents.weight": "model-00001-of-00004.safetensors",
655
+ "vlm.vision_tokenizer.layers.5.0.norm_media.bias": "model-00001-of-00004.safetensors",
656
+ "vlm.vision_tokenizer.layers.5.0.norm_media.weight": "model-00001-of-00004.safetensors",
657
+ "vlm.vision_tokenizer.layers.5.0.to_kv.weight": "model-00001-of-00004.safetensors",
658
+ "vlm.vision_tokenizer.layers.5.0.to_out.weight": "model-00001-of-00004.safetensors",
659
+ "vlm.vision_tokenizer.layers.5.0.to_q.weight": "model-00001-of-00004.safetensors",
660
+ "vlm.vision_tokenizer.layers.5.1.0.bias": "model-00001-of-00004.safetensors",
661
+ "vlm.vision_tokenizer.layers.5.1.0.weight": "model-00001-of-00004.safetensors",
662
+ "vlm.vision_tokenizer.layers.5.1.1.weight": "model-00001-of-00004.safetensors",
663
+ "vlm.vision_tokenizer.layers.5.1.3.weight": "model-00001-of-00004.safetensors",
664
+ "vlm.vision_tokenizer.norm.bias": "model-00001-of-00004.safetensors",
665
+ "vlm.vision_tokenizer.norm.weight": "model-00001-of-00004.safetensors",
666
+ "vlm.vision_tokenizer.projection.bias": "model-00001-of-00004.safetensors",
667
+ "vlm.vision_tokenizer.projection.weight": "model-00001-of-00004.safetensors"
668
+ }
669
+ }
modeling_xgenmm.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModelForCausalLM
2
+ import torch
3
+ import open_clip
4
+ from typing import List, Optional, Tuple, Union
5
+ from utils import check_embedding_fns
6
+ from vlm import PerceiverResampler, Kosmos
7
+ from configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig
8
+
9
+ class XGenMMVisionEncoder(PreTrainedModel):
10
+ main_input_name = "pixel_values"
11
+ config_class = XGenMMVisionEncoderConfig
12
+
13
+ def __init__(self, config: XGenMMVisionEncoderConfig):
14
+ super().__init__(config)
15
+ if config.model_name != 'ViT-H-14-378-quickgelu':
16
+ raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
17
+ self.model, _, _ = open_clip.create_model_and_transforms(
18
+ model_name = config.model_name,
19
+ force_image_size=config.force_image_size
20
+ )
21
+
22
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
23
+ # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
24
+ return self.model.encode_image(pixel_values)
25
+
26
+
27
+ # vision tokenizer
28
+ class XGenMMVisionTokenizer(PreTrainedModel):
29
+ config_class = XGenMMVisionTokenizerConfig
30
+ def __init__(self, config: XGenMMVisionTokenizerConfig):
31
+ super().__init__(config)
32
+ self.model = PerceiverResampler(
33
+ dim=config.vis_feature_dim,
34
+ dim_inner=config.lang_embedding_dim,
35
+ )
36
+
37
+ def forward(self,
38
+ vision_features: torch.Tensor,
39
+ vision_attn_masks: torch.Tensor):
40
+ return self.model(vision_features, vision_attn_masks)
41
+
42
+ # XGenMM model
43
+ class XGenMMModelForConditionalGeneration(PreTrainedModel):
44
+ config_class = XGenMMConfig
45
+
46
+ def __init__(self, config: XGenMMConfig):
47
+ super().__init__(config)
48
+
49
+ # vision encoder initialization
50
+ vision_encoder = XGenMMVisionEncoder(config.vision_encoder_config).model
51
+ vision_encoder.visual.output_tokens = True
52
+ vision_encoder = vision_encoder.visual
53
+
54
+ # language model initialization
55
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
56
+ check_embedding_fns(language_model)
57
+ # Update _tied_weights_keys using the base model used.
58
+ if language_model._tied_weights_keys is not None:
59
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
60
+
61
+ # vision tokenizer initialization
62
+ if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
63
+ overwrite = language_model.get_input_embeddings().weight.shape[1]
64
+ config.vision_tokenizer_config.lang_embedding_dim = overwrite
65
+ print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
66
+
67
+ vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
68
+
69
+ self.vlm = Kosmos(
70
+ vision_encoder=vision_encoder,
71
+ vision_tokenizer=vision_tokenizer,
72
+ lang_model=language_model,
73
+ initial_tokenizer_len = config.text_config.initial_tokenizer_len,
74
+ pad_token_id = config.text_config.pad_token_id,
75
+ )
76
+ # Initialize weights and apply final processing
77
+ self.post_init()
78
+
79
+ @torch.no_grad()
80
+ def generate(
81
+ self,
82
+ pixel_values: torch.FloatTensor,
83
+ input_ids: Optional[torch.LongTensor] = None,
84
+ attention_mask: Optional[torch.LongTensor] = None,
85
+ **generate_kwargs,
86
+ ) -> torch.LongTensor:
87
+ self.vlm = self.vlm.eval()
88
+ return self.vlm.generate(
89
+ vision_x = pixel_values,
90
+ lang_x = input_ids,
91
+ attention_mask = attention_mask,
92
+ **generate_kwargs)
93
+
94
+ def update_special_tokens(self, tokenizer):
95
+ tokenizer.add_special_tokens(
96
+ {"additional_special_tokens": list(self.vlm.special_tokens.values())}
97
+ )
98
+ self.vlm.lang_model.config.vocab_size = len(tokenizer)
99
+ self.vlm.set_special_token_ids(
100
+ {
101
+ v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
102
+ }
103
+ )
104
+ return tokenizer
105
+
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_xgenmm.XGenMMImageProcessor"
4
+ },
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.48145466,
8
+ 0.4578275,
9
+ 0.40821073
10
+ ],
11
+ "image_processor_type": "XGenMMImageProcessor",
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "interpolation_mode": "bicubic",
18
+ "resize_mode": "squash",
19
+ "size": [
20
+ 378,
21
+ 378
22
+ ]
23
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
test_samples/few_shots.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "example_1": {
3
+ "image_path": "./test_samples/images/COCO_val2014_000000486568.jpg",
4
+ "instruction": "A short description of this image in one sentence:",
5
+ "output": "A man in a suit holding something in his office."
6
+ },
7
+ "example_2": {
8
+ "image_path": "./test_samples/images/COCO_val2014_000000176466.jpg",
9
+ "instruction": "A short description of this image in one sentence:",
10
+ "output": "The young girl is standing by the fire hydrant in curlers."
11
+ },
12
+ "example_3": {
13
+ "image_path": "./test_samples/images/COCO_val2014_000000392640.jpg",
14
+ "instruction": "A short description of this image in one sentence:",
15
+ "output": "A man with a skateboard that is jumping in the air."
16
+ },
17
+ "example_4": {
18
+ "image_path": "./test_samples/images/COCO_val2014_000000267408.jpg",
19
+ "instruction": "A short description of this image in one sentence:",
20
+ "output": "A few people looking at a television that's next to a laptop."
21
+ }
22
+ }
test_samples/images/000adfe5b817011c.jpg ADDED
test_samples/images/COCO_val2014_000000176466.jpg ADDED
test_samples/images/COCO_val2014_000000267408.jpg ADDED
test_samples/images/COCO_val2014_000000392640.jpg ADDED
test_samples/images/COCO_val2014_000000486568.jpg ADDED
test_samples/zero_shot.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "image_path": "./test_samples/images/000adfe5b817011c.jpg",
3
+ "instruction": "Please provide a short description of this image:"
4
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": true,
26
+ "single_word": false,
27
+ "special": false
28
+ },
29
+ "32000": {
30
+ "content": "<|endoftext|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "32001": {
38
+ "content": "<|assistant|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": true,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "32002": {
46
+ "content": "<|placeholder1|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": true,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "32003": {
54
+ "content": "<|placeholder2|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": true,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "32004": {
62
+ "content": "<|placeholder3|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": true,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "32005": {
70
+ "content": "<|placeholder4|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": true,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "32006": {
78
+ "content": "<|system|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": true,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "32007": {
86
+ "content": "<|end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": true,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "32008": {
94
+ "content": "<|placeholder5|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": true,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "32009": {
102
+ "content": "<|placeholder6|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": true,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "32010": {
110
+ "content": "<|user|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": true,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "32011": {
118
+ "content": "<pad>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": true
124
+ }
125
+ },
126
+ "bos_token": "<s>",
127
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
128
+ "clean_up_tokenization_spaces": false,
129
+ "eos_token": "<|endoftext|>",
130
+ "model_max_length": 4096,
131
+ "pad_token": "<pad>",
132
+ "padding_side": "left",
133
+ "sp_model_kwargs": {},
134
+ "tokenizer_class": "LlamaTokenizer",
135
+ "unk_token": "<unk>",
136
+ "use_default_system_prompt": false
137
+ }
utils.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import ast
3
+ import math
4
+ from PIL import Image
5
+
6
+
7
+ def has_fn(model, fn_name):
8
+ """Check if model has a function fn_name"""
9
+ return callable(getattr(model, fn_name, None))
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+ def num_params(module, filter_to_trainable=False):
15
+ """Returns the number of parameters in the module, or optionally only the trainable parameters"""
16
+ if filter_to_trainable:
17
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
18
+ else:
19
+ return sum(p.numel() for p in module.parameters())
20
+
21
+ def hasattr_recursive(obj, att):
22
+ """
23
+ Check if obj has nested attribute
24
+ Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
25
+ """
26
+ if att == "":
27
+ return True
28
+ i = att.find(".")
29
+ if i < 0:
30
+ return hasattr(obj, att)
31
+ else:
32
+ try:
33
+ return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
34
+ except:
35
+ return False
36
+
37
+ def getattr_recursive(obj, att):
38
+ """
39
+ Return nested attribute of obj
40
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
41
+ """
42
+ if att == "":
43
+ return obj
44
+ i = att.find(".")
45
+ if i < 0:
46
+ return getattr(obj, att)
47
+ else:
48
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
49
+
50
+
51
+ def setattr_recursive(obj, att, val):
52
+ """
53
+ Set nested attribute of obj
54
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
55
+ """
56
+ if "." in att:
57
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
58
+ setattr(obj, att.split(".")[-1], val)
59
+
60
+
61
+ def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
62
+ """
63
+ Stack a list of tensors with padding on one side
64
+ Args:
65
+ list_of_tensors (list[torch.Tensor]): List of tensors to stack
66
+ padding_value (int, optional): Value to pad with. Defaults to 0.
67
+ padding_side (str, optional): Side to pad on. Defaults to "right".
68
+ Returns:
69
+ torch.Tensor: Stacked tensors
70
+ """
71
+ max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
72
+ padded_tensors = []
73
+ for tensor in list_of_tensors:
74
+ num_tokens = tensor.size(0)
75
+ if len(tensor.size()) == 1:
76
+ padding = torch.full(
77
+ (max_tokens - num_tokens,),
78
+ padding_value,
79
+ dtype=tensor.dtype,
80
+ device=tensor.device,
81
+ )
82
+ else:
83
+ padding = torch.full(
84
+ (max_tokens - num_tokens, tensor.size(1)),
85
+ padding_value,
86
+ dtype=tensor.dtype,
87
+ device=tensor.device,
88
+ )
89
+ padded_tensor = (
90
+ torch.cat((tensor, padding), dim=0)
91
+ if padding_side == "right"
92
+ else torch.cat((padding, tensor), dim=0)
93
+ )
94
+ padded_tensors.append(padded_tensor)
95
+ return torch.stack(padded_tensors)
96
+
97
+
98
+ def check_embedding_fns(lang_model):
99
+ """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
100
+ if not has_fn(lang_model, "get_input_embeddings"):
101
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
102
+ lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
103
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
104
+ lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
105
+ else:
106
+ raise ValueError(
107
+ "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
108
+ )
109
+
110
+ if not has_fn(lang_model, "set_input_embeddings"):
111
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
112
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
113
+ lang_model, "transformer.wte", x
114
+ )
115
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
116
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
117
+ lang_model, "model.decoder.embed_tokens", x
118
+ )
119
+ else:
120
+ raise ValueError(
121
+ "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
122
+ )
123
+
124
+ if not has_fn(lang_model, "get_output_embeddings"):
125
+ if hasattr_recursive(lang_model, "lm_head"):
126
+ lang_model.get_output_embeddings = lambda: lang_model.lm_head
127
+ else:
128
+ raise ValueError(
129
+ "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
130
+ )
131
+
132
+ if not has_fn(lang_model, "set_output_embeddings"):
133
+ if hasattr_recursive(lang_model, "lm_head"):
134
+ lang_model.set_output_embeddings = lambda x: setattr_recursive(
135
+ lang_model, "lm_head", x
136
+ )
137
+ else:
138
+ raise ValueError(
139
+ "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
140
+ )
141
+
142
+
143
+ def has_fn(model, fn_name):
144
+ """Check if model has a function fn_name"""
145
+ return callable(getattr(model, fn_name, None))
146
+
147
+
148
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
149
+ #
150
+ # Licensed under the Apache License, Version 2.0 (the "License");
151
+ # you may not use this file except in compliance with the License.
152
+ # You may obtain a copy of the License at
153
+ #
154
+ # http://www.apache.org/licenses/LICENSE-2.0
155
+ #
156
+ # Unless required by applicable law or agreed to in writing, software
157
+ # distributed under the License is distributed on an "AS IS" BASIS,
158
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
159
+ # See the License for the specific language governing permissions and
160
+ # limitations under the License.
161
+
162
+ def unpad_image(tensor, original_size, keep_original_shape=False):
163
+ """
164
+ Unpads a PyTorch tensor of a padded and resized image.
165
+
166
+ Args:
167
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
168
+ original_size (tuple): The original size of the image (height, width).
169
+
170
+ Returns:
171
+ torch.Tensor: The unpadded image tensor.
172
+ """
173
+ original_width, original_height = original_size
174
+ current_height, current_width = tensor.shape[1:]
175
+
176
+ original_aspect_ratio = original_width / original_height
177
+ current_aspect_ratio = current_width / current_height
178
+
179
+ if original_aspect_ratio > current_aspect_ratio:
180
+ scale_factor = current_width / original_width
181
+ new_height = int(original_height * scale_factor)
182
+ padding = (current_height - new_height) // 2
183
+ if keep_original_shape:
184
+ attention_mask = torch.ones((current_height, current_width), device=tensor.device)
185
+ attention_mask[:padding, :] = 0
186
+ attention_mask[current_height - padding:, :] = 0
187
+ return tensor, attention_mask
188
+ else:
189
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
190
+ return unpadded_tensor, None
191
+ else:
192
+ scale_factor = current_height / original_height
193
+ new_width = int(original_width * scale_factor)
194
+ padding = (current_width - new_width) // 2
195
+ if keep_original_shape:
196
+ attention_mask = torch.ones((current_height, current_width), device=tensor.device)
197
+ attention_mask[:, :padding] = 0
198
+ attention_mask[:, current_width - padding:] = 0
199
+ return tensor, attention_mask
200
+ else:
201
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
202
+ return unpadded_tensor, None
203
+
204
+
205
+ def select_best_resolution(original_size, possible_resolutions):
206
+ """
207
+ Selects the best resolution from a list of possible resolutions based on the original size.
208
+
209
+ Args:
210
+ original_size (tuple): The original size of the image in the format (width, height).
211
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
212
+
213
+ Returns:
214
+ tuple: The best fit resolution in the format (width, height).
215
+ """
216
+ original_width, original_height = original_size
217
+ best_fit = None
218
+ max_effective_resolution = 0
219
+ min_wasted_resolution = float('inf')
220
+
221
+ for width, height in possible_resolutions:
222
+ scale = min(width / original_width, height / original_height)
223
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
224
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
225
+ wasted_resolution = (width * height) - effective_resolution
226
+
227
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
228
+ max_effective_resolution = effective_resolution
229
+ min_wasted_resolution = wasted_resolution
230
+ best_fit = (width, height)
231
+
232
+ return best_fit
233
+
234
+
235
+ def resize_and_pad_image(image, target_resolution):
236
+ """
237
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
238
+
239
+ Args:
240
+ image (PIL.Image.Image): The input image.
241
+ target_resolution (tuple): The target resolution (width, height) of the image.
242
+
243
+ Returns:
244
+ PIL.Image.Image: The resized and padded image.
245
+ """
246
+ original_width, original_height = image.size
247
+ target_width, target_height = target_resolution
248
+
249
+ scale_w = target_width / original_width
250
+ scale_h = target_height / original_height
251
+
252
+ if scale_w < scale_h:
253
+ new_width = target_width
254
+ new_height = min(math.ceil(original_height * scale_w), target_height)
255
+ else:
256
+ new_height = target_height
257
+ new_width = min(math.ceil(original_width * scale_h), target_width)
258
+
259
+ # Resize the image
260
+ resized_image = image.resize((new_width, new_height))
261
+
262
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
263
+ paste_x = (target_width - new_width) // 2
264
+ paste_y = (target_height - new_height) // 2
265
+ new_image.paste(resized_image, (paste_x, paste_y))
266
+
267
+ return new_image
268
+
269
+
270
+ def divide_to_patches(image, patch_size):
271
+ """
272
+ Divides an image into patches of a specified size.
273
+
274
+ Args:
275
+ image (PIL.Image.Image): The input image.
276
+ patch_size (int): The size of each patch.
277
+
278
+ Returns:
279
+ list: A list of PIL.Image.Image objects representing the patches.
280
+ """
281
+ patches = []
282
+ width, height = image.size
283
+ for i in range(0, height, patch_size):
284
+ for j in range(0, width, patch_size):
285
+ box = (j, i, j + patch_size, i + patch_size)
286
+ patch = image.crop(box)
287
+ patches.append(patch)
288
+
289
+ return patches
290
+
291
+
292
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
293
+ """
294
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
295
+
296
+ Args:
297
+ image_size (tuple): The size of the input image in the format (width, height).
298
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
299
+ patch_size (int): The size of each image patch.
300
+
301
+ Returns:
302
+ tuple: The shape of the image patch grid in the format (width, height).
303
+ """
304
+ if type(grid_pinpoints) is list:
305
+ possible_resolutions = grid_pinpoints
306
+ else:
307
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
308
+ width, height = select_best_resolution(image_size, possible_resolutions)
309
+ return width // patch_size, height // patch_size
310
+
311
+
312
+ def process_anyres_image(image, processor, grid_pinpoints):
313
+ """
314
+ Process an image with variable resolutions.
315
+
316
+ Args:
317
+ image (PIL.Image.Image): The input image to be processed.
318
+ processor: The image processor object.
319
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
320
+
321
+ Returns:
322
+ torch.Tensor: A tensor containing the processed image patches.
323
+ """
324
+ # FIXME: determine grid_pinpoints from image sizes.
325
+ if type(grid_pinpoints) is list:
326
+ possible_resolutions = grid_pinpoints
327
+ else:
328
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
329
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
330
+ image_padded = resize_and_pad_image(image, best_resolution)
331
+
332
+ processor_size = processor.transforms[0].size
333
+ patches = divide_to_patches(image_padded, processor_size[0])
334
+
335
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
336
+
337
+ image_patches = [image_original_resize] + patches
338
+ image_patches = [processor(image_patch)
339
+ for image_patch in image_patches]
340
+ return torch.stack(image_patches, dim=0)
341
+
342
+
343
+ def expand2square(pil_img, background_color):
344
+ width, height = pil_img.size
345
+ if width == height:
346
+ return pil_img
347
+ elif width > height:
348
+ result = Image.new(pil_img.mode, (width, width), background_color)
349
+ result.paste(pil_img, (0, (width - height) // 2))
350
+ return result
351
+ else:
352
+ result = Image.new(pil_img.mode, (height, height), background_color)
353
+ result.paste(pil_img, ((height - width) // 2, 0))
354
+ return result
355
+
356
+
357
+ def process_images(images, image_processor, model_cfg):
358
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
359
+ new_images = []
360
+ if image_aspect_ratio == 'pad':
361
+ for image in images:
362
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
363
+ image = image_processor(image)
364
+ new_images.append(image)
365
+ elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
366
+ base_img_size = image_processor.transforms[0].size[0]
367
+ for image in images:
368
+ image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
369
+ [base_img_size*2,base_img_size],
370
+ [base_img_size*2,base_img_size*2],
371
+ [base_img_size*3,base_img_size],
372
+ [base_img_size,base_img_size*3]])
373
+
374
+ # Debug any res inference by only using 672x672.
375
+ # image = process_anyres_image(image, image_processor, [[base_img_size*2,base_img_size*2]])
376
+ new_images.append(image)
377
+ else:
378
+ return image_processor(images)
379
+ if all(x.shape == new_images[0].shape for x in new_images):
380
+ new_images = torch.stack(new_images, dim=0)
381
+ return new_images
382
+
383
+
vlm.py ADDED
@@ -0,0 +1,1308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import einsum, nn
4
+ from einops import rearrange, repeat
5
+ from einops_exts import rearrange_many
6
+ from einops import rearrange
7
+ from typing import List, Optional, Tuple, Union
8
+ import torch.nn.functional as F
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from dataclasses import dataclass
11
+ from transformers import CLIPVisionModel
12
+ import transformers
13
+
14
+ from utils import num_params, getattr_recursive, stack_with_padding, get_anyres_image_grid_shape, unpad_image
15
+
16
+
17
+ class VisionTokenizer(nn.Module):
18
+ def __init__(self, dim_media, num_tokens_per_media):
19
+ super().__init__()
20
+ self.dim_media = dim_media
21
+ self.num_tokens_per_media = num_tokens_per_media
22
+
23
+ class PerceiverAttention(nn.Module):
24
+ def __init__(self, *, dim, dim_head=64, heads=8):
25
+ super().__init__()
26
+ self.scale = dim_head**-0.5
27
+ self.heads = heads
28
+ inner_dim = dim_head * heads
29
+
30
+ self.norm_media = nn.LayerNorm(dim)
31
+ self.norm_latents = nn.LayerNorm(dim)
32
+
33
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
34
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
35
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
36
+
37
+ def forward(self, x, latents, vision_attn_masks=None):
38
+ """
39
+ Args:
40
+ x (torch.Tensor): image features
41
+ shape (b, T, n1, D)
42
+ latent (torch.Tensor): latent features
43
+ shape (b, T, n2, D)
44
+ """
45
+ x = self.norm_media(x)
46
+ latents = self.norm_latents(latents)
47
+
48
+ h = self.heads
49
+
50
+ q = self.to_q(latents)
51
+ kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
52
+ if vision_attn_masks is not None:
53
+ vision_attn_masks = torch.cat((vision_attn_masks,
54
+ torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
55
+ dim=-1)
56
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
57
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
58
+ q = q * self.scale
59
+
60
+ # attention
61
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
62
+ # Apply vision attention mask here.
63
+ # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
64
+ if vision_attn_masks is not None:
65
+ attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
66
+ vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
67
+ attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
68
+ sim += attn_bias
69
+
70
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
71
+ attn = sim.softmax(dim=-1)
72
+
73
+
74
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
75
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
76
+ return self.to_out(out)
77
+
78
+
79
+ def FeedForward(dim, mult=4):
80
+ inner_dim = int(dim * mult)
81
+ return nn.Sequential(
82
+ nn.LayerNorm(dim),
83
+ nn.Linear(dim, inner_dim, bias=False),
84
+ nn.GELU(),
85
+ nn.Linear(inner_dim, dim, bias=False),
86
+ )
87
+
88
+
89
+ class PerceiverResampler(VisionTokenizer):
90
+ def __init__(
91
+ self,
92
+ *,
93
+ dim,
94
+ dim_inner=None,
95
+ depth=6,
96
+ dim_head=96,
97
+ heads=16,
98
+ num_latents=128,
99
+ max_num_media=None,
100
+ max_num_frames=None,
101
+ ff_mult=4,
102
+ ):
103
+ """
104
+ Perceiver module which takes in image features and outputs image tokens.
105
+ Args:
106
+ dim (int): dimension of the incoming image features
107
+ dim_inner (int, optional): final dimension to project the incoming image features to;
108
+ also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
109
+ depth (int, optional): number of layers. Defaults to 6.
110
+ dim_head (int, optional): dimension of each head. Defaults to 64.
111
+ heads (int, optional): number of heads. Defaults to 8.
112
+ num_latents (int, optional): number of latent tokens to use in the Perceiver;
113
+ also corresponds to number of tokens per sequence to output. Defaults to 64.
114
+ max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
115
+ and keep positional embeddings for. If None, no positional embeddings are used.
116
+ max_num_frames (int, optional): maximum number of frames to input into the Perceiver
117
+ and keep positional embeddings for. If None, no positional embeddings are used.
118
+ ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
119
+ """
120
+ if dim_inner is not None:
121
+ projection = nn.Linear(dim, dim_inner)
122
+ else:
123
+ projection = None
124
+ dim_inner = dim
125
+ super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
126
+ self.projection = projection
127
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
128
+ # positional embeddings
129
+ self.frame_embs = (
130
+ nn.Parameter(torch.randn(max_num_frames, dim))
131
+ if exists(max_num_frames)
132
+ else None
133
+ )
134
+ self.media_time_embs = (
135
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
136
+ if exists(max_num_media)
137
+ else None
138
+ )
139
+
140
+ self.layers = nn.ModuleList([])
141
+ for _ in range(depth):
142
+ self.layers.append(
143
+ nn.ModuleList(
144
+ [
145
+ PerceiverAttention(
146
+ dim=dim, dim_head=dim_head, heads=heads
147
+ ),
148
+ FeedForward(dim=dim, mult=ff_mult),
149
+ ]
150
+ )
151
+ )
152
+
153
+ self.norm = nn.LayerNorm(dim)
154
+
155
+ def forward(self, x):
156
+ """
157
+ Args:
158
+ x (torch.Tensor): image features
159
+ shape (b, T, F, v, D)
160
+ Returns:
161
+ shape (b, T, n, D) where n is self.num_latents
162
+ """
163
+ b, T, F, v = x.shape[:4]
164
+
165
+ # frame and media time embeddings
166
+ if exists(self.frame_embs):
167
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
168
+ x = x + frame_embs
169
+ x = rearrange(
170
+ x, "b T F v d -> b T (F v) d"
171
+ ) # flatten the frame and spatial dimensions
172
+ if exists(self.media_time_embs):
173
+ x = x + self.media_time_embs[:T]
174
+
175
+ # blocks
176
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
177
+ for attn, ff in self.layers:
178
+ latents = attn(x, latents) + latents
179
+ latents = ff(latents) + latents
180
+
181
+ if exists(self.projection):
182
+ return self.projection(self.norm(latents))
183
+ else:
184
+ return self.norm(latents)
185
+
186
+ class DecoupledEmbedding(nn.Embedding):
187
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
188
+ """
189
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
190
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
191
+ then it will create `num_additional_embeddings` additional parameters that are always trained. If
192
+ `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ max_original_id: int,
198
+ num_additional_embeddings: int = 0,
199
+ _weight: torch.Tensor = None,
200
+ num_original_embeddings: int = None,
201
+ embedding_dim: int = None,
202
+ partially_freeze=True,
203
+ device=None,
204
+ dtype=None,
205
+ pad_token_id=None,
206
+ ) -> None:
207
+ """
208
+ Args:
209
+ max_original_id (`int`):
210
+ The largest token id that should be embedded using the regular embedding (regular `weight`).
211
+ This is usually len(tokenizer) - 1 before additional tokens are added.
212
+ Note that this may not equal self.weight.shape[0]
213
+ num_additional_embeddings (`int`):
214
+ Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
215
+ _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
216
+ If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
217
+ num_original_embeddings (`int`):
218
+ self.weight.shape[0]
219
+ embedding_dim (`int`):
220
+ The size of each embedding vector
221
+ partially_freeze: (`bool`, *optional*, defaults to `True`):
222
+ If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
223
+ padding_idx (`int`, *optional*):
224
+ The padding index (needs to be less than num_embeddings)
225
+
226
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
227
+ `max_norm` or `norm_type`. We are not supporting these.
228
+ """
229
+ # validate args
230
+ if pad_token_id is not None and pad_token_id > max_original_id:
231
+ raise ValueError(
232
+ f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
233
+ + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
234
+ )
235
+ if _weight is not None:
236
+ assert (num_original_embeddings is None) or (
237
+ _weight.shape[0] == num_original_embeddings
238
+ ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
239
+ assert (embedding_dim is None) or (
240
+ _weight.shape[1] == embedding_dim
241
+ ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
242
+ num_original_embeddings = _weight.shape[0]
243
+ embedding_dim = _weight.shape[1]
244
+ else:
245
+ assert (
246
+ num_original_embeddings is not None
247
+ ), "num_original_embeddings must be provided if _weight is not provided"
248
+ assert (
249
+ embedding_dim is not None
250
+ ), "embedding_dim must be provided if _weight is not provided"
251
+
252
+ super().__init__(
253
+ num_embeddings=num_original_embeddings,
254
+ embedding_dim=embedding_dim,
255
+ device=device,
256
+ dtype=dtype,
257
+ padding_idx=pad_token_id,
258
+ _weight=_weight,
259
+ )
260
+ self.max_original_id = max_original_id
261
+ self.padding_idx = pad_token_id
262
+ self.num_additional_embeddings = num_additional_embeddings
263
+ if self.num_additional_embeddings > 0:
264
+ self.additional_embedding = nn.Embedding(
265
+ num_embeddings=self.num_additional_embeddings,
266
+ embedding_dim=embedding_dim,
267
+ device=device,
268
+ dtype=dtype,
269
+ )
270
+ self.set_requires_grad(
271
+ require_regular_grad=not partially_freeze, require_additional_grad=True
272
+ )
273
+
274
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
275
+ """
276
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
277
+ """
278
+ self.weight.requires_grad_(require_regular_grad)
279
+ self.additional_embedding.requires_grad_(require_additional_grad)
280
+
281
+ def forward(self, input_ids):
282
+ """
283
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
284
+ self.additional_embedding.weight that is being trained.
285
+
286
+ in order to make a lookup of the input ids, we:
287
+ 1. find out the indices of the entries belonging to the 2nd embedding
288
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
289
+ embedding starts from 0 and not num_embeddings
290
+ 3. perform the 2nd embedding lookup
291
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
292
+ 5. perform the 1st embedding lookup
293
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
294
+
295
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
296
+ then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
297
+ i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
298
+ usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
299
+ measure.
300
+
301
+ """
302
+ if self.num_additional_embeddings == 0:
303
+ return F.embedding(input_ids, self.weight)
304
+
305
+ # Clone so that we don't modify the original input_ids later on
306
+ input_ids = input_ids.clone()
307
+ additional_vocab_indices = torch.where(input_ids > self.max_original_id)
308
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
309
+ additional_embeddings = self.additional_embedding(
310
+ input_ids_additional_vocab - self.max_original_id - 1
311
+ )
312
+
313
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
314
+ input_ids[additional_vocab_indices] = 0
315
+ full_vector = F.embedding(input_ids, self.weight)
316
+
317
+ # overwrite the records with high indices
318
+ full_vector[additional_vocab_indices] = additional_embeddings
319
+
320
+ return full_vector
321
+
322
+ def extra_repr(self) -> str:
323
+ return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
324
+ self.max_original_id + 1,
325
+ self.num_additional_embeddings,
326
+ self.embedding_dim,
327
+ (not self.weight.requires_grad),
328
+ )
329
+
330
+
331
+ class DecoupledLinear(nn.Linear):
332
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
333
+ """
334
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
335
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
336
+ then it will create `additional_out_features * in_features` additional parameters that are always trained. If
337
+ `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
338
+ """
339
+
340
+ def __init__(
341
+ self,
342
+ max_original_id: int,
343
+ additional_out_features: int = 0,
344
+ _weight: torch.Tensor = None,
345
+ _bias: torch.Tensor = None,
346
+ in_features: int = None,
347
+ original_out_features: int = None,
348
+ bias: bool = True,
349
+ partially_freeze: bool = True,
350
+ device=None,
351
+ dtype=None,
352
+ ) -> None:
353
+ """
354
+ Args:
355
+ max_original_id (`int`): The largest token id that should be extracted from the regular weight.
356
+ This is usually len(tokenizer) - 1 before additional tokens are added.
357
+ Note that this may not equal original_out_features - 1
358
+ _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
359
+ If provided, this sets the `in_features` and `original_out_features` parameters.
360
+ _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
361
+ in_features: int. Input hidden size.
362
+ original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
363
+ additional_out_features: int. Number of additional trainable dimensions.
364
+ bias: bool. Whether to include a bias term.
365
+ partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
366
+ """
367
+ # argument validation
368
+ if _weight is not None:
369
+ assert (_weight.shape[0] == original_out_features) or (
370
+ original_out_features is None
371
+ ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
372
+ assert (_weight.shape[1] == in_features) or (
373
+ in_features is None
374
+ ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
375
+ in_features = _weight.shape[1]
376
+ original_out_features = _weight.shape[0]
377
+ else:
378
+ assert (
379
+ in_features is not None
380
+ ), "in_features must be provided if _weight is not provided"
381
+ assert (
382
+ original_out_features is not None
383
+ ), "original_out_features must be provided if _weight is not provided"
384
+
385
+ if _bias is not None:
386
+ assert bias is True, "bias must be True if _bias is provided"
387
+
388
+ # initialize original linear
389
+ super().__init__(
390
+ in_features,
391
+ original_out_features,
392
+ bias,
393
+ device,
394
+ dtype)
395
+
396
+ # set weight and bias manually
397
+ if _weight is not None:
398
+ self.weight = nn.Parameter(_weight)
399
+ if _bias is not None:
400
+ self.bias = nn.Parameter(_bias)
401
+
402
+ self.in_features = in_features
403
+ self.original_out_features = original_out_features
404
+ self.max_original_id = max_original_id
405
+
406
+ # initialize additional linear
407
+ self.additional_out_features = additional_out_features
408
+ self.has_bias = bias
409
+ if additional_out_features > 0:
410
+ self.additional_fc = nn.Linear(
411
+ in_features=in_features,
412
+ out_features=additional_out_features,
413
+ bias=self.has_bias,
414
+ device=device,
415
+ dtype=dtype,
416
+ )
417
+ self.set_requires_grad(
418
+ require_regular_grad=not partially_freeze, require_additional_grad=True
419
+ )
420
+
421
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
422
+ """
423
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
424
+ """
425
+ self.weight.requires_grad_(require_regular_grad)
426
+ if self.has_bias:
427
+ self.bias.requires_grad_(require_regular_grad)
428
+ self.additional_fc.requires_grad_(require_additional_grad)
429
+
430
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
431
+ output = F.linear(input, self.weight, self.bias)
432
+ output = output[..., : self.max_original_id + 1]
433
+
434
+ if self.additional_out_features > 0:
435
+ additional_features = F.linear(
436
+ input, self.additional_fc.weight, self.additional_fc.bias
437
+ )
438
+ output = torch.cat((output, additional_features), -1)
439
+ return output
440
+
441
+ def extra_repr(self) -> str:
442
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
443
+ return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
444
+ self.in_features,
445
+ self.max_original_id + 1,
446
+ self.additional_out_features,
447
+ self.bias is not None,
448
+ (not self.weight.requires_grad or not self.bias.requires_grad),
449
+ )
450
+
451
+ class VLM(nn.Module):
452
+ """
453
+ Generic vision-language model (VLM) class.
454
+ A VLM consists of four components:
455
+ 1. A vision encoder that extracts features from pixels, e.g. CLIP
456
+ input: (B, T_img, F, C, H, W)
457
+ output: (B, T_img, F, v, d)
458
+ 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
459
+ input: (B, T_img, F, v, d)
460
+ output: (B, T_img, n, d)
461
+ 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
462
+ 4. A language model
463
+ """
464
+
465
+ def __init__(
466
+ self,
467
+ vision_encoder: nn.Module,
468
+ vision_tokenizer: nn.Module,
469
+ lang_model: nn.Module,
470
+ initial_tokenizer_len: int,
471
+ pad_token_id: int,
472
+ gradient_checkpointing: bool = False,
473
+ ):
474
+ """
475
+ Args:
476
+ vision_encoder (nn.Module): e.g. CLIP
477
+ vision_tokenizer (nn.Module): e.g. PerceiverResampler
478
+ lang_model (nn.Module): e.g. MPT
479
+ initial_tokenizer_len (int): size of the original tokenizer vocab
480
+ pad_token_id (int): id of the pad token
481
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
482
+ """
483
+ super().__init__()
484
+
485
+ # save dimension information
486
+ self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
487
+ if hasattr(lang_model.config, "d_model"):
488
+ self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
489
+ else:
490
+ self.lang_hidden_dim = lang_model.config.hidden_size
491
+ self.vis_embedding_dim = vision_tokenizer.dim_media
492
+ self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
493
+
494
+ # core components
495
+ self.vision_encoder = vision_encoder
496
+ self.vision_tokenizer = vision_tokenizer
497
+ self.lang_model = lang_model
498
+
499
+ # lm embeddings
500
+ self.pad_token_id = pad_token_id
501
+ self.initial_tokenizer_len = initial_tokenizer_len
502
+ input_embeds = DecoupledEmbedding(
503
+ max_original_id=initial_tokenizer_len - 1,
504
+ num_additional_embeddings=len(self.special_tokens),
505
+ _weight=self.lang_model.get_input_embeddings().weight,
506
+ pad_token_id=self.pad_token_id,
507
+ )
508
+ if hasattr(input_embeds, "additional_embedding"):
509
+ input_embeds.additional_embedding.weight.data.normal_(
510
+ mean=0.0,
511
+ std=self.lang_model.config.initializer_range
512
+ if hasattr(self.lang_model.config, "initializer_range")
513
+ else 0.02,
514
+ )
515
+ self.lang_model.set_input_embeddings(input_embeds)
516
+
517
+ out_embeds = DecoupledLinear(
518
+ max_original_id=initial_tokenizer_len - 1,
519
+ additional_out_features=len(self.special_tokens),
520
+ _weight=self.lang_model.get_output_embeddings().weight,
521
+ _bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
522
+ )
523
+ if hasattr(out_embeds, "additional_fc"):
524
+ out_embeds.additional_fc.weight.data.normal_(
525
+ mean=0.0,
526
+ std=self.lang_model.config.initializer_range
527
+ if hasattr(self.lang_model.config, "initializer_range")
528
+ else 0.02,
529
+ )
530
+ self.lang_model.set_output_embeddings(out_embeds)
531
+
532
+ # gradient checkpointing
533
+ self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
534
+
535
+ def forward(
536
+ self,
537
+ vision_x: Optional[torch.Tensor],
538
+ lang_x: torch.Tensor,
539
+ attention_mask: Optional[torch.Tensor] = None,
540
+ labels: Optional[torch.Tensor] = None,
541
+ past_key_values: Optional[
542
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
543
+ ] = None,
544
+ past_media_locations: Optional[torch.Tensor] = None,
545
+ past_vision_tokens: Optional[torch.Tensor] = None,
546
+ use_cache: Optional[bool] = False,
547
+ **kwargs,
548
+ ):
549
+ """
550
+ Args:
551
+ vision_x: Vision input
552
+ shape (B, T_img, F, C, H, W) with F=1
553
+ only F = 1 is supported (single-frame videos)
554
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
555
+ only the first number of media tokens in lang_x are used
556
+ lang_x: Language input ids, with media tokens denoting where
557
+ visual media should be inserted.
558
+ shape (B, T_txt)
559
+ attention_mask: Attention mask. Defaults to None.
560
+ labels: Labels. Defaults to None.
561
+ shape (B, T_txt)
562
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
563
+ list of length = number of decoder layers in the LM
564
+ exact implementation depends on LM, see Hugging Face docs
565
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
566
+ shape (B, T_txt)
567
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
568
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
569
+ If True, includes key_values, media_locations, and vision_tokens in the output.
570
+ """
571
+ assert not (past_vision_tokens is None) ^ (
572
+ past_media_locations is None
573
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
574
+
575
+ # convert pixels to vision tokens
576
+ if vision_x is not None:
577
+ vision_features = self._encode_vision_x(vision_x=vision_x)
578
+ vision_tokens = self.vision_tokenizer(vision_features)
579
+ else:
580
+ vision_tokens = None
581
+
582
+ # fuse the vision and language tokens
583
+ new_inputs = self._prepare_inputs_for_forward(
584
+ vision_tokens=vision_tokens,
585
+ lang_x=lang_x,
586
+ attention_mask=attention_mask,
587
+ labels=labels,
588
+ past_key_values=past_key_values,
589
+ past_media_locations=past_media_locations,
590
+ padding_side="right",
591
+ past_vision_tokens=past_vision_tokens,
592
+ )
593
+ output = self.lang_model(
594
+ **new_inputs,
595
+ use_cache=use_cache,
596
+ past_key_values=past_key_values,
597
+ **kwargs,
598
+ )
599
+
600
+ # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
601
+ # or to add the past_vision_tokens and past_media_locations to the output
602
+ output = self._postprocess_outputs_from_forward(
603
+ output=output,
604
+ lang_x=lang_x,
605
+ vision_tokens=vision_tokens,
606
+ use_cache=use_cache,
607
+ past_vision_tokens=past_vision_tokens,
608
+ past_media_locations=past_media_locations,
609
+ )
610
+
611
+ # postforward hooks
612
+ self._post_forward_hook()
613
+ return output
614
+
615
+ def _encode_vision_x_anyres(self, samples, device):
616
+ image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
617
+ image_sizes = samples["image_size"]
618
+
619
+ # concate list of patches into one big patch for any res encoding.
620
+ images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
621
+ image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
622
+ image = image.to(device)
623
+
624
+ with torch.no_grad():
625
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
626
+ image_embeds = self.vision_encoder.trunk.forward_features(image)
627
+ elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
628
+ image_embeds = self.vision_encoder(image).last_hidden_state
629
+ else:
630
+ image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
631
+
632
+ if isinstance(self.vision_encoder, CLIPVisionModel):
633
+ base_img_size = self.vision_encoder.config.image_size
634
+ else:
635
+ base_img_size = self.vision_encoder.image_size[0]
636
+
637
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
638
+ grid_size = self.vision_encoder.trunk.patch_embed.grid_size
639
+ elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
640
+ grid_size_base = self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size
641
+ grid_size = (grid_size_base, grid_size_base)
642
+ else:
643
+ grid_size = self.vision_encoder.grid_size
644
+ height, width = grid_size
645
+
646
+ if not image_embeds.shape[1] == height * width:
647
+ assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
648
+ image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
649
+ n_vis_token_per_patch = image_embeds.shape[1]
650
+
651
+ # Split encoded patches and merge patch features
652
+ # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
653
+ split_sizes = [image.shape[0] for image in images]
654
+ image_embeds = torch.split(image_embeds, split_sizes, dim=0)
655
+ # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
656
+ new_image_embeds = []
657
+ patch_attn_masks = []
658
+ max_n_img_token = -1
659
+ for idx, patch_embeds in enumerate(image_embeds):
660
+ if patch_embeds.shape[0] > 1:
661
+ # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
662
+ base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
663
+ patch_embeds = patch_embeds[1:]
664
+
665
+ assert height * width == base_patch_embeds.shape[0]
666
+
667
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
668
+ [[base_img_size,base_img_size*2],
669
+ [base_img_size*2,base_img_size],
670
+ [base_img_size*2,base_img_size*2],
671
+ [base_img_size*3,base_img_size],
672
+ [base_img_size,base_img_size*3]],
673
+ base_img_size) # Hardcoded grid_pinpoints.
674
+ patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
675
+
676
+ patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
677
+ patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
678
+ # TODO: add an option that return masked patch_embeds instead of trimmed.
679
+ patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
680
+ if hasattr(self, 'image_newline'):
681
+ patch_embeds = torch.cat((
682
+ patch_embeds,
683
+ self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
684
+ ), dim=-1)
685
+ if self.anyres_patch_sampling:
686
+ patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
687
+ patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
688
+ assert patch_attn_mask is not None
689
+ patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
690
+ patch_attn_mask = patch_attn_mask.flatten(0, 1)
691
+ patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
692
+ patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
693
+ else:
694
+ patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
695
+ patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
696
+ else:
697
+ patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
698
+ patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
699
+ if hasattr(self, 'image_newline'):
700
+ patch_embeds = torch.cat((
701
+ patch_embeds,
702
+ self.image_newline[None]
703
+ ), dim=0)
704
+ if not self.anyres_patch_sampling:
705
+ max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
706
+
707
+ new_image_embeds.append(patch_embeds)
708
+ patch_attn_masks.append(patch_attn_mask)
709
+
710
+ if self.anyres_patch_sampling:
711
+ # Return individual patches for independent token downsampling.
712
+ return new_image_embeds, patch_attn_masks
713
+
714
+ # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
715
+ image_embeds = []
716
+ image_atts = []
717
+ for image_embed in new_image_embeds:
718
+ n_img_token = image_embed.shape[0]
719
+ img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
720
+ if n_img_token < max_n_img_token:
721
+ padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
722
+ padded_embed[:n_img_token, :] = image_embed
723
+ img_attn[n_img_token:] = 0 # Mask out the padded entries.
724
+ else:
725
+ padded_embed = image_embed
726
+ image_embeds.append(padded_embed)
727
+ image_atts.append(img_attn)
728
+ image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
729
+ image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
730
+ # TODO: reshape image_embeds and image_atts to "b T F v d"
731
+ image_embeds = image_embeds[:, None, None, :, :]
732
+ # image_atts = image_atts[:, None, None, :, :]
733
+
734
+ return image_embeds, image_atts
735
+
736
+ def _encode_vision_x(self, vision_x: torch.Tensor):
737
+ """
738
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
739
+ Args:
740
+ vision_x: Vision input
741
+ shape (B, T_img, F, C, H, W)
742
+ Images in the same chunk are collated along T_img, and frames are collated along F
743
+ Currently only F=1 is supported (single-frame videos)
744
+
745
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
746
+ """
747
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
748
+ b, T, F = vision_x.shape[:3]
749
+
750
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
751
+ with torch.no_grad():
752
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
753
+ vision_x = self.vision_encoder.trunk.forward_features(vision_x)
754
+ elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
755
+ vision_x = self.vision_encoder(vision_x).last_hidden_state
756
+ else:
757
+ vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
758
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
759
+ return vision_x
760
+
761
+ def _concat_vision_cache(
762
+ self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
763
+ ):
764
+ """
765
+ Helper function to include the past vision tokens and past media locations in the output.
766
+ """
767
+ if use_cache:
768
+ if past_media_locations is not None and past_vision_tokens is not None:
769
+ if vision_tokens is not None:
770
+ updated_vision_tokens = torch.cat(
771
+ [
772
+ past_vision_tokens,
773
+ vision_tokens,
774
+ ],
775
+ dim=1,
776
+ )
777
+ else:
778
+ updated_vision_tokens = past_vision_tokens
779
+ updated_media_locations = torch.cat(
780
+ [
781
+ past_media_locations,
782
+ lang_x == self.media_token_id,
783
+ ],
784
+ dim=1,
785
+ )
786
+ else:
787
+ updated_vision_tokens = vision_tokens
788
+ updated_media_locations = lang_x == self.media_token_id
789
+
790
+ else:
791
+ updated_vision_tokens = None
792
+ updated_media_locations = None
793
+
794
+ return updated_vision_tokens, updated_media_locations
795
+
796
+ def generate(
797
+ self,
798
+ vision_x: torch.Tensor,
799
+ lang_x: torch.Tensor,
800
+ attention_mask: torch.Tensor = None,
801
+ past_key_values: Optional[
802
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
803
+ ] = None,
804
+ past_media_locations: Optional[torch.Tensor] = None,
805
+ past_vision_tokens: Optional[torch.Tensor] = None,
806
+ **kwargs,
807
+ ):
808
+ """
809
+ Generate text conditioned on vision and language inputs.
810
+ Args:
811
+ vision_x (torch.Tensor): Vision input
812
+ shape (B, T_img, F, C, H, W)
813
+ see documentation for forward
814
+ lang_x (torch.Tensor): Language input
815
+ shape (B, T_txt)
816
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
817
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
818
+ Returns:
819
+ torch.Tensor: lang_x with generated tokens appended to it
820
+ """
821
+ num_beams = kwargs.pop("num_beams", 1)
822
+
823
+ # convert pixels to vision tokens
824
+ if vision_x is not None:
825
+ vision_features = self._encode_vision_x(vision_x=vision_x)
826
+ vision_tokens = self.vision_tokenizer(vision_features)
827
+ else:
828
+ vision_tokens = None
829
+
830
+ # fuse the vision and language tokens
831
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
832
+ # the total batch size is B * num_beams
833
+ new_inputs = self._prepare_inputs_for_forward(
834
+ vision_tokens=vision_tokens,
835
+ lang_x=lang_x,
836
+ attention_mask=attention_mask,
837
+ past_key_values=past_key_values,
838
+ past_media_locations=past_media_locations,
839
+ past_vision_tokens=past_vision_tokens,
840
+ padding_side="left",
841
+ num_beams=num_beams,
842
+ )
843
+ output = self.lang_model.generate(
844
+ **new_inputs,
845
+ past_key_values=past_key_values,
846
+ num_beams=num_beams,
847
+ use_cache=True,
848
+ **kwargs,
849
+ )
850
+ self._post_forward_hook()
851
+ return output
852
+
853
+ @property
854
+ def num_trainable_params(self):
855
+ """Print the number of trainable parameters"""
856
+ return num_params(self, filter_to_trainable=True)
857
+
858
+ def set_trainable(self):
859
+ """
860
+ Freeze appropriate parameters in the model.
861
+ """
862
+ raise NotImplementedError
863
+
864
+ def group_params_by_weight_decay(self):
865
+ """
866
+ Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
867
+ """
868
+ params_with_wd, params_without_wd = [], []
869
+ for n, p in self.named_parameters():
870
+ if p.requires_grad:
871
+ if self._should_apply_weight_decay(n):
872
+ params_with_wd.append(p)
873
+ else:
874
+ params_without_wd.append(p)
875
+ return params_with_wd, params_without_wd
876
+
877
+ def _should_apply_weight_decay(self, parameter_name):
878
+ """
879
+ Return whether weight decay should be applied to a parameter.
880
+ """
881
+ raise NotImplementedError
882
+
883
+ @property
884
+ def special_tokens(self):
885
+ """
886
+ Returns a dict mapping from the attribute name of a special token to its string format,
887
+ e.g. "media_token": "<image>"
888
+ """
889
+ assert (
890
+ "media_token" in self._special_tokens
891
+ ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
892
+ return self._special_tokens
893
+
894
+ @property
895
+ def special_token_ids(self):
896
+ """
897
+ Returns a list of the special token ids
898
+ """
899
+ return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
900
+
901
+ def set_special_token_ids(self, string_to_ids):
902
+ """
903
+ Args:
904
+ string_to_ids (dict): mapping from token string to id
905
+ """
906
+ assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
907
+ for att_name, token_str in self.special_tokens.items():
908
+ token_id = string_to_ids[token_str]
909
+ setattr(self, f"{att_name}_id", token_id)
910
+ setattr(self.lang_model, f"{att_name}_id", token_id)
911
+
912
+ def init_gradient_checkpointing(self):
913
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
914
+ checkpoint_wrapper,
915
+ CheckpointWrapper,
916
+ CheckpointImpl,
917
+ apply_activation_checkpointing,
918
+ )
919
+ from functools import partial
920
+
921
+ non_reentrant_wrapper = partial(
922
+ checkpoint_wrapper,
923
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
924
+ )
925
+ apply_activation_checkpointing(
926
+ self,
927
+ checkpoint_wrapper_fn=non_reentrant_wrapper,
928
+ check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
929
+ and not isinstance(m, CheckpointWrapper),
930
+ )
931
+
932
+ @dataclass
933
+ class VLMOutputWithPast(CausalLMOutputWithPast):
934
+ """
935
+ VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
936
+ past_media_locations: Optional[torch.Tensor] = None,
937
+ past_vision_tokens: Optional[torch.Tensor] = None,
938
+ """
939
+
940
+ past_media_locations: Optional[torch.Tensor] = None
941
+ past_vision_tokens: Optional[torch.Tensor] = None
942
+
943
+
944
+ def exists(val):
945
+ return val is not None
946
+
947
+
948
+ def FeedForward(dim, mult=4):
949
+ inner_dim = int(dim * mult)
950
+ return nn.Sequential(
951
+ nn.LayerNorm(dim),
952
+ nn.Linear(dim, inner_dim, bias=False),
953
+ nn.GELU(),
954
+ nn.Linear(inner_dim, dim, bias=False),
955
+ )
956
+
957
+ class VLMWithLanguageStream(VLM):
958
+ """
959
+ VLM that fuses modalities by inserting vision tokens directly into the language stream.
960
+ """
961
+
962
+ def __init__(
963
+ self,
964
+ vision_encoder: nn.Module,
965
+ vision_tokenizer: nn.Module,
966
+ lang_model: nn.Module,
967
+ initial_tokenizer_len: int,
968
+ pad_token_id: int,
969
+ decoder_layers_attr_name: str = None,
970
+ gradient_checkpointing: bool = False,
971
+ ):
972
+ super().__init__(
973
+ vision_encoder=vision_encoder,
974
+ vision_tokenizer=vision_tokenizer,
975
+ lang_model=lang_model,
976
+ initial_tokenizer_len=initial_tokenizer_len,
977
+ pad_token_id=pad_token_id,
978
+ gradient_checkpointing=gradient_checkpointing,
979
+ )
980
+ self.decoder_layers_attr_name = decoder_layers_attr_name
981
+ if decoder_layers_attr_name is not None:
982
+ for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
983
+ block._use_gradient_checkpointing = gradient_checkpointing
984
+
985
+ def _prepare_inputs_for_forward(
986
+ self,
987
+ vision_tokens: torch.Tensor,
988
+ lang_x: torch.Tensor,
989
+ attention_mask: torch.Tensor,
990
+ labels: torch.Tensor = None,
991
+ past_key_values=None,
992
+ past_media_locations: torch.Tensor = None,
993
+ past_vision_tokens: torch.Tensor = None,
994
+ padding_side: str = "left",
995
+ num_beams: int = 1,
996
+ ):
997
+ """
998
+ Insert the vision tokens directly into the language stream/
999
+ This requires us to modify the input_ids, attention_mask, and labels.
1000
+ """
1001
+ if past_key_values is not None:
1002
+ past_len = past_key_values[0][0].shape[2]
1003
+ assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1004
+ "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1005
+ + "Check that you've expanded the attention mask to account for past image tokens."
1006
+ )
1007
+
1008
+ if vision_tokens is None:
1009
+ return {
1010
+ "input_ids": lang_x,
1011
+ "attention_mask": attention_mask,
1012
+ "labels": labels,
1013
+ }
1014
+
1015
+ # get the language embeddings
1016
+ lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1017
+
1018
+ # build up the multimodal embeddings
1019
+ B = lang_x.shape[0]
1020
+ has_labels = labels is not None
1021
+ multimodal_embeds = []
1022
+ multimodal_attention_mask = []
1023
+ multimodal_labels = [] if has_labels else None
1024
+ for i in range(B):
1025
+ # get index of <image> tokens in lang_x[i]
1026
+ image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1027
+
1028
+ if len(image_token_idxs) == 0:
1029
+ multimodal_embeds.append(lang_embeds[i].clone())
1030
+ multimodal_attention_mask.append(attention_mask[i].clone())
1031
+ if has_labels:
1032
+ multimodal_labels.append(labels[i].clone())
1033
+ continue
1034
+
1035
+ # # since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
1036
+ # for j, img_idx in enumerate(image_token_idxs):
1037
+ # image_token_idxs[j] += (self.num_tokens_per_vis - 1) * j
1038
+
1039
+ # loop through the image_token_idxs and insert the vision tokens
1040
+ new_embed = lang_embeds[i].clone()
1041
+ new_attention_mask = (
1042
+ attention_mask[i].clone() if attention_mask is not None else None
1043
+ )
1044
+ if has_labels:
1045
+ new_label = labels[i].clone()
1046
+
1047
+ for img_num, img_idx in enumerate(image_token_idxs):
1048
+ new_embed = torch.cat(
1049
+ (
1050
+ new_embed[:img_idx],
1051
+ vision_tokens[i][img_num],
1052
+ new_embed[img_idx + self.num_tokens_per_vis :],
1053
+ ),
1054
+ dim=0,
1055
+ )
1056
+ new_attention_mask = torch.cat(
1057
+ (
1058
+ new_attention_mask[:img_idx],
1059
+ torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
1060
+ attention_mask.device
1061
+ ),
1062
+ new_attention_mask[img_idx + self.num_tokens_per_vis :],
1063
+ ),
1064
+ dim=0,
1065
+ )
1066
+ if has_labels:
1067
+ new_label = torch.cat(
1068
+ (
1069
+ new_label[:img_idx],
1070
+ torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
1071
+ labels.device
1072
+ )
1073
+ * -100,
1074
+ new_label[img_idx + self.num_tokens_per_vis :],
1075
+ ),
1076
+ dim=0,
1077
+ )
1078
+ multimodal_embeds.append(new_embed)
1079
+ multimodal_attention_mask.append(new_attention_mask)
1080
+ if has_labels:
1081
+ multimodal_labels.append(new_label)
1082
+
1083
+ # stack
1084
+ multimodal_embeds = stack_with_padding(
1085
+ multimodal_embeds,
1086
+ padding_value=self.pad_token_id,
1087
+ padding_side=padding_side,
1088
+ )
1089
+ multimodal_attention_mask = stack_with_padding(
1090
+ multimodal_attention_mask,
1091
+ padding_value=0,
1092
+ padding_side=padding_side,
1093
+ )
1094
+ if has_labels:
1095
+ multimodal_labels = stack_with_padding(
1096
+ multimodal_labels,
1097
+ padding_value=-100,
1098
+ padding_side=padding_side,
1099
+ )
1100
+
1101
+ return {
1102
+ "inputs_embeds": multimodal_embeds,
1103
+ "attention_mask": multimodal_attention_mask,
1104
+ "labels": multimodal_labels,
1105
+ }
1106
+
1107
+ def _postprocess_outputs_from_forward(
1108
+ self,
1109
+ output: CausalLMOutputWithPast,
1110
+ lang_x: torch.Tensor,
1111
+ vision_tokens: torch.Tensor,
1112
+ past_vision_tokens: torch.Tensor,
1113
+ past_media_locations: torch.Tensor,
1114
+ use_cache: bool = False,
1115
+ ):
1116
+ # Include the past vision tokens and past media locations in the output
1117
+ updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1118
+ lang_x=lang_x,
1119
+ vision_tokens=vision_tokens,
1120
+ past_vision_tokens=past_vision_tokens,
1121
+ past_media_locations=past_media_locations,
1122
+ use_cache=use_cache,
1123
+ )
1124
+
1125
+ # return logits that are the same shape as the original input_ids
1126
+ logits = output.logits
1127
+ batch_logits = []
1128
+ B, T_txt = lang_x.shape
1129
+ for i in range(B):
1130
+ sequence_logits = []
1131
+ logits_j = 0
1132
+ for j in range(T_txt):
1133
+ if lang_x[i, j] != self.media_token_id:
1134
+ sequence_logits.append(logits[i, logits_j])
1135
+ logits_j += 1
1136
+ else:
1137
+ # append the logit for the first image token, then skip over the rest
1138
+ # note: the model actually learns to predict <im_patch>, not <image>
1139
+ sequence_logits.append(logits[i, logits_j])
1140
+ logits_j += self.num_tokens_per_vis
1141
+ sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1142
+ batch_logits.append(sequence_logits)
1143
+
1144
+ batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1145
+ # The final logits shape should be the same as the original input_ids shape
1146
+ assert batch_logits.shape[:2] == (B, T_txt)
1147
+
1148
+ # assemble the output
1149
+ output = VLMOutputWithPast(
1150
+ loss=output.loss,
1151
+ logits=batch_logits,
1152
+ past_key_values=output.past_key_values,
1153
+ hidden_states=output.hidden_states,
1154
+ attentions=output.attentions,
1155
+ past_media_locations=updated_media_locations,
1156
+ past_vision_tokens=updated_vision_tokens,
1157
+ )
1158
+
1159
+ return output
1160
+
1161
+ def _post_forward_hook(self):
1162
+ pass
1163
+
1164
+
1165
+ @property
1166
+ def num_params_per_module(self):
1167
+ """Print the number of parameters per module in the model"""
1168
+ return "\n".join(
1169
+ [
1170
+ f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1171
+ f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1172
+ f"Language model: {num_params(self.lang_model):,} parameters",
1173
+ ]
1174
+ )
1175
+
1176
+ @property
1177
+ def num_trainable_params_per_module(self):
1178
+ """Print the number of trainable parameters per module in the model"""
1179
+ return "\n".join(
1180
+ [
1181
+ f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1182
+ f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1183
+ f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1184
+ ]
1185
+ )
1186
+
1187
+
1188
+ class Kosmos(VLMWithLanguageStream):
1189
+ def __init__(
1190
+ self,
1191
+ vision_encoder: nn.Module,
1192
+ vision_tokenizer: nn.Module,
1193
+ lang_model: nn.Module,
1194
+ initial_tokenizer_len: int,
1195
+ pad_token_id: int,
1196
+ decoder_layers_attr_name: str = None,
1197
+ gradient_checkpointing: bool = False,
1198
+ ):
1199
+ """
1200
+ Args:
1201
+ vision_encoder (nn.Module): HF CLIPModel
1202
+ lang_encoder (nn.Module): HF causal language model
1203
+ vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1204
+ initial_tokenizer_len (int): size of the tokenizer vocab
1205
+ padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1206
+ will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1207
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1208
+ gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1209
+ """
1210
+ self._special_tokens = {
1211
+ "media_token": "<image>",
1212
+ "image_placeholder_token": "<image placeholder>",
1213
+ "end_of_trunk_token": "<|endofchunk|>"
1214
+ }
1215
+ super().__init__(
1216
+ vision_encoder=vision_encoder,
1217
+ vision_tokenizer=vision_tokenizer,
1218
+ lang_model=lang_model,
1219
+ initial_tokenizer_len=initial_tokenizer_len,
1220
+ gradient_checkpointing=gradient_checkpointing,
1221
+ decoder_layers_attr_name=decoder_layers_attr_name,
1222
+ pad_token_id=pad_token_id
1223
+ )
1224
+
1225
+ # def set_trainable(self):
1226
+ # """
1227
+ # Unfreeze everything except the vision_encoder
1228
+ # """
1229
+ # self.requires_grad_(True)
1230
+ # self.vision_encoder.requires_grad_(False)
1231
+
1232
+ def set_trainable(self, unfreeze_vision_encoder: bool = False):
1233
+ """
1234
+ Unfreeze everything except the vision_encoder
1235
+ """
1236
+ self.requires_grad_(True)
1237
+ self.vision_encoder.requires_grad_(unfreeze_vision_encoder)
1238
+
1239
+ def _should_apply_weight_decay(self, parameter_name):
1240
+ """
1241
+ Kosmos applies 0.01 weight deacy to everything
1242
+ """
1243
+ return True
1244
+
1245
+ def generate(
1246
+ self,
1247
+ vision_x: torch.Tensor,
1248
+ lang_x: torch.Tensor,
1249
+ attention_mask: torch.Tensor = None,
1250
+ past_key_values: Optional[
1251
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1252
+ ] = None,
1253
+ past_media_locations: Optional[torch.Tensor] = None,
1254
+ past_vision_tokens: Optional[torch.Tensor] = None,
1255
+ **kwargs
1256
+ ):
1257
+ """
1258
+ Generate text conditioned on vision and language inputs.
1259
+ Args:
1260
+ vision_x (torch.Tensor): Vision input
1261
+ shape (B, T_img, F, C, H, W)
1262
+ see documentation for forward
1263
+ lang_x (torch.Tensor): Language input
1264
+ shape (B, T_txt)
1265
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1266
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1267
+ Returns:
1268
+ torch.Tensor: lang_x with generated tokens appended to it
1269
+ """
1270
+ num_beams = kwargs.pop("num_beams", 1)
1271
+
1272
+ # convert pixels to vision tokens
1273
+ if vision_x is not None:
1274
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1275
+ vision_tokens = self.vision_tokenizer(vision_features)
1276
+ else:
1277
+ vision_tokens = None
1278
+
1279
+ # fuse the vision and language tokens
1280
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1281
+ # the total batch size is B * num_beams
1282
+ new_inputs = self._prepare_inputs_for_forward(
1283
+ vision_tokens=vision_tokens,
1284
+ lang_x=lang_x,
1285
+ attention_mask=attention_mask,
1286
+ past_key_values=past_key_values,
1287
+ past_media_locations=past_media_locations,
1288
+ past_vision_tokens=past_vision_tokens,
1289
+ padding_side="left",
1290
+ num_beams=num_beams,
1291
+ )
1292
+
1293
+ if transformers.__version__ == '4.41.0.dev0':
1294
+ output = self.lang_model.generate(
1295
+ **new_inputs,
1296
+ num_beams=num_beams,
1297
+ use_cache=True,
1298
+ eos_token_id=self.end_of_trunk_token_id,
1299
+ **kwargs)
1300
+ else:
1301
+ output = self.lang_model.generate(
1302
+ **new_inputs,
1303
+ past_key_values=past_key_values,
1304
+ num_beams=num_beams,
1305
+ use_cache=True,
1306
+ eos_token_id=self.end_of_trunk_token_id,
1307
+ **kwargs)
1308
+ return output