ECOFRI commited on
Commit
14c93b0
1 Parent(s): 88475f0

Upload model

Browse files
CXR_LLAVA_HF.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+ import torch, transformers
3
+ from typing import List, Optional, Tuple, Union
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+ from .VisualTransformer import VisionTransformer, LayerNorm
6
+ from functools import partial
7
+ from transformers import TextIteratorStreamer
8
+ from transformers import StoppingCriteria, GenerationConfig
9
+ from threading import Thread
10
+
11
+ # Model Constants
12
+ IGNORE_INDEX = -100
13
+ IMAGE_TOKEN_INDEX = -200
14
+ DEFAULT_IMAGE_TOKEN = "<image>"
15
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
16
+ DEFAULT_IM_START_TOKEN = "<im_start>"
17
+ DEFAULT_IM_END_TOKEN = "<im_end>"
18
+ class AttrDict(dict):
19
+ def __init__(self, *args, **kwargs):
20
+ super(AttrDict, self).__init__(*args, **kwargs)
21
+ self.__dict__ = self
22
+
23
+ class CXRLLAVAConfig(PretrainedConfig):
24
+ model_type = "CXR-LLAVA"
25
+
26
+ def __init__(self, **kwargs,):
27
+
28
+ if 'llama' in kwargs:
29
+ self.llama = AttrDict(kwargs['llama'])
30
+ del kwargs['llama']
31
+
32
+ self.__dict__.update(kwargs)
33
+ super().__init__(**kwargs)
34
+
35
+
36
+ class CXRLLAVAModel(PreTrainedModel):
37
+ config_class = CXRLLAVAConfig
38
+
39
+ def __init__(self, config):
40
+ super().__init__(config)
41
+
42
+ self.tokenizer = transformers.LlamaTokenizer.from_pretrained(config._name_or_path, add_special_tokens=False)
43
+ self.tokenizer.pad_token = self.tokenizer.unk_token
44
+ self.tokenizer.sep_token = self.tokenizer.unk_token
45
+ self.tokenizer.cls_token = self.tokenizer.unk_token
46
+ self.tokenizer.mask_token = self.tokenizer.unk_token
47
+
48
+ from open_clip.model import CLIPVisionCfg
49
+ vision_cfg = CLIPVisionCfg(**config.clip_vision_cfg)
50
+
51
+ self.generation_config = GenerationConfig.from_pretrained(config._name_or_path)
52
+
53
+ vision_heads = vision_cfg.width // vision_cfg.head_width
54
+ norm_layer = LayerNorm
55
+ act_layer = torch.nn.GELU
56
+ if vision_cfg.norm_kwargs:
57
+ norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
58
+ if vision_cfg.act_kwargs is not None:
59
+ act_layer = partial(act_layer, **vision_cfg.act_kwargs)
60
+
61
+ self.vision_tower = VisionTransformer(
62
+ in_channels=1,
63
+ image_size=vision_cfg.image_size,
64
+ patch_size=vision_cfg.patch_size,
65
+ width=vision_cfg.width,
66
+ layers=vision_cfg.layers,
67
+ heads=vision_heads,
68
+ mlp_ratio=vision_cfg.mlp_ratio,
69
+ ls_init_value=vision_cfg.ls_init_value,
70
+ patch_dropout=vision_cfg.patch_dropout,
71
+ attentional_pool=vision_cfg.attentional_pool,
72
+ attn_pooler_queries=vision_cfg.attn_pooler_queries,
73
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
74
+ pos_embed_type=vision_cfg.pos_embed_type,
75
+ no_ln_pre=vision_cfg.no_ln_pre,
76
+ final_ln_after_pool=vision_cfg.final_ln_after_pool,
77
+ pool_type=vision_cfg.pool_type,
78
+ output_tokens=vision_cfg.output_tokens,
79
+ output_dim=config.clip_embed_dim,
80
+ act_layer=act_layer,
81
+ norm_layer=norm_layer,
82
+ )
83
+
84
+ self.vision_tower.image_processor = transformers.CLIPImageProcessor(
85
+ do_resize=True,
86
+ size={'shortest_edge': config.clip_vision_cfg['image_size']},
87
+ resample=True,
88
+ do_center_crop=True,
89
+ crop_size=config.clip_vision_cfg['image_size'],
90
+ do_rescale=True,
91
+ rescale_factor=1 / 255,
92
+ do_normalize=True,
93
+ image_mean=config.image_preprocess_cfg['mean'],
94
+ image_std=config.image_preprocess_cfg['std'],
95
+ do_convert_rgb=False
96
+ )
97
+
98
+ def convert_dtype(dtype):
99
+ if dtype == 'fp32':
100
+ dtype = torch.float32
101
+ elif dtype == 'fp16':
102
+ dtype = torch.float16
103
+ elif dtype == 'bf16':
104
+ dtype = torch.bfloat16
105
+ else:
106
+ raise Exception("Unsupported dtype")
107
+ return dtype
108
+
109
+ self.clip_cast_dtype = convert_dtype(config.clip_vision_tower_dtype)
110
+ self.mm_projector = torch.nn.Linear(config.mm_projector_dim, config.llama['hidden_size'])
111
+ self.lm_head = torch.nn.Linear(config.llama.hidden_size, config.llama.vocab_size, bias=False)
112
+ self.llama = transformers.LlamaModel(transformers.LlamaConfig(**config.llama))
113
+
114
+ self.llama = self.llama.to(torch.bfloat16)
115
+ self.lm_head = self.lm_head.to(torch.bfloat16)
116
+ self.vision_tower = self.vision_tower.to(torch.bfloat16)
117
+ self.mm_projector = self.mm_projector.to(torch.bfloat16)
118
+
119
+ def get_input_embeddings(self):
120
+ return self.llama.get_input_embeddings()
121
+
122
+ def get_vision_tower(self):
123
+ return self.vision_tower
124
+
125
+ def gradient_checkpointing_enable(self):
126
+ return self.llama.gradient_checkpointing_enable()
127
+
128
+ def encode_images(self, images):
129
+ images = images.to(torch.bfloat16)
130
+
131
+ def _expand_token(token, batch_size: int):
132
+ return token.view(1, 1, -1).expand(batch_size, -1, -1)
133
+
134
+ # open_clip ViT
135
+ # https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
136
+ x = images
137
+ x = self.vision_tower.conv1(x) # shape = [*, width, grid, grid]
138
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
139
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
140
+
141
+ # class embeddings and positional embeddings
142
+ x = torch.cat([_expand_token(self.vision_tower.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
143
+ # shape = [*, grid ** 2 + 1, width]
144
+ x = x + self.vision_tower.positional_embedding.to(x.dtype)
145
+
146
+ x = self.vision_tower.patch_dropout(x)
147
+ x = self.vision_tower.ln_pre(x)
148
+
149
+ x = x.permute(1, 0, 2) # NLD -> LND
150
+ x = self.vision_tower.transformer(x)
151
+ x = x.permute(1, 0, 2) # LND -> NLD
152
+
153
+ if self.vision_tower.attn_pool is not None:
154
+ if self.vision_tower.attn_pool_contrastive is not None:
155
+ # This is untested, WIP pooling that should match paper
156
+ x = self.vision_tower.ln_post(x) # TBD LN first or separate one after each pool?
157
+ tokens = self.vision_tower.attn_pool(x)
158
+ if self.vision_tower.attn_pool_type == 'parallel':
159
+ pooled = self.vision_tower.attn_pool_contrastive(x)
160
+ else:
161
+ assert self.vision_tower.attn_pool_type == 'cascade'
162
+ pooled = self.vision_tower.attn_pool_contrastive(tokens)
163
+ else:
164
+ # this is the original OpenCLIP CoCa setup, does not match paper
165
+ x = self.vision_tower.attn_pool(x)
166
+ x = self.vision_tower.ln_post(x)
167
+ pooled, tokens = self.vision_tower._global_pool(x)
168
+ elif self.vision_tower.final_ln_after_pool:
169
+ pooled, tokens = self.vision_tower._global_pool(x)
170
+ pooled = self.vision_tower.ln_post(pooled)
171
+ else:
172
+ x = self.vision_tower.ln_post(x)
173
+ pooled, tokens = self.vision_tower._global_pool(x)
174
+
175
+ if self.vision_tower.proj is not None:
176
+ pooled = pooled @ self.vision_tower.proj
177
+
178
+ image_features = tokens
179
+ image_features = image_features.to(torch.bfloat16)
180
+ image_features = self.mm_projector(image_features)
181
+
182
+ image_features = image_features.to(torch.bfloat16)
183
+ return image_features
184
+
185
+ def forward(
186
+ self,
187
+ input_ids: torch.LongTensor = None,
188
+ attention_mask: Optional[torch.Tensor] = None,
189
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
190
+ inputs_embeds: Optional[torch.FloatTensor] = None,
191
+ labels: Optional[torch.LongTensor] = None, # (1,4317)
192
+ use_cache: Optional[bool] = None,
193
+ output_attentions: Optional[bool] = None,
194
+ output_hidden_states: Optional[bool] = None,
195
+ images: Optional[torch.FloatTensor] = None,
196
+ return_dict: Optional[bool] = None,
197
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
198
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
199
+ output_hidden_states = (
200
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
201
+ )
202
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
203
+
204
+
205
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(
206
+ input_ids, attention_mask, past_key_values, labels, images)
207
+
208
+ outputs = self.llama(
209
+ input_ids=input_ids,
210
+ attention_mask=attention_mask,
211
+ past_key_values=past_key_values,
212
+ inputs_embeds=inputs_embeds,
213
+ use_cache=use_cache,
214
+ output_attentions=output_attentions,
215
+ output_hidden_states=output_hidden_states,
216
+ return_dict=return_dict
217
+ )
218
+
219
+ hidden_states = outputs[0]
220
+ logits = self.lm_head(hidden_states)
221
+
222
+ loss = None
223
+
224
+ return CausalLMOutputWithPast(
225
+ loss=loss,
226
+ logits=logits,
227
+ past_key_values=outputs.past_key_values,
228
+ hidden_states=outputs.hidden_states,
229
+ attentions=outputs.attentions,
230
+ )
231
+
232
+ # original multimodal code
233
+ def prepare_inputs_labels_for_multimodal(
234
+ self, input_ids, attention_mask, past_key_values, labels, images
235
+ ):
236
+ vision_tower = self.vision_tower
237
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
238
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
239
+ 1] == 1:
240
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
241
+ dtype=attention_mask.dtype, device=attention_mask.device)
242
+ return input_ids, attention_mask, past_key_values, None, labels
243
+
244
+ if type(images) is list or images.ndim == 5:
245
+ concat_images = torch.cat([image for image in images], dim=0)
246
+ image_features = self.encode_images(concat_images)
247
+ split_sizes = [image.shape[0] for image in images]
248
+ image_features = torch.split(image_features, split_sizes, dim=0)
249
+ image_features = [x.flatten(0, 1) for x in image_features]
250
+ else:
251
+ image_features = self.encode_images(images)
252
+
253
+ new_input_embeds = []
254
+ new_labels = [] if labels is not None else None
255
+ cur_image_idx = 0
256
+ for batch_idx, cur_input_ids in enumerate(input_ids):
257
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
258
+ # multimodal LLM, but the current sample is not multimodal
259
+ cur_input_embeds = self.llama.embed_tokens(cur_input_ids)
260
+ cur_input_embeds = cur_input_embeds + (0. * self.mm_projector(vision_tower.dummy_feature)).sum()
261
+ new_input_embeds.append(cur_input_embeds)
262
+ if labels is not None:
263
+ new_labels.append(labels[batch_idx])
264
+ cur_image_idx += 1
265
+ continue
266
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
267
+ cur_new_input_embeds = []
268
+ if labels is not None:
269
+ cur_labels = labels[batch_idx]
270
+ cur_new_labels = []
271
+ assert cur_labels.shape == cur_input_ids.shape
272
+ while image_token_indices.numel() > 0:
273
+ cur_image_features = image_features[cur_image_idx]
274
+ image_token_start = image_token_indices[0]
275
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end',
276
+ False):
277
+ cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start - 1]).detach())
278
+ cur_new_input_embeds.append(
279
+ self.llama.embed_tokens(cur_input_ids[image_token_start - 1:image_token_start]))
280
+ cur_new_input_embeds.append(cur_image_features)
281
+ cur_new_input_embeds.append(
282
+ self.llama.embed_tokens(cur_input_ids[image_token_start + 1:image_token_start + 2]))
283
+ if labels is not None:
284
+ cur_new_labels.append(cur_labels[:image_token_start])
285
+ cur_new_labels.append(
286
+ torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device,
287
+ dtype=labels.dtype))
288
+ cur_new_labels.append(cur_labels[image_token_start:image_token_start + 1])
289
+ cur_labels = cur_labels[image_token_start + 2:]
290
+ else:
291
+ cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start]))
292
+ cur_new_input_embeds.append(cur_image_features)
293
+ if labels is not None:
294
+ cur_new_labels.append(cur_labels[:image_token_start])
295
+ cur_new_labels.append(
296
+ torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device,
297
+ dtype=labels.dtype))
298
+ cur_labels = cur_labels[image_token_start + 1:]
299
+ cur_image_idx += 1
300
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end',
301
+ False):
302
+ cur_input_ids = cur_input_ids[image_token_start + 2:]
303
+ else:
304
+ cur_input_ids = cur_input_ids[image_token_start + 1:]
305
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
306
+ if cur_input_ids.numel() > 0:
307
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end',
308
+ False):
309
+ cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids).detach())
310
+ else:
311
+ cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids))
312
+ if labels is not None:
313
+ cur_new_labels.append(cur_labels)
314
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
315
+
316
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
317
+ new_input_embeds.append(cur_new_input_embeds)
318
+ if labels is not None:
319
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
320
+ new_labels.append(cur_new_labels)
321
+
322
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
323
+ max_len = max(x.shape[0] for x in new_input_embeds)
324
+
325
+ new_input_embeds_align = []
326
+ for cur_new_embed in new_input_embeds:
327
+ cur_new_embed = torch.cat((cur_new_embed,
328
+ torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
329
+ dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
330
+ new_input_embeds_align.append(cur_new_embed)
331
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
332
+
333
+ if labels is not None:
334
+ new_labels_align = []
335
+ _new_labels = new_labels
336
+ for cur_new_label in new_labels:
337
+ cur_new_label = torch.cat((cur_new_label,
338
+ torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX,
339
+ dtype=cur_new_label.dtype, device=cur_new_label.device)),
340
+ dim=0)
341
+ new_labels_align.append(cur_new_label)
342
+ new_labels = torch.stack(new_labels_align, dim=0)
343
+
344
+ if attention_mask is not None:
345
+ new_attention_mask = []
346
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels,
347
+ new_labels):
348
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True,
349
+ dtype=attention_mask.dtype, device=attention_mask.device)
350
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
351
+ False, dtype=attention_mask.dtype,
352
+ device=attention_mask.device)
353
+ cur_new_attention_mask = torch.cat(
354
+ (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
355
+ new_attention_mask.append(cur_new_attention_mask)
356
+ attention_mask = torch.stack(new_attention_mask, dim=0)
357
+ assert attention_mask.shape == new_labels.shape
358
+ else:
359
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
360
+ if labels is not None:
361
+ new_labels = torch.stack(new_labels, dim=0)
362
+
363
+ if attention_mask is not None:
364
+ new_attn_mask_pad_left = torch.full(
365
+ (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True,
366
+ dtype=attention_mask.dtype, device=attention_mask.device)
367
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
368
+ assert attention_mask.shape == new_input_embeds.shape[:2]
369
+
370
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
371
+
372
+ # sw-modified code
373
+
374
+ def prepare_inputs_labels_for_multimodal_use_final_vector(
375
+ self, input_ids, attention_mask, past_key_values, labels, images
376
+ ):
377
+ vision_tower = self.vision_tower
378
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
379
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
380
+ 1] == 1:
381
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
382
+ dtype=attention_mask.dtype, device=attention_mask.device)
383
+ return input_ids, attention_mask, past_key_values, None, labels
384
+
385
+ if type(images) is list or images.ndim == 5:
386
+ concat_images = torch.cat([image for image in images], dim=0)
387
+ image_features = self.encode_images(concat_images)
388
+ split_sizes = [image.shape[0] for image in images]
389
+ image_features = torch.split(image_features, split_sizes, dim=0)
390
+ image_features = [x.flatten(0, 1) for x in image_features]
391
+ else:
392
+ image_features = self.encode_images(images)
393
+
394
+ new_input_embeds = []
395
+ new_labels = [] if labels is not None else None
396
+ cur_image_idx = 0
397
+ for batch_idx, cur_input_ids in enumerate(input_ids):
398
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
399
+ # multimodal LLM, but the current sample is not multimodal
400
+ cur_input_embeds = self.llama.embed_tokens(cur_input_ids)
401
+ cur_input_embeds = cur_input_embeds + (0. * self.mm_projector(vision_tower.dummy_feature)).sum()
402
+ new_input_embeds.append(cur_input_embeds)
403
+ if labels is not None:
404
+ new_labels.append(labels[batch_idx])
405
+ cur_image_idx += 1
406
+ continue
407
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
408
+ cur_new_input_embeds = []
409
+ if labels is not None:
410
+ cur_labels = labels[batch_idx]
411
+ cur_new_labels = []
412
+ assert cur_labels.shape == cur_input_ids.shape
413
+ while image_token_indices.numel() > 0:
414
+ cur_image_features = image_features[cur_image_idx]
415
+ image_token_start = image_token_indices[0]
416
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end',
417
+ False):
418
+ cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start - 1]).detach())
419
+ cur_new_input_embeds.append(
420
+ self.llama.embed_tokens(cur_input_ids[image_token_start - 1:image_token_start]))
421
+ cur_new_input_embeds.append(cur_image_features)
422
+ cur_new_input_embeds.append(
423
+ self.llama.embed_tokens(cur_input_ids[image_token_start + 1:image_token_start + 2]))
424
+ if labels is not None:
425
+ cur_new_labels.append(cur_labels[:image_token_start])
426
+ cur_new_labels.append(
427
+ torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device,
428
+ dtype=labels.dtype))
429
+ cur_new_labels.append(cur_labels[image_token_start:image_token_start + 1])
430
+ cur_labels = cur_labels[image_token_start + 2:]
431
+ else:
432
+ cur_new_input_embeds.append(
433
+ self.llama.embed_tokens(cur_input_ids[:image_token_start].to(self.device)))
434
+ cur_new_input_embeds.append(cur_image_features)
435
+ if labels is not None:
436
+ cur_new_labels.append(cur_labels[:image_token_start])
437
+ cur_new_labels.append(
438
+ torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device,
439
+ dtype=labels.dtype))
440
+ cur_labels = cur_labels[image_token_start + 1:]
441
+ cur_image_idx += 1
442
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end',
443
+ False):
444
+ cur_input_ids = cur_input_ids[image_token_start + 2:]
445
+ else:
446
+ cur_input_ids = cur_input_ids[image_token_start + 1:]
447
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
448
+ if cur_input_ids.numel() > 0:
449
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end',
450
+ False):
451
+ cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids).detach())
452
+ else:
453
+ cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids.to(self.device)))
454
+ if labels is not None:
455
+ # seowoo-edit
456
+ cur_labels = labels[batch_idx]
457
+ cur_new_labels.append(cur_labels)
458
+ # [5120] -> [1, 5120]
459
+ cur_new_input_embeds[1] = torch.unsqueeze(cur_new_input_embeds[1], dim=0)
460
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
461
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
462
+ new_input_embeds.append(cur_new_input_embeds)
463
+ if labels is not None:
464
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
465
+ new_labels.append(cur_new_labels)
466
+
467
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
468
+ # print("if 204")
469
+ max_len = max(x.shape[0] for x in new_input_embeds)
470
+
471
+ new_input_embeds_align = []
472
+ for cur_new_embed in new_input_embeds:
473
+ cur_new_embed = torch.cat((cur_new_embed,
474
+ torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
475
+ dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
476
+ new_input_embeds_align.append(cur_new_embed)
477
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
478
+
479
+ if labels is not None:
480
+ new_labels_align = []
481
+ _new_labels = new_labels
482
+ for cur_new_label in new_labels:
483
+ cur_new_label = torch.cat((cur_new_label,
484
+ torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX,
485
+ dtype=cur_new_label.dtype, device=cur_new_label.device)),
486
+ dim=0)
487
+ new_labels_align.append(cur_new_label)
488
+ new_labels = torch.stack(new_labels_align, dim=0)
489
+
490
+ if attention_mask is not None:
491
+ new_attention_mask = []
492
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels,
493
+ new_labels):
494
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True,
495
+ dtype=attention_mask.dtype, device=attention_mask.device)
496
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
497
+ False, dtype=attention_mask.dtype,
498
+ device=attention_mask.device)
499
+ cur_new_attention_mask = torch.cat(
500
+ (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
501
+ new_attention_mask.append(cur_new_attention_mask)
502
+ attention_mask = torch.stack(new_attention_mask, dim=0)
503
+ assert attention_mask.shape == new_labels.shape
504
+ else:
505
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
506
+ if labels is not None:
507
+ new_labels = torch.stack(new_labels, dim=0)
508
+
509
+ if attention_mask is not None:
510
+ new_attn_mask_pad_left = torch.full(
511
+ (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True,
512
+ dtype=attention_mask.dtype, device=attention_mask.device)
513
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
514
+ assert attention_mask.shape == new_input_embeds.shape[:2]
515
+
516
+ return None, attention_mask, past_key_values, new_input_embeds, labels
517
+
518
+ def prepare_inputs_for_generation(
519
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
520
+ ):
521
+ if past_key_values:
522
+ input_ids = input_ids[:, -1:]
523
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
524
+ if inputs_embeds is not None and past_key_values is None:
525
+ model_inputs = {"inputs_embeds": inputs_embeds}
526
+ else:
527
+ model_inputs = {"input_ids": input_ids}
528
+ model_inputs.update(
529
+ {
530
+ "past_key_values": past_key_values,
531
+ "use_cache": kwargs.get("use_cache"),
532
+ "attention_mask": attention_mask,
533
+ "images": kwargs.get("images", None),
534
+ }
535
+ )
536
+ return model_inputs
537
+
538
+ def apply_chat_template(self, chat):
539
+ return self.tokenizer.apply_chat_template(chat, tokenize=False)
540
+
541
+ def tokenizer_image_token(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
542
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
543
+
544
+ def insert_separator(X, sep):
545
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
546
+
547
+ input_ids = []
548
+ offset = 0
549
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
550
+ offset = 1
551
+ input_ids.append(prompt_chunks[0][0])
552
+
553
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
554
+ input_ids.extend(x[offset:])
555
+
556
+ if return_tensors is not None:
557
+ if return_tensors == 'pt':
558
+ return torch.tensor(input_ids, dtype=torch.long)
559
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
560
+ return input_ids
561
+
562
+ def generate_cxr_repsonse(self, chat, pil_image, temperature=0.2, top_p=0.8):
563
+ with torch.no_grad():
564
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
565
+ import numpy as np
566
+ pil_image = np.expand_dims(pil_image,axis=-1)
567
+ prompt = self.apply_chat_template(chat)
568
+ images = self.vision_tower.image_processor(pil_image, return_tensors='pt')['pixel_values']
569
+ images = images.to(self.device)
570
+ input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
571
+ stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
572
+
573
+ image_args = {"images": images}
574
+ do_sample = True if temperature > 0.001 else False
575
+ num_image_tokens = 1
576
+ max_context_length = getattr(self.config, 'max_position_embeddings', 2048)
577
+
578
+ max_new_tokens = min(512, max_context_length - input_ids.shape[-1] - num_image_tokens)
579
+
580
+ thread = Thread(target=self.generate, kwargs=dict(
581
+ inputs=input_ids,
582
+ do_sample=do_sample,
583
+ temperature=temperature,
584
+ top_p=top_p,
585
+ max_new_tokens=max_new_tokens,
586
+ streamer=streamer,
587
+ stopping_criteria=[stopping_criteria],
588
+ use_cache=True,
589
+ generation_config=self.generation_config,
590
+ **image_args
591
+ ))
592
+ thread.start()
593
+ generated_text = ""
594
+ for new_text in streamer:
595
+ generated_text += new_text
596
+
597
+ return generated_text
598
+
599
+ def tokenizer_image_token(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
600
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
601
+
602
+ def insert_separator(X, sep):
603
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
604
+
605
+ input_ids = []
606
+ offset = 0
607
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
608
+ offset = 1
609
+ input_ids.append(prompt_chunks[0][0])
610
+
611
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
612
+ input_ids.extend(x[offset:])
613
+
614
+ if return_tensors is not None:
615
+ if return_tensors == 'pt':
616
+ return torch.tensor(input_ids, dtype=torch.long)
617
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
618
+ return input_ids
619
+ class KeywordsStoppingCriteria(StoppingCriteria):
620
+ def __init__(self, keywords, tokenizer, input_ids):
621
+ self.keywords = keywords
622
+ self.keyword_ids = []
623
+ for keyword in keywords:
624
+ cur_keyword_ids = tokenizer(keyword).input_ids
625
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
626
+ cur_keyword_ids = cur_keyword_ids[1:]
627
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
628
+ self.tokenizer = tokenizer
629
+ self.start_len = input_ids.shape[1]
630
+
631
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
632
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
633
+ offset = min(output_ids.shape[1] - self.start_len, 3)
634
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
635
+ for keyword_id in self.keyword_ids:
636
+ if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
637
+ return True
638
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
639
+ for keyword in self.keywords:
640
+ if keyword in outputs:
641
+ return True
642
+ return False
VisualTransformer.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import math
3
+ from typing import Callable, Optional, Sequence, Tuple
4
+ from functools import partial
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+ from itertools import repeat
12
+ import collections.abc
13
+
14
+ # From PyTorch internals
15
+ def _ntuple(n):
16
+ def parse(x):
17
+ if isinstance(x, collections.abc.Iterable):
18
+ return x
19
+ return tuple(repeat(x, n))
20
+ return parse
21
+
22
+ to_1tuple = _ntuple(1)
23
+ to_2tuple = _ntuple(2)
24
+ to_3tuple = _ntuple(3)
25
+ to_4tuple = _ntuple(4)
26
+ to_ntuple = lambda n, x: _ntuple(n)(x)
27
+
28
+ class LayerNormFp32(nn.LayerNorm):
29
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
30
+
31
+ def forward(self, x: torch.Tensor):
32
+ orig_type = x.dtype
33
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
34
+
35
+ #x = F.layer_norm(x.to(torch.bfloat16), self.normalized_shape, self.weight, self.bias, self.eps)
36
+ return x.to(orig_type)
37
+
38
+
39
+ class LayerNorm(nn.LayerNorm):
40
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ orig_type = x.dtype
44
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
45
+ return x.to(orig_type)
46
+
47
+
48
+ class QuickGELU(nn.Module):
49
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
50
+ def forward(self, x: torch.Tensor):
51
+ return x * torch.sigmoid(1.702 * x)
52
+
53
+
54
+ class LayerScale(nn.Module):
55
+ def __init__(self, dim, init_values=1e-5, inplace=False):
56
+ super().__init__()
57
+ self.inplace = inplace
58
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
59
+
60
+ def forward(self, x):
61
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
62
+
63
+
64
+ class PatchDropout(nn.Module):
65
+ """
66
+ https://arxiv.org/abs/2212.00794
67
+ """
68
+
69
+ def __init__(self, prob, exclude_first_token=True):
70
+ super().__init__()
71
+ assert 0 <= prob < 1.
72
+ self.prob = prob
73
+ self.exclude_first_token = exclude_first_token # exclude CLS token
74
+
75
+ def forward(self, x):
76
+ if not self.training or self.prob == 0.:
77
+ return x
78
+
79
+ if self.exclude_first_token:
80
+ cls_tokens, x = x[:, :1], x[:, 1:]
81
+ else:
82
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
83
+
84
+ batch = x.size()[0]
85
+ num_tokens = x.size()[1]
86
+
87
+ batch_indices = torch.arange(batch)
88
+ batch_indices = batch_indices[..., None]
89
+
90
+ keep_prob = 1 - self.prob
91
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
92
+
93
+ rand = torch.randn(batch, num_tokens)
94
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
95
+
96
+ x = x[batch_indices, patch_indices_keep]
97
+
98
+ if self.exclude_first_token:
99
+ x = torch.cat((cls_tokens, x), dim=1)
100
+
101
+ return x
102
+
103
+
104
+ class Attention(nn.Module):
105
+ def __init__(
106
+ self,
107
+ dim,
108
+ num_heads=8,
109
+ qkv_bias=True,
110
+ scaled_cosine=False,
111
+ scale_heads=False,
112
+ logit_scale_max=math.log(1. / 0.01),
113
+ attn_drop=0.,
114
+ proj_drop=0.
115
+ ):
116
+ super().__init__()
117
+ self.scaled_cosine = scaled_cosine
118
+ self.scale_heads = scale_heads
119
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
120
+ self.num_heads = num_heads
121
+ self.head_dim = dim // num_heads
122
+ self.scale = self.head_dim ** -0.5
123
+ self.logit_scale_max = logit_scale_max
124
+
125
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
126
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
127
+ if qkv_bias:
128
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
129
+ else:
130
+ self.in_proj_bias = None
131
+
132
+ if self.scaled_cosine:
133
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
134
+ else:
135
+ self.logit_scale = None
136
+ self.attn_drop = nn.Dropout(attn_drop)
137
+ if self.scale_heads:
138
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
139
+ else:
140
+ self.head_scale = None
141
+ self.out_proj = nn.Linear(dim, dim)
142
+ self.out_drop = nn.Dropout(proj_drop)
143
+
144
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
145
+ L, N, C = x.shape
146
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
147
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
148
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
149
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
150
+
151
+ if self.logit_scale is not None:
152
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
153
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
154
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
155
+ attn = attn.view(-1, L, L)
156
+ else:
157
+ q = q * self.scale
158
+ attn = torch.bmm(q, k.transpose(-1, -2))
159
+
160
+ if attn_mask is not None:
161
+ if attn_mask.dtype == torch.bool:
162
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
163
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
164
+ attn_mask = new_attn_mask
165
+ attn += attn_mask
166
+
167
+ attn = attn.softmax(dim=-1)
168
+ attn = self.attn_drop(attn)
169
+
170
+ x = torch.bmm(attn, v)
171
+ if self.head_scale is not None:
172
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
173
+ x = x.view(-1, L, C)
174
+ x = x.transpose(0, 1).reshape(L, N, C)
175
+ x = self.out_proj(x)
176
+ x = self.out_drop(x)
177
+ return x
178
+
179
+
180
+ class AttentionalPooler(nn.Module):
181
+ def __init__(
182
+ self,
183
+ d_model: int,
184
+ context_dim: int,
185
+ n_head: int = 8,
186
+ n_queries: int = 256,
187
+ norm_layer: Callable = LayerNorm
188
+ ):
189
+ super().__init__()
190
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
191
+ self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
192
+ self.ln_q = norm_layer(d_model)
193
+ self.ln_k = norm_layer(context_dim)
194
+
195
+ def forward(self, x: torch.Tensor):
196
+ x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
197
+ N = x.shape[1]
198
+ q = self.ln_q(self.query)
199
+ out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0]
200
+ return out.permute(1, 0, 2) # LND -> NLD
201
+
202
+
203
+ class ResidualAttentionBlock(nn.Module):
204
+ def __init__(
205
+ self,
206
+ d_model: int,
207
+ n_head: int,
208
+ mlp_ratio: float = 4.0,
209
+ ls_init_value: float = None,
210
+ act_layer: Callable = nn.GELU,
211
+ norm_layer: Callable = LayerNorm,
212
+ is_cross_attention: bool = False,
213
+ ):
214
+ super().__init__()
215
+
216
+ self.ln_1 = norm_layer(d_model)
217
+ self.attn = nn.MultiheadAttention(d_model, n_head)
218
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
219
+ if is_cross_attention:
220
+ self.ln_1_kv = norm_layer(d_model)
221
+
222
+ self.ln_2 = norm_layer(d_model)
223
+ mlp_width = int(d_model * mlp_ratio)
224
+ self.mlp = nn.Sequential(OrderedDict([
225
+ ("c_fc", nn.Linear(d_model, mlp_width)),
226
+ ("gelu", act_layer()),
227
+ ("c_proj", nn.Linear(mlp_width, d_model))
228
+ ]))
229
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
230
+
231
+ def attention(
232
+ self,
233
+ q_x: torch.Tensor,
234
+ k_x: Optional[torch.Tensor] = None,
235
+ v_x: Optional[torch.Tensor] = None,
236
+ attn_mask: Optional[torch.Tensor] = None,
237
+ ):
238
+ k_x = k_x if k_x is not None else q_x
239
+ v_x = v_x if v_x is not None else q_x
240
+
241
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
242
+ return self.attn(
243
+ q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
244
+ )[0]
245
+
246
+ def forward(
247
+ self,
248
+ q_x: torch.Tensor,
249
+ k_x: Optional[torch.Tensor] = None,
250
+ v_x: Optional[torch.Tensor] = None,
251
+ attn_mask: Optional[torch.Tensor] = None,
252
+ ):
253
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
254
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
255
+
256
+ x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
257
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
258
+ return x
259
+
260
+
261
+ class CustomResidualAttentionBlock(nn.Module):
262
+ def __init__(
263
+ self,
264
+ d_model: int,
265
+ n_head: int,
266
+ mlp_ratio: float = 4.0,
267
+ ls_init_value: float = None,
268
+ act_layer: Callable = nn.GELU,
269
+ norm_layer: Callable = LayerNorm,
270
+ scale_cosine_attn: bool = False,
271
+ scale_heads: bool = False,
272
+ scale_attn: bool = False,
273
+ scale_fc: bool = False,
274
+ ):
275
+ super().__init__()
276
+
277
+ self.ln_1 = norm_layer(d_model)
278
+ self.attn = Attention(
279
+ d_model, n_head,
280
+ scaled_cosine=scale_cosine_attn,
281
+ scale_heads=scale_heads,
282
+ )
283
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
284
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
285
+
286
+ self.ln_2 = norm_layer(d_model)
287
+ mlp_width = int(d_model * mlp_ratio)
288
+ self.mlp = nn.Sequential(OrderedDict([
289
+ ("c_fc", nn.Linear(d_model, mlp_width)),
290
+ ("gelu", act_layer()),
291
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
292
+ ("c_proj", nn.Linear(mlp_width, d_model))
293
+ ]))
294
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
295
+
296
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
297
+ x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
298
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
299
+ return x
300
+
301
+
302
+ def _expand_token(token, batch_size: int):
303
+ return token.view(1, 1, -1).expand(batch_size, -1, -1)
304
+
305
+
306
+ class Transformer(nn.Module):
307
+ def __init__(
308
+ self,
309
+ width: int,
310
+ layers: int,
311
+ heads: int,
312
+ mlp_ratio: float = 4.0,
313
+ ls_init_value: float = None,
314
+ act_layer: Callable = nn.GELU,
315
+ norm_layer: Callable = LayerNorm,
316
+ ):
317
+ super().__init__()
318
+ self.width = width
319
+ self.layers = layers
320
+ self.grad_checkpointing = False
321
+
322
+ self.resblocks = nn.ModuleList([
323
+ ResidualAttentionBlock(
324
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
325
+ for _ in range(layers)
326
+ ])
327
+
328
+ def get_cast_dtype(self) -> torch.dtype:
329
+ if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
330
+ return self.resblocks[0].mlp.c_fc.int8_original_dtype
331
+ return self.resblocks[0].mlp.c_fc.weight.dtype
332
+
333
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
334
+ for r in self.resblocks:
335
+ if self.grad_checkpointing and not torch.jit.is_scripting():
336
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
337
+ x = checkpoint(r, x, None, None, attn_mask)
338
+ else:
339
+ x = r(x, attn_mask=attn_mask)
340
+ return x
341
+
342
+
343
+ class VisionTransformer(nn.Module):
344
+ output_tokens: torch.jit.Final[bool]
345
+
346
+ def __init__(
347
+ self,
348
+ in_channels:int,
349
+ image_size: int,
350
+ patch_size: int,
351
+ width: int,
352
+ layers: int,
353
+ heads: int,
354
+ mlp_ratio: float,
355
+ ls_init_value: float = None,
356
+ attentional_pool: bool = False,
357
+ attn_pooler_queries: int = 256,
358
+ attn_pooler_heads: int = 8,
359
+ output_dim: int = 512,
360
+ patch_dropout: float = 0.,
361
+ no_ln_pre: bool = False,
362
+ pos_embed_type: str = 'learnable',
363
+ pool_type: str = 'tok',
364
+ final_ln_after_pool: bool = False,
365
+ act_layer: Callable = nn.GELU,
366
+ norm_layer: Callable = LayerNorm,
367
+ output_tokens: bool = False,
368
+ ):
369
+ super().__init__()
370
+ assert pool_type in ('tok', 'avg', 'none')
371
+ self.output_tokens = output_tokens
372
+ image_height, image_width = self.image_size = to_2tuple(image_size)
373
+ patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
374
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
375
+ self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
376
+ self.output_dim = output_dim
377
+
378
+ self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
379
+
380
+ # class embeddings and positional embeddings
381
+ scale = width ** -0.5
382
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
383
+ if pos_embed_type == 'learnable':
384
+ self.positional_embedding = nn.Parameter(
385
+ scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
386
+ elif pos_embed_type == 'sin_cos_2d':
387
+ # fixed sin-cos embedding
388
+ assert self.grid_size[0] == self.grid_size[1], \
389
+ 'currently sin cos 2d pos embedding only supports square input'
390
+ self.positional_embedding = nn.Parameter(
391
+ torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)
392
+ pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)
393
+ self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())
394
+ else:
395
+ raise ValueError
396
+
397
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
398
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
399
+
400
+ self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
401
+ self.transformer = Transformer(
402
+ width,
403
+ layers,
404
+ heads,
405
+ mlp_ratio,
406
+ ls_init_value=ls_init_value,
407
+ act_layer=act_layer,
408
+ norm_layer=norm_layer,
409
+ )
410
+
411
+ if attentional_pool:
412
+ if isinstance(attentional_pool, str):
413
+ self.attn_pool_type = attentional_pool
414
+ self.pool_type = 'none'
415
+ if attentional_pool in ('parallel', 'cascade'):
416
+ self.attn_pool = AttentionalPooler(
417
+ output_dim,
418
+ width,
419
+ n_head=attn_pooler_heads,
420
+ n_queries=attn_pooler_queries,
421
+ )
422
+ self.attn_pool_contrastive = AttentionalPooler(
423
+ output_dim,
424
+ width,
425
+ n_head=attn_pooler_heads,
426
+ n_queries=1,
427
+ )
428
+ else:
429
+ assert False
430
+ else:
431
+ self.attn_pool_type = ''
432
+ self.pool_type = pool_type
433
+ self.attn_pool = AttentionalPooler(
434
+ output_dim,
435
+ width,
436
+ n_head=attn_pooler_heads,
437
+ n_queries=attn_pooler_queries,
438
+ )
439
+ self.attn_pool_contrastive = None
440
+ pool_dim = output_dim
441
+ else:
442
+ self.attn_pool = None
443
+ pool_dim = width
444
+ self.pool_type = pool_type
445
+
446
+ self.ln_post = norm_layer(pool_dim)
447
+ self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
448
+
449
+ self.init_parameters()
450
+
451
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
452
+ for param in self.parameters():
453
+ param.requires_grad = False
454
+
455
+ if unlocked_groups != 0:
456
+ groups = [
457
+ [
458
+ self.conv1,
459
+ self.class_embedding,
460
+ self.positional_embedding,
461
+ self.ln_pre,
462
+ ],
463
+ *self.transformer.resblocks[:-1],
464
+ [
465
+ self.transformer.resblocks[-1],
466
+ self.ln_post,
467
+ ],
468
+ self.proj,
469
+ ]
470
+
471
+ def _unlock(x):
472
+ if isinstance(x, Sequence):
473
+ for g in x:
474
+ _unlock(g)
475
+ else:
476
+ if isinstance(x, torch.nn.Parameter):
477
+ x.requires_grad = True
478
+ else:
479
+ for p in x.parameters():
480
+ p.requires_grad = True
481
+
482
+ _unlock(groups[-unlocked_groups:])
483
+
484
+ def init_parameters(self):
485
+ # FIXME OpenAI CLIP did not define an init for the VisualTransformer
486
+ # TODO experiment if default PyTorch init, below, or alternate init is best.
487
+
488
+ # nn.init.normal_(self.class_embedding, std=self.scale)
489
+ # nn.init.normal_(self.positional_embedding, std=self.scale)
490
+ #
491
+ # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
492
+ # attn_std = self.transformer.width ** -0.5
493
+ # fc_std = (2 * self.transformer.width) ** -0.5
494
+ # for block in self.transformer.resblocks:
495
+ # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
496
+ # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
497
+ # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
498
+ # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
499
+ #
500
+ # if self.text_projection is not None:
501
+ # nn.init.normal_(self.text_projection, std=self.scale)
502
+ pass
503
+
504
+ @torch.jit.ignore
505
+ def set_grad_checkpointing(self, enable=True):
506
+ self.transformer.grad_checkpointing = enable
507
+
508
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
509
+ if self.pool_type == 'avg':
510
+ pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
511
+ elif self.pool_type == 'tok':
512
+ pooled, tokens = x[:, 0], x[:, 1:]
513
+ else:
514
+ pooled = tokens = x
515
+
516
+ return pooled, tokens
517
+
518
+ def forward(self, x: torch.Tensor):
519
+ x = self.conv1(x) # shape = [*, width, grid, grid]
520
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
521
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
522
+
523
+ # class embeddings and positional embeddings
524
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
525
+ # shape = [*, grid ** 2 + 1, width]
526
+ x = x + self.positional_embedding.to(x.dtype)
527
+
528
+ x = self.patch_dropout(x)
529
+ x = self.ln_pre(x)
530
+
531
+ x = x.permute(1, 0, 2) # NLD -> LND
532
+ x = self.transformer(x)
533
+ x = x.permute(1, 0, 2) # LND -> NLD
534
+
535
+ if self.attn_pool is not None:
536
+ if self.attn_pool_contrastive is not None:
537
+ # This is untested, WIP pooling that should match paper
538
+ x = self.ln_post(x) # TBD LN first or separate one after each pool?
539
+ tokens = self.attn_pool(x)
540
+ if self.attn_pool_type == 'parallel':
541
+ pooled = self.attn_pool_contrastive(x)
542
+ else:
543
+ assert self.attn_pool_type == 'cascade'
544
+ pooled = self.attn_pool_contrastive(tokens)
545
+ else:
546
+ # this is the original OpenCLIP CoCa setup, does not match paper
547
+ x = self.attn_pool(x)
548
+ x = self.ln_post(x)
549
+ pooled, tokens = self._global_pool(x)
550
+ elif self.final_ln_after_pool:
551
+ pooled, tokens = self._global_pool(x)
552
+ pooled = self.ln_post(pooled)
553
+ else:
554
+ x = self.ln_post(x)
555
+ pooled, tokens = self._global_pool(x)
556
+
557
+ if self.proj is not None:
558
+ pooled = pooled @ self.proj
559
+
560
+ if self.output_tokens:
561
+ return pooled, tokens
562
+
563
+ return pooled
564
+
565
+
566
+ def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'):
567
+ if pool_type == 'first':
568
+ pooled, tokens = x[:, 0], x[:, 1:]
569
+ elif pool_type == 'last':
570
+ pooled, tokens = x[:, -1], x[:, :-1]
571
+ elif pool_type == 'argmax':
572
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
573
+ assert text is not None
574
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
575
+ else:
576
+ pooled = tokens = x
577
+
578
+ return pooled, tokens
579
+
580
+
581
+ class TextTransformer(nn.Module):
582
+ output_tokens: torch.jit.Final[bool]
583
+
584
+ def __init__(
585
+ self,
586
+ context_length: int = 77,
587
+ vocab_size: int = 49408,
588
+ width: int = 512,
589
+ heads: int = 8,
590
+ layers: int = 12,
591
+ mlp_ratio: float = 4.0,
592
+ ls_init_value: float = None,
593
+ output_dim: int = 512,
594
+ embed_cls: bool = False,
595
+ no_causal_mask: bool = False,
596
+ pad_id: int = 0,
597
+ pool_type: str = 'argmax',
598
+ proj_bias: bool = False,
599
+ act_layer: Callable = nn.GELU,
600
+ norm_layer: Callable = LayerNorm,
601
+ output_tokens: bool = False,
602
+ ):
603
+ super().__init__()
604
+ assert pool_type in ('first', 'last', 'argmax', 'none')
605
+ self.output_tokens = output_tokens
606
+ self.num_pos = self.context_length = context_length
607
+ self.vocab_size = vocab_size
608
+ self.width = width
609
+ self.output_dim = output_dim
610
+ self.heads = heads
611
+ self.pad_id = pad_id
612
+ self.pool_type = pool_type
613
+
614
+ self.token_embedding = nn.Embedding(vocab_size, width)
615
+ if embed_cls:
616
+ self.cls_emb = nn.Parameter(torch.empty(width))
617
+ self.num_pos += 1
618
+ else:
619
+ self.cls_emb = None
620
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
621
+ self.transformer = Transformer(
622
+ width=width,
623
+ layers=layers,
624
+ heads=heads,
625
+ mlp_ratio=mlp_ratio,
626
+ ls_init_value=ls_init_value,
627
+ act_layer=act_layer,
628
+ norm_layer=norm_layer,
629
+ )
630
+ self.ln_final = norm_layer(width)
631
+
632
+ if no_causal_mask:
633
+ self.attn_mask = None
634
+ else:
635
+ self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)
636
+
637
+ if proj_bias:
638
+ self.text_projection = nn.Linear(width, output_dim)
639
+ else:
640
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
641
+
642
+ self.init_parameters()
643
+
644
+ def init_parameters(self):
645
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
646
+ nn.init.normal_(self.positional_embedding, std=0.01)
647
+ if self.cls_emb is not None:
648
+ nn.init.normal_(self.cls_emb, std=0.01)
649
+
650
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
651
+ attn_std = self.transformer.width ** -0.5
652
+ fc_std = (2 * self.transformer.width) ** -0.5
653
+ for block in self.transformer.resblocks:
654
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
655
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
656
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
657
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
658
+
659
+ if self.text_projection is not None:
660
+ if isinstance(self.text_projection, nn.Linear):
661
+ nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)
662
+ if self.text_projection.bias is not None:
663
+ nn.init.zeros_(self.text_projection.bias)
664
+ else:
665
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
666
+
667
+ @torch.jit.ignore
668
+ def set_grad_checkpointing(self, enable=True):
669
+ self.transformer.grad_checkpointing = enable
670
+
671
+ def build_causal_mask(self):
672
+ # lazily create causal attention mask, with full attention between the tokens
673
+ # pytorch uses additive attention mask; fill with -inf
674
+ mask = torch.empty(self.num_pos, self.num_pos)
675
+ mask.fill_(float("-inf"))
676
+ mask.triu_(1) # zero out the lower diagonal
677
+ return mask
678
+
679
+ def build_cls_mask(self, text, cast_dtype: torch.dtype):
680
+ cls_mask = (text != self.pad_id).unsqueeze(1)
681
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
682
+ additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
683
+ additive_mask.fill_(0)
684
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
685
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
686
+ return additive_mask
687
+
688
+ def forward(self, text):
689
+ cast_dtype = self.transformer.get_cast_dtype()
690
+ seq_len = text.shape[1]
691
+
692
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
693
+ attn_mask = self.attn_mask
694
+ if self.cls_emb is not None:
695
+ seq_len += 1
696
+ x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)
697
+ cls_mask = self.build_cls_mask(text, cast_dtype)
698
+ if attn_mask is not None:
699
+ attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
700
+
701
+ x = x + self.positional_embedding[:seq_len].to(cast_dtype)
702
+ x = x.permute(1, 0, 2) # NLD -> LND
703
+ x = self.transformer(x, attn_mask=attn_mask)
704
+ x = x.permute(1, 0, 2) # LND -> NLD
705
+
706
+ # x.shape = [batch_size, n_ctx, transformer.width]
707
+ if self.cls_emb is not None:
708
+ # presence of appended cls embed (CoCa) overrides pool_type, always take last token
709
+ pooled, tokens = text_global_pool(x, pool_type='last')
710
+ pooled = self.ln_final(pooled) # final LN applied after pooling in this case
711
+ else:
712
+ x = self.ln_final(x)
713
+ pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
714
+
715
+ if self.text_projection is not None:
716
+ if isinstance(self.text_projection, nn.Linear):
717
+ pooled = self.text_projection(pooled)
718
+ else:
719
+ pooled = pooled @ self.text_projection
720
+
721
+ if self.output_tokens:
722
+ return pooled, tokens
723
+
724
+ return pooled
725
+
726
+
727
+ class MultimodalTransformer(Transformer):
728
+ def __init__(
729
+ self,
730
+ width: int,
731
+ layers: int,
732
+ heads: int,
733
+ context_length: int = 77,
734
+ mlp_ratio: float = 4.0,
735
+ ls_init_value: float = None,
736
+ act_layer: Callable = nn.GELU,
737
+ norm_layer: Callable = LayerNorm,
738
+ output_dim: int = 512,
739
+ ):
740
+
741
+ super().__init__(
742
+ width=width,
743
+ layers=layers,
744
+ heads=heads,
745
+ mlp_ratio=mlp_ratio,
746
+ ls_init_value=ls_init_value,
747
+ act_layer=act_layer,
748
+ norm_layer=norm_layer,
749
+ )
750
+ self.context_length = context_length
751
+ self.cross_attn = nn.ModuleList([
752
+ ResidualAttentionBlock(
753
+ width,
754
+ heads,
755
+ mlp_ratio,
756
+ ls_init_value=ls_init_value,
757
+ act_layer=act_layer,
758
+ norm_layer=norm_layer,
759
+ is_cross_attention=True,
760
+ )
761
+ for _ in range(layers)
762
+ ])
763
+
764
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
765
+
766
+ self.ln_final = norm_layer(width)
767
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
768
+
769
+ def init_parameters(self):
770
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
771
+ attn_std = self.transformer.width ** -0.5
772
+ fc_std = (2 * self.transformer.width) ** -0.5
773
+ for block in self.transformer.resblocks:
774
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
775
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
776
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
777
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
778
+ for block in self.transformer.cross_attn:
779
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
780
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
781
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
782
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
783
+
784
+ if self.text_projection is not None:
785
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
786
+
787
+ def build_attention_mask(self):
788
+ # lazily create causal attention mask, with full attention between the tokens
789
+ # pytorch uses additive attention mask; fill with -inf
790
+ mask = torch.empty(self.context_length, self.context_length)
791
+ mask.fill_(float("-inf"))
792
+ mask.triu_(1) # zero out the lower diagonal
793
+ return mask
794
+
795
+ def forward(self, image_embs, text_embs):
796
+ text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
797
+ image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
798
+ seq_len = text_embs.shape[0]
799
+
800
+ for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
801
+ if self.grad_checkpointing and not torch.jit.is_scripting():
802
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
803
+ text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
804
+ text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
805
+ else:
806
+ text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
807
+ text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
808
+
809
+ x = text_embs.permute(1, 0, 2) # LND -> NLD
810
+ x = self.ln_final(x)
811
+
812
+ if self.text_projection is not None:
813
+ x = x @ self.text_projection
814
+
815
+ return x
816
+
817
+ @torch.jit.ignore
818
+ def set_grad_checkpointing(self, enable=True):
819
+ self.grad_checkpointing = enable
820
+
821
+
822
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
823
+ # All rights reserved.
824
+
825
+ # This source code is licensed under the license found in the
826
+ # LICENSE file in the root directory of this source tree.
827
+ # --------------------------------------------------------
828
+ # Position embedding utils
829
+ # --------------------------------------------------------
830
+
831
+ import numpy as np
832
+
833
+ import torch
834
+
835
+ # --------------------------------------------------------
836
+ # 2D sine-cosine position embedding
837
+ # References:
838
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
839
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
840
+ # --------------------------------------------------------
841
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
842
+ """
843
+ grid_size: int of the grid height and width
844
+ return:
845
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
846
+ """
847
+ grid_h = np.arange(grid_size, dtype=np.float32)
848
+ grid_w = np.arange(grid_size, dtype=np.float32)
849
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
850
+ grid = np.stack(grid, axis=0)
851
+
852
+ grid = grid.reshape([2, 1, grid_size, grid_size])
853
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
854
+ if cls_token:
855
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
856
+ return pos_embed
857
+
858
+
859
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
860
+ assert embed_dim % 2 == 0
861
+
862
+ # use half of dimensions to encode grid_h
863
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
864
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
865
+
866
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
867
+ return emb
868
+
869
+
870
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
871
+ """
872
+ embed_dim: output dimension for each position
873
+ pos: a list of positions to be encoded: size (M,)
874
+ out: (M, D)
875
+ """
876
+ assert embed_dim % 2 == 0
877
+ omega = np.arange(embed_dim // 2, dtype=float)
878
+ omega /= embed_dim / 2.
879
+ omega = 1. / 10000**omega # (D/2,)
880
+
881
+ pos = pos.reshape(-1) # (M,)
882
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
883
+
884
+ emb_sin = np.sin(out) # (M, D/2)
885
+ emb_cos = np.cos(out) # (M, D/2)
886
+
887
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
888
+ return emb
889
+
890
+
891
+ # --------------------------------------------------------
892
+ # Interpolate position embeddings for high-resolution
893
+ # References:
894
+ # DeiT: https://github.com/facebookresearch/deit
895
+ # --------------------------------------------------------
896
+ def interpolate_pos_embed(model, checkpoint_model):
897
+ if 'pos_embed' in checkpoint_model:
898
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
899
+ embedding_size = pos_embed_checkpoint.shape[-1]
900
+ num_patches = model.patch_embed.num_patches
901
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
902
+ # height (== width) for the checkpoint position embedding
903
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
904
+ # height (== width) for the new position embedding
905
+ new_size = int(num_patches ** 0.5)
906
+ # class_token and dist_token are kept unchanged
907
+ if orig_size != new_size:
908
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
909
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
910
+ # only the position tokens are interpolated
911
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
912
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
913
+ pos_tokens = torch.nn.functional.interpolate(
914
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
915
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
916
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
917
+ checkpoint_model['pos_embed'] = new_pos_embed
config.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "G:\\Temp\\finetune_result\\LLAMA2-7B-CHAT_ViT-L-16-512_MOREKEYWORD_LN_PATCH_FINETUNE_ChexpertJSON_POSTTRAIN_25000_DIST",
3
+ "architectures": [
4
+ "CXRLLAVAModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "CXR_LLAVA_HF.CXRLLAVAConfig",
8
+ "AutoModel": "CXR_LLAVA_HF.CXRLLAVAModel"
9
+ },
10
+ "clip_embed_dim": 128,
11
+ "clip_quick_gelu": false,
12
+ "clip_vision_cfg": {
13
+ "image_size": 512,
14
+ "layers": 24,
15
+ "patch_size": 16,
16
+ "width": 1024
17
+ },
18
+ "clip_vision_tower_dtype": "bf16",
19
+ "clip_vision_tower_path": null,
20
+ "freeze_mm_mlp_adapter": false,
21
+ "image_aspect_ratio": "square",
22
+ "image_grid_pinpoints": null,
23
+ "image_preprocess_cfg": {
24
+ "mean": 0.5518136078431373,
25
+ "std": 0.3821719215686275
26
+ },
27
+ "llama": {
28
+ "_name_or_path": "/home/jovyan/llava/SW_LLAVA/LLAMA2-7B-CHAT_ViT-L-16-512_MOREKEYWORD_LN_PATCH_FINETUNE_ChexpertJSON_POSTTRAIN",
29
+ "add_cross_attention": false,
30
+ "architectures": [
31
+ "LlamaForCausalLM"
32
+ ],
33
+ "bad_words_ids": null,
34
+ "begin_suppress_tokens": null,
35
+ "bos_token_id": 1,
36
+ "chunk_size_feed_forward": 0,
37
+ "cross_attention_hidden_size": null,
38
+ "decoder_start_token_id": null,
39
+ "diversity_penalty": 0.0,
40
+ "do_sample": false,
41
+ "early_stopping": false,
42
+ "encoder_no_repeat_ngram_size": 0,
43
+ "eos_token_id": 2,
44
+ "exponential_decay_length_penalty": null,
45
+ "finetuning_task": null,
46
+ "forced_bos_token_id": null,
47
+ "forced_eos_token_id": null,
48
+ "hidden_act": "silu",
49
+ "hidden_size": 4096,
50
+ "id2label": {
51
+ "0": "LABEL_0",
52
+ "1": "LABEL_1"
53
+ },
54
+ "initializer_range": 0.02,
55
+ "intermediate_size": 11008,
56
+ "is_decoder": false,
57
+ "is_encoder_decoder": false,
58
+ "label2id": {
59
+ "LABEL_0": 0,
60
+ "LABEL_1": 1
61
+ },
62
+ "length_penalty": 1.0,
63
+ "max_length": 20,
64
+ "max_position_embeddings": 4096,
65
+ "min_length": 0,
66
+ "model_type": "llama",
67
+ "no_repeat_ngram_size": 0,
68
+ "num_attention_heads": 32,
69
+ "num_beam_groups": 1,
70
+ "num_beams": 1,
71
+ "num_hidden_layers": 32,
72
+ "num_key_value_heads": 32,
73
+ "num_return_sequences": 1,
74
+ "output_attentions": false,
75
+ "output_hidden_states": false,
76
+ "output_scores": false,
77
+ "pad_token_id": null,
78
+ "prefix": null,
79
+ "pretraining_tp": 1,
80
+ "problem_type": null,
81
+ "pruned_heads": {},
82
+ "remove_invalid_values": false,
83
+ "repetition_penalty": 1.0,
84
+ "return_dict": true,
85
+ "return_dict_in_generate": false,
86
+ "rms_norm_eps": 1e-06,
87
+ "rope_scaling": null,
88
+ "rope_theta": 10000.0,
89
+ "sep_token_id": null,
90
+ "suppress_tokens": null,
91
+ "task_specific_params": null,
92
+ "temperature": 1.0,
93
+ "tf_legacy_loss": false,
94
+ "tie_encoder_decoder": false,
95
+ "tie_word_embeddings": false,
96
+ "tokenizer_class": null,
97
+ "top_k": 50,
98
+ "top_p": 1.0,
99
+ "torch_dtype": "float16",
100
+ "torchscript": false,
101
+ "typical_p": 1.0,
102
+ "use_bfloat16": false,
103
+ "use_cache": true,
104
+ "vocab_size": 32000
105
+ },
106
+ "llama_model_dtype": "bf16",
107
+ "llama_model_path": "/home/jovyan/llava/SW_LLAVA/LLAMA2-7B-CHAT_ViT-L-16-512_MOREKEYWORD_LN_PATCH_FINETUNE_ChexpertJSON_POSTTRAIN",
108
+ "mm_projector_dim": 1024,
109
+ "mm_projector_dtype": "fp32",
110
+ "mm_projector_path": null,
111
+ "mm_use_im_patch_token": false,
112
+ "mm_use_im_start_end": false,
113
+ "model_type": "CXR-LLAVA",
114
+ "torch_dtype": "bfloat16",
115
+ "transformers_version": "4.34.0",
116
+ "tune_mm_mlp_adapter": false,
117
+ "use_cache": false
118
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.34.0"
4
+ }
pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8687c706d0d68518b4636a40eeaafed137946620c164f77473484e86f20c540a
3
+ size 9955046591
pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1179867a0a8cb96669b8c46670bd82928c6597e805982c623652a0bda776b4da
3
+ size 4137901382
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 14092742656
4
+ },
5
+ "weight_map": {
6
+ "llama.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
7
+ "llama.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
8
+ "llama.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
9
+ "llama.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
10
+ "llama.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
11
+ "llama.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
12
+ "llama.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
13
+ "llama.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
14
+ "llama.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
15
+ "llama.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
16
+ "llama.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
17
+ "llama.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
18
+ "llama.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
19
+ "llama.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
20
+ "llama.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
21
+ "llama.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
22
+ "llama.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
23
+ "llama.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
24
+ "llama.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
25
+ "llama.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
26
+ "llama.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
27
+ "llama.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
28
+ "llama.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
29
+ "llama.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
30
+ "llama.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
31
+ "llama.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
32
+ "llama.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
33
+ "llama.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
34
+ "llama.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
35
+ "llama.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
36
+ "llama.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
37
+ "llama.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
38
+ "llama.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
39
+ "llama.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
40
+ "llama.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
41
+ "llama.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
42
+ "llama.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
43
+ "llama.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
44
+ "llama.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
45
+ "llama.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
46
+ "llama.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
47
+ "llama.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
48
+ "llama.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
49
+ "llama.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
50
+ "llama.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
51
+ "llama.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
52
+ "llama.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
53
+ "llama.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
54
+ "llama.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
55
+ "llama.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
56
+ "llama.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
57
+ "llama.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
58
+ "llama.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
59
+ "llama.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
60
+ "llama.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
61
+ "llama.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
62
+ "llama.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
63
+ "llama.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
64
+ "llama.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
65
+ "llama.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
66
+ "llama.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
67
+ "llama.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
68
+ "llama.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
69
+ "llama.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
70
+ "llama.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
71
+ "llama.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
72
+ "llama.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
73
+ "llama.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
74
+ "llama.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
75
+ "llama.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
76
+ "llama.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
77
+ "llama.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
78
+ "llama.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
79
+ "llama.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
80
+ "llama.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
81
+ "llama.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
82
+ "llama.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
83
+ "llama.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
84
+ "llama.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
85
+ "llama.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
86
+ "llama.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
87
+ "llama.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
88
+ "llama.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
89
+ "llama.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
90
+ "llama.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
91
+ "llama.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
92
+ "llama.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
93
+ "llama.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
94
+ "llama.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
95
+ "llama.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
96
+ "llama.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
97
+ "llama.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
98
+ "llama.layers.18.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
99
+ "llama.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
100
+ "llama.layers.18.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
101
+ "llama.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
102
+ "llama.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
103
+ "llama.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
104
+ "llama.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
105
+ "llama.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
106
+ "llama.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
107
+ "llama.layers.19.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
108
+ "llama.layers.19.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
109
+ "llama.layers.19.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
110
+ "llama.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
111
+ "llama.layers.19.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
112
+ "llama.layers.19.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
113
+ "llama.layers.19.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
114
+ "llama.layers.19.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
115
+ "llama.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
116
+ "llama.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
117
+ "llama.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
118
+ "llama.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
119
+ "llama.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
120
+ "llama.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
121
+ "llama.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
122
+ "llama.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
123
+ "llama.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
124
+ "llama.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
125
+ "llama.layers.20.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
126
+ "llama.layers.20.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
127
+ "llama.layers.20.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
128
+ "llama.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
129
+ "llama.layers.20.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
130
+ "llama.layers.20.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
131
+ "llama.layers.20.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
132
+ "llama.layers.20.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
133
+ "llama.layers.21.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
134
+ "llama.layers.21.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
135
+ "llama.layers.21.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
136
+ "llama.layers.21.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
137
+ "llama.layers.21.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
138
+ "llama.layers.21.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
139
+ "llama.layers.21.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
140
+ "llama.layers.21.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
141
+ "llama.layers.21.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
142
+ "llama.layers.22.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
143
+ "llama.layers.22.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
144
+ "llama.layers.22.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
145
+ "llama.layers.22.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
146
+ "llama.layers.22.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
147
+ "llama.layers.22.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
148
+ "llama.layers.22.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
149
+ "llama.layers.22.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
150
+ "llama.layers.22.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
151
+ "llama.layers.23.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
152
+ "llama.layers.23.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
153
+ "llama.layers.23.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
154
+ "llama.layers.23.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
155
+ "llama.layers.23.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
156
+ "llama.layers.23.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
157
+ "llama.layers.23.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
158
+ "llama.layers.23.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
159
+ "llama.layers.23.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
160
+ "llama.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
161
+ "llama.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
162
+ "llama.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
163
+ "llama.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
164
+ "llama.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
165
+ "llama.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
166
+ "llama.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
167
+ "llama.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
168
+ "llama.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
169
+ "llama.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
170
+ "llama.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
171
+ "llama.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
172
+ "llama.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
173
+ "llama.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
174
+ "llama.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
175
+ "llama.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
176
+ "llama.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
177
+ "llama.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
178
+ "llama.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
179
+ "llama.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
180
+ "llama.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
181
+ "llama.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
182
+ "llama.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
183
+ "llama.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
184
+ "llama.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
185
+ "llama.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
186
+ "llama.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
187
+ "llama.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
188
+ "llama.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
189
+ "llama.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
190
+ "llama.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
191
+ "llama.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
192
+ "llama.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
193
+ "llama.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
194
+ "llama.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
195
+ "llama.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
196
+ "llama.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
197
+ "llama.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
198
+ "llama.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
199
+ "llama.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
200
+ "llama.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
201
+ "llama.layers.28.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
202
+ "llama.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
203
+ "llama.layers.28.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
204
+ "llama.layers.28.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
205
+ "llama.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
206
+ "llama.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
207
+ "llama.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
208
+ "llama.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
209
+ "llama.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
210
+ "llama.layers.29.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
211
+ "llama.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
212
+ "llama.layers.29.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
213
+ "llama.layers.29.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
214
+ "llama.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
215
+ "llama.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
216
+ "llama.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
217
+ "llama.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
218
+ "llama.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
219
+ "llama.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
220
+ "llama.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
221
+ "llama.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
222
+ "llama.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
223
+ "llama.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
224
+ "llama.layers.30.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
225
+ "llama.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
226
+ "llama.layers.30.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
227
+ "llama.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
228
+ "llama.layers.30.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
229
+ "llama.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
230
+ "llama.layers.30.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
231
+ "llama.layers.30.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
232
+ "llama.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
233
+ "llama.layers.31.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
234
+ "llama.layers.31.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
235
+ "llama.layers.31.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
236
+ "llama.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
237
+ "llama.layers.31.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
238
+ "llama.layers.31.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
239
+ "llama.layers.31.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
240
+ "llama.layers.31.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
241
+ "llama.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
242
+ "llama.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
243
+ "llama.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
244
+ "llama.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
245
+ "llama.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
246
+ "llama.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
247
+ "llama.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
248
+ "llama.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
249
+ "llama.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
250
+ "llama.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
251
+ "llama.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
252
+ "llama.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
253
+ "llama.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
254
+ "llama.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
255
+ "llama.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
256
+ "llama.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
257
+ "llama.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
258
+ "llama.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
259
+ "llama.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
260
+ "llama.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
261
+ "llama.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
262
+ "llama.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
263
+ "llama.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
264
+ "llama.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
265
+ "llama.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
266
+ "llama.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
267
+ "llama.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
268
+ "llama.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
269
+ "llama.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
270
+ "llama.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
271
+ "llama.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
272
+ "llama.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
273
+ "llama.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
274
+ "llama.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
275
+ "llama.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
276
+ "llama.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
277
+ "llama.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
278
+ "llama.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
279
+ "llama.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
280
+ "llama.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
281
+ "llama.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
282
+ "llama.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
283
+ "llama.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
284
+ "llama.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
285
+ "llama.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
286
+ "llama.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
287
+ "llama.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
288
+ "llama.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
289
+ "llama.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
290
+ "llama.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
291
+ "llama.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
292
+ "llama.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
293
+ "llama.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
294
+ "llama.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
295
+ "llama.norm.weight": "pytorch_model-00002-of-00002.bin",
296
+ "lm_head.weight": "pytorch_model-00001-of-00002.bin",
297
+ "mm_projector.bias": "pytorch_model-00001-of-00002.bin",
298
+ "mm_projector.weight": "pytorch_model-00001-of-00002.bin",
299
+ "vision_tower.class_embedding": "pytorch_model-00001-of-00002.bin",
300
+ "vision_tower.conv1.weight": "pytorch_model-00001-of-00002.bin",
301
+ "vision_tower.ln_post.bias": "pytorch_model-00001-of-00002.bin",
302
+ "vision_tower.ln_post.weight": "pytorch_model-00001-of-00002.bin",
303
+ "vision_tower.ln_pre.bias": "pytorch_model-00001-of-00002.bin",
304
+ "vision_tower.ln_pre.weight": "pytorch_model-00001-of-00002.bin",
305
+ "vision_tower.positional_embedding": "pytorch_model-00001-of-00002.bin",
306
+ "vision_tower.proj": "pytorch_model-00001-of-00002.bin",
307
+ "vision_tower.transformer.resblocks.0.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
308
+ "vision_tower.transformer.resblocks.0.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
309
+ "vision_tower.transformer.resblocks.0.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
310
+ "vision_tower.transformer.resblocks.0.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
311
+ "vision_tower.transformer.resblocks.0.ln_1.bias": "pytorch_model-00001-of-00002.bin",
312
+ "vision_tower.transformer.resblocks.0.ln_1.weight": "pytorch_model-00001-of-00002.bin",
313
+ "vision_tower.transformer.resblocks.0.ln_2.bias": "pytorch_model-00001-of-00002.bin",
314
+ "vision_tower.transformer.resblocks.0.ln_2.weight": "pytorch_model-00001-of-00002.bin",
315
+ "vision_tower.transformer.resblocks.0.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
316
+ "vision_tower.transformer.resblocks.0.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
317
+ "vision_tower.transformer.resblocks.0.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
318
+ "vision_tower.transformer.resblocks.0.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
319
+ "vision_tower.transformer.resblocks.1.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
320
+ "vision_tower.transformer.resblocks.1.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
321
+ "vision_tower.transformer.resblocks.1.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
322
+ "vision_tower.transformer.resblocks.1.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
323
+ "vision_tower.transformer.resblocks.1.ln_1.bias": "pytorch_model-00001-of-00002.bin",
324
+ "vision_tower.transformer.resblocks.1.ln_1.weight": "pytorch_model-00001-of-00002.bin",
325
+ "vision_tower.transformer.resblocks.1.ln_2.bias": "pytorch_model-00001-of-00002.bin",
326
+ "vision_tower.transformer.resblocks.1.ln_2.weight": "pytorch_model-00001-of-00002.bin",
327
+ "vision_tower.transformer.resblocks.1.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
328
+ "vision_tower.transformer.resblocks.1.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
329
+ "vision_tower.transformer.resblocks.1.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
330
+ "vision_tower.transformer.resblocks.1.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
331
+ "vision_tower.transformer.resblocks.10.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
332
+ "vision_tower.transformer.resblocks.10.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
333
+ "vision_tower.transformer.resblocks.10.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
334
+ "vision_tower.transformer.resblocks.10.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
335
+ "vision_tower.transformer.resblocks.10.ln_1.bias": "pytorch_model-00001-of-00002.bin",
336
+ "vision_tower.transformer.resblocks.10.ln_1.weight": "pytorch_model-00001-of-00002.bin",
337
+ "vision_tower.transformer.resblocks.10.ln_2.bias": "pytorch_model-00001-of-00002.bin",
338
+ "vision_tower.transformer.resblocks.10.ln_2.weight": "pytorch_model-00001-of-00002.bin",
339
+ "vision_tower.transformer.resblocks.10.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
340
+ "vision_tower.transformer.resblocks.10.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
341
+ "vision_tower.transformer.resblocks.10.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
342
+ "vision_tower.transformer.resblocks.10.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
343
+ "vision_tower.transformer.resblocks.11.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
344
+ "vision_tower.transformer.resblocks.11.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
345
+ "vision_tower.transformer.resblocks.11.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
346
+ "vision_tower.transformer.resblocks.11.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
347
+ "vision_tower.transformer.resblocks.11.ln_1.bias": "pytorch_model-00001-of-00002.bin",
348
+ "vision_tower.transformer.resblocks.11.ln_1.weight": "pytorch_model-00001-of-00002.bin",
349
+ "vision_tower.transformer.resblocks.11.ln_2.bias": "pytorch_model-00001-of-00002.bin",
350
+ "vision_tower.transformer.resblocks.11.ln_2.weight": "pytorch_model-00001-of-00002.bin",
351
+ "vision_tower.transformer.resblocks.11.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
352
+ "vision_tower.transformer.resblocks.11.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
353
+ "vision_tower.transformer.resblocks.11.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
354
+ "vision_tower.transformer.resblocks.11.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
355
+ "vision_tower.transformer.resblocks.12.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
356
+ "vision_tower.transformer.resblocks.12.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
357
+ "vision_tower.transformer.resblocks.12.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
358
+ "vision_tower.transformer.resblocks.12.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
359
+ "vision_tower.transformer.resblocks.12.ln_1.bias": "pytorch_model-00001-of-00002.bin",
360
+ "vision_tower.transformer.resblocks.12.ln_1.weight": "pytorch_model-00001-of-00002.bin",
361
+ "vision_tower.transformer.resblocks.12.ln_2.bias": "pytorch_model-00001-of-00002.bin",
362
+ "vision_tower.transformer.resblocks.12.ln_2.weight": "pytorch_model-00001-of-00002.bin",
363
+ "vision_tower.transformer.resblocks.12.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
364
+ "vision_tower.transformer.resblocks.12.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
365
+ "vision_tower.transformer.resblocks.12.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
366
+ "vision_tower.transformer.resblocks.12.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
367
+ "vision_tower.transformer.resblocks.13.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
368
+ "vision_tower.transformer.resblocks.13.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
369
+ "vision_tower.transformer.resblocks.13.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
370
+ "vision_tower.transformer.resblocks.13.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
371
+ "vision_tower.transformer.resblocks.13.ln_1.bias": "pytorch_model-00001-of-00002.bin",
372
+ "vision_tower.transformer.resblocks.13.ln_1.weight": "pytorch_model-00001-of-00002.bin",
373
+ "vision_tower.transformer.resblocks.13.ln_2.bias": "pytorch_model-00001-of-00002.bin",
374
+ "vision_tower.transformer.resblocks.13.ln_2.weight": "pytorch_model-00001-of-00002.bin",
375
+ "vision_tower.transformer.resblocks.13.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
376
+ "vision_tower.transformer.resblocks.13.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
377
+ "vision_tower.transformer.resblocks.13.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
378
+ "vision_tower.transformer.resblocks.13.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
379
+ "vision_tower.transformer.resblocks.14.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
380
+ "vision_tower.transformer.resblocks.14.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
381
+ "vision_tower.transformer.resblocks.14.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
382
+ "vision_tower.transformer.resblocks.14.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
383
+ "vision_tower.transformer.resblocks.14.ln_1.bias": "pytorch_model-00001-of-00002.bin",
384
+ "vision_tower.transformer.resblocks.14.ln_1.weight": "pytorch_model-00001-of-00002.bin",
385
+ "vision_tower.transformer.resblocks.14.ln_2.bias": "pytorch_model-00001-of-00002.bin",
386
+ "vision_tower.transformer.resblocks.14.ln_2.weight": "pytorch_model-00001-of-00002.bin",
387
+ "vision_tower.transformer.resblocks.14.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
388
+ "vision_tower.transformer.resblocks.14.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
389
+ "vision_tower.transformer.resblocks.14.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
390
+ "vision_tower.transformer.resblocks.14.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
391
+ "vision_tower.transformer.resblocks.15.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
392
+ "vision_tower.transformer.resblocks.15.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
393
+ "vision_tower.transformer.resblocks.15.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
394
+ "vision_tower.transformer.resblocks.15.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
395
+ "vision_tower.transformer.resblocks.15.ln_1.bias": "pytorch_model-00001-of-00002.bin",
396
+ "vision_tower.transformer.resblocks.15.ln_1.weight": "pytorch_model-00001-of-00002.bin",
397
+ "vision_tower.transformer.resblocks.15.ln_2.bias": "pytorch_model-00001-of-00002.bin",
398
+ "vision_tower.transformer.resblocks.15.ln_2.weight": "pytorch_model-00001-of-00002.bin",
399
+ "vision_tower.transformer.resblocks.15.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
400
+ "vision_tower.transformer.resblocks.15.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
401
+ "vision_tower.transformer.resblocks.15.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
402
+ "vision_tower.transformer.resblocks.15.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
403
+ "vision_tower.transformer.resblocks.16.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
404
+ "vision_tower.transformer.resblocks.16.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
405
+ "vision_tower.transformer.resblocks.16.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
406
+ "vision_tower.transformer.resblocks.16.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
407
+ "vision_tower.transformer.resblocks.16.ln_1.bias": "pytorch_model-00001-of-00002.bin",
408
+ "vision_tower.transformer.resblocks.16.ln_1.weight": "pytorch_model-00001-of-00002.bin",
409
+ "vision_tower.transformer.resblocks.16.ln_2.bias": "pytorch_model-00001-of-00002.bin",
410
+ "vision_tower.transformer.resblocks.16.ln_2.weight": "pytorch_model-00001-of-00002.bin",
411
+ "vision_tower.transformer.resblocks.16.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
412
+ "vision_tower.transformer.resblocks.16.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
413
+ "vision_tower.transformer.resblocks.16.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
414
+ "vision_tower.transformer.resblocks.16.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
415
+ "vision_tower.transformer.resblocks.17.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
416
+ "vision_tower.transformer.resblocks.17.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
417
+ "vision_tower.transformer.resblocks.17.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
418
+ "vision_tower.transformer.resblocks.17.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
419
+ "vision_tower.transformer.resblocks.17.ln_1.bias": "pytorch_model-00001-of-00002.bin",
420
+ "vision_tower.transformer.resblocks.17.ln_1.weight": "pytorch_model-00001-of-00002.bin",
421
+ "vision_tower.transformer.resblocks.17.ln_2.bias": "pytorch_model-00001-of-00002.bin",
422
+ "vision_tower.transformer.resblocks.17.ln_2.weight": "pytorch_model-00001-of-00002.bin",
423
+ "vision_tower.transformer.resblocks.17.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
424
+ "vision_tower.transformer.resblocks.17.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
425
+ "vision_tower.transformer.resblocks.17.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
426
+ "vision_tower.transformer.resblocks.17.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
427
+ "vision_tower.transformer.resblocks.18.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
428
+ "vision_tower.transformer.resblocks.18.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
429
+ "vision_tower.transformer.resblocks.18.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
430
+ "vision_tower.transformer.resblocks.18.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
431
+ "vision_tower.transformer.resblocks.18.ln_1.bias": "pytorch_model-00001-of-00002.bin",
432
+ "vision_tower.transformer.resblocks.18.ln_1.weight": "pytorch_model-00001-of-00002.bin",
433
+ "vision_tower.transformer.resblocks.18.ln_2.bias": "pytorch_model-00001-of-00002.bin",
434
+ "vision_tower.transformer.resblocks.18.ln_2.weight": "pytorch_model-00001-of-00002.bin",
435
+ "vision_tower.transformer.resblocks.18.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
436
+ "vision_tower.transformer.resblocks.18.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
437
+ "vision_tower.transformer.resblocks.18.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
438
+ "vision_tower.transformer.resblocks.18.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
439
+ "vision_tower.transformer.resblocks.19.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
440
+ "vision_tower.transformer.resblocks.19.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
441
+ "vision_tower.transformer.resblocks.19.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
442
+ "vision_tower.transformer.resblocks.19.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
443
+ "vision_tower.transformer.resblocks.19.ln_1.bias": "pytorch_model-00001-of-00002.bin",
444
+ "vision_tower.transformer.resblocks.19.ln_1.weight": "pytorch_model-00001-of-00002.bin",
445
+ "vision_tower.transformer.resblocks.19.ln_2.bias": "pytorch_model-00001-of-00002.bin",
446
+ "vision_tower.transformer.resblocks.19.ln_2.weight": "pytorch_model-00001-of-00002.bin",
447
+ "vision_tower.transformer.resblocks.19.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
448
+ "vision_tower.transformer.resblocks.19.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
449
+ "vision_tower.transformer.resblocks.19.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
450
+ "vision_tower.transformer.resblocks.19.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
451
+ "vision_tower.transformer.resblocks.2.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
452
+ "vision_tower.transformer.resblocks.2.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
453
+ "vision_tower.transformer.resblocks.2.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
454
+ "vision_tower.transformer.resblocks.2.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
455
+ "vision_tower.transformer.resblocks.2.ln_1.bias": "pytorch_model-00001-of-00002.bin",
456
+ "vision_tower.transformer.resblocks.2.ln_1.weight": "pytorch_model-00001-of-00002.bin",
457
+ "vision_tower.transformer.resblocks.2.ln_2.bias": "pytorch_model-00001-of-00002.bin",
458
+ "vision_tower.transformer.resblocks.2.ln_2.weight": "pytorch_model-00001-of-00002.bin",
459
+ "vision_tower.transformer.resblocks.2.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
460
+ "vision_tower.transformer.resblocks.2.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
461
+ "vision_tower.transformer.resblocks.2.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
462
+ "vision_tower.transformer.resblocks.2.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
463
+ "vision_tower.transformer.resblocks.20.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
464
+ "vision_tower.transformer.resblocks.20.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
465
+ "vision_tower.transformer.resblocks.20.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
466
+ "vision_tower.transformer.resblocks.20.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
467
+ "vision_tower.transformer.resblocks.20.ln_1.bias": "pytorch_model-00001-of-00002.bin",
468
+ "vision_tower.transformer.resblocks.20.ln_1.weight": "pytorch_model-00001-of-00002.bin",
469
+ "vision_tower.transformer.resblocks.20.ln_2.bias": "pytorch_model-00001-of-00002.bin",
470
+ "vision_tower.transformer.resblocks.20.ln_2.weight": "pytorch_model-00001-of-00002.bin",
471
+ "vision_tower.transformer.resblocks.20.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
472
+ "vision_tower.transformer.resblocks.20.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
473
+ "vision_tower.transformer.resblocks.20.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
474
+ "vision_tower.transformer.resblocks.20.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
475
+ "vision_tower.transformer.resblocks.21.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
476
+ "vision_tower.transformer.resblocks.21.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
477
+ "vision_tower.transformer.resblocks.21.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
478
+ "vision_tower.transformer.resblocks.21.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
479
+ "vision_tower.transformer.resblocks.21.ln_1.bias": "pytorch_model-00001-of-00002.bin",
480
+ "vision_tower.transformer.resblocks.21.ln_1.weight": "pytorch_model-00001-of-00002.bin",
481
+ "vision_tower.transformer.resblocks.21.ln_2.bias": "pytorch_model-00001-of-00002.bin",
482
+ "vision_tower.transformer.resblocks.21.ln_2.weight": "pytorch_model-00001-of-00002.bin",
483
+ "vision_tower.transformer.resblocks.21.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
484
+ "vision_tower.transformer.resblocks.21.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
485
+ "vision_tower.transformer.resblocks.21.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
486
+ "vision_tower.transformer.resblocks.21.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
487
+ "vision_tower.transformer.resblocks.22.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
488
+ "vision_tower.transformer.resblocks.22.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
489
+ "vision_tower.transformer.resblocks.22.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
490
+ "vision_tower.transformer.resblocks.22.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
491
+ "vision_tower.transformer.resblocks.22.ln_1.bias": "pytorch_model-00001-of-00002.bin",
492
+ "vision_tower.transformer.resblocks.22.ln_1.weight": "pytorch_model-00001-of-00002.bin",
493
+ "vision_tower.transformer.resblocks.22.ln_2.bias": "pytorch_model-00001-of-00002.bin",
494
+ "vision_tower.transformer.resblocks.22.ln_2.weight": "pytorch_model-00001-of-00002.bin",
495
+ "vision_tower.transformer.resblocks.22.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
496
+ "vision_tower.transformer.resblocks.22.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
497
+ "vision_tower.transformer.resblocks.22.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
498
+ "vision_tower.transformer.resblocks.22.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
499
+ "vision_tower.transformer.resblocks.23.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
500
+ "vision_tower.transformer.resblocks.23.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
501
+ "vision_tower.transformer.resblocks.23.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
502
+ "vision_tower.transformer.resblocks.23.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
503
+ "vision_tower.transformer.resblocks.23.ln_1.bias": "pytorch_model-00001-of-00002.bin",
504
+ "vision_tower.transformer.resblocks.23.ln_1.weight": "pytorch_model-00001-of-00002.bin",
505
+ "vision_tower.transformer.resblocks.23.ln_2.bias": "pytorch_model-00001-of-00002.bin",
506
+ "vision_tower.transformer.resblocks.23.ln_2.weight": "pytorch_model-00001-of-00002.bin",
507
+ "vision_tower.transformer.resblocks.23.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
508
+ "vision_tower.transformer.resblocks.23.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
509
+ "vision_tower.transformer.resblocks.23.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
510
+ "vision_tower.transformer.resblocks.23.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
511
+ "vision_tower.transformer.resblocks.3.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
512
+ "vision_tower.transformer.resblocks.3.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
513
+ "vision_tower.transformer.resblocks.3.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
514
+ "vision_tower.transformer.resblocks.3.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
515
+ "vision_tower.transformer.resblocks.3.ln_1.bias": "pytorch_model-00001-of-00002.bin",
516
+ "vision_tower.transformer.resblocks.3.ln_1.weight": "pytorch_model-00001-of-00002.bin",
517
+ "vision_tower.transformer.resblocks.3.ln_2.bias": "pytorch_model-00001-of-00002.bin",
518
+ "vision_tower.transformer.resblocks.3.ln_2.weight": "pytorch_model-00001-of-00002.bin",
519
+ "vision_tower.transformer.resblocks.3.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
520
+ "vision_tower.transformer.resblocks.3.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
521
+ "vision_tower.transformer.resblocks.3.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
522
+ "vision_tower.transformer.resblocks.3.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
523
+ "vision_tower.transformer.resblocks.4.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
524
+ "vision_tower.transformer.resblocks.4.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
525
+ "vision_tower.transformer.resblocks.4.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
526
+ "vision_tower.transformer.resblocks.4.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
527
+ "vision_tower.transformer.resblocks.4.ln_1.bias": "pytorch_model-00001-of-00002.bin",
528
+ "vision_tower.transformer.resblocks.4.ln_1.weight": "pytorch_model-00001-of-00002.bin",
529
+ "vision_tower.transformer.resblocks.4.ln_2.bias": "pytorch_model-00001-of-00002.bin",
530
+ "vision_tower.transformer.resblocks.4.ln_2.weight": "pytorch_model-00001-of-00002.bin",
531
+ "vision_tower.transformer.resblocks.4.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
532
+ "vision_tower.transformer.resblocks.4.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
533
+ "vision_tower.transformer.resblocks.4.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
534
+ "vision_tower.transformer.resblocks.4.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
535
+ "vision_tower.transformer.resblocks.5.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
536
+ "vision_tower.transformer.resblocks.5.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
537
+ "vision_tower.transformer.resblocks.5.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
538
+ "vision_tower.transformer.resblocks.5.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
539
+ "vision_tower.transformer.resblocks.5.ln_1.bias": "pytorch_model-00001-of-00002.bin",
540
+ "vision_tower.transformer.resblocks.5.ln_1.weight": "pytorch_model-00001-of-00002.bin",
541
+ "vision_tower.transformer.resblocks.5.ln_2.bias": "pytorch_model-00001-of-00002.bin",
542
+ "vision_tower.transformer.resblocks.5.ln_2.weight": "pytorch_model-00001-of-00002.bin",
543
+ "vision_tower.transformer.resblocks.5.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
544
+ "vision_tower.transformer.resblocks.5.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
545
+ "vision_tower.transformer.resblocks.5.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
546
+ "vision_tower.transformer.resblocks.5.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
547
+ "vision_tower.transformer.resblocks.6.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
548
+ "vision_tower.transformer.resblocks.6.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
549
+ "vision_tower.transformer.resblocks.6.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
550
+ "vision_tower.transformer.resblocks.6.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
551
+ "vision_tower.transformer.resblocks.6.ln_1.bias": "pytorch_model-00001-of-00002.bin",
552
+ "vision_tower.transformer.resblocks.6.ln_1.weight": "pytorch_model-00001-of-00002.bin",
553
+ "vision_tower.transformer.resblocks.6.ln_2.bias": "pytorch_model-00001-of-00002.bin",
554
+ "vision_tower.transformer.resblocks.6.ln_2.weight": "pytorch_model-00001-of-00002.bin",
555
+ "vision_tower.transformer.resblocks.6.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
556
+ "vision_tower.transformer.resblocks.6.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
557
+ "vision_tower.transformer.resblocks.6.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
558
+ "vision_tower.transformer.resblocks.6.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
559
+ "vision_tower.transformer.resblocks.7.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
560
+ "vision_tower.transformer.resblocks.7.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
561
+ "vision_tower.transformer.resblocks.7.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
562
+ "vision_tower.transformer.resblocks.7.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
563
+ "vision_tower.transformer.resblocks.7.ln_1.bias": "pytorch_model-00001-of-00002.bin",
564
+ "vision_tower.transformer.resblocks.7.ln_1.weight": "pytorch_model-00001-of-00002.bin",
565
+ "vision_tower.transformer.resblocks.7.ln_2.bias": "pytorch_model-00001-of-00002.bin",
566
+ "vision_tower.transformer.resblocks.7.ln_2.weight": "pytorch_model-00001-of-00002.bin",
567
+ "vision_tower.transformer.resblocks.7.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
568
+ "vision_tower.transformer.resblocks.7.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
569
+ "vision_tower.transformer.resblocks.7.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
570
+ "vision_tower.transformer.resblocks.7.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
571
+ "vision_tower.transformer.resblocks.8.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
572
+ "vision_tower.transformer.resblocks.8.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
573
+ "vision_tower.transformer.resblocks.8.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
574
+ "vision_tower.transformer.resblocks.8.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
575
+ "vision_tower.transformer.resblocks.8.ln_1.bias": "pytorch_model-00001-of-00002.bin",
576
+ "vision_tower.transformer.resblocks.8.ln_1.weight": "pytorch_model-00001-of-00002.bin",
577
+ "vision_tower.transformer.resblocks.8.ln_2.bias": "pytorch_model-00001-of-00002.bin",
578
+ "vision_tower.transformer.resblocks.8.ln_2.weight": "pytorch_model-00001-of-00002.bin",
579
+ "vision_tower.transformer.resblocks.8.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
580
+ "vision_tower.transformer.resblocks.8.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
581
+ "vision_tower.transformer.resblocks.8.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
582
+ "vision_tower.transformer.resblocks.8.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
583
+ "vision_tower.transformer.resblocks.9.attn.in_proj_bias": "pytorch_model-00001-of-00002.bin",
584
+ "vision_tower.transformer.resblocks.9.attn.in_proj_weight": "pytorch_model-00001-of-00002.bin",
585
+ "vision_tower.transformer.resblocks.9.attn.out_proj.bias": "pytorch_model-00001-of-00002.bin",
586
+ "vision_tower.transformer.resblocks.9.attn.out_proj.weight": "pytorch_model-00001-of-00002.bin",
587
+ "vision_tower.transformer.resblocks.9.ln_1.bias": "pytorch_model-00001-of-00002.bin",
588
+ "vision_tower.transformer.resblocks.9.ln_1.weight": "pytorch_model-00001-of-00002.bin",
589
+ "vision_tower.transformer.resblocks.9.ln_2.bias": "pytorch_model-00001-of-00002.bin",
590
+ "vision_tower.transformer.resblocks.9.ln_2.weight": "pytorch_model-00001-of-00002.bin",
591
+ "vision_tower.transformer.resblocks.9.mlp.c_fc.bias": "pytorch_model-00001-of-00002.bin",
592
+ "vision_tower.transformer.resblocks.9.mlp.c_fc.weight": "pytorch_model-00001-of-00002.bin",
593
+ "vision_tower.transformer.resblocks.9.mlp.c_proj.bias": "pytorch_model-00001-of-00002.bin",
594
+ "vision_tower.transformer.resblocks.9.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin"
595
+ }
596
+ }