williamium commited on
Commit
768c6c0
·
verified ·
1 Parent(s): d435bb4

Upload TinyLlavaForConditionalGeneration

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. configuration_tinyllava.py +146 -0
  3. modeling_tinyllava.py +397 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "TinyLlavaForConditionalGeneration"
4
  ],
 
 
 
 
5
  "cache_dir": null,
6
  "connector_type": "mlp2x_gelu",
7
  "hidden_size": 4096,
 
2
  "architectures": [
3
  "TinyLlavaForConditionalGeneration"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_tinyllava.TinyLlavaConfig",
7
+ "AutoModelForCausalLM": "modeling_tinyllava.TinyLlavaForConditionalGeneration"
8
+ },
9
  "cache_dir": null,
10
  "connector_type": "mlp2x_gelu",
11
  "hidden_size": 4096,
configuration_tinyllava.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, LlavaConfig
2
+ from transformers import CONFIG_MAPPING
3
+ from transformers import AutoConfig
4
+
5
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
6
+ WORKER_HEART_BEAT_INTERVAL = 15
7
+
8
+ LOGDIR = "."
9
+
10
+ # Model Constants
11
+ IGNORE_INDEX = -100
12
+ IMAGE_TOKEN_INDEX = -200
13
+ DEFAULT_IMAGE_TOKEN = "<image>"
14
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
15
+ DEFAULT_IM_START_TOKEN = "<im_start>"
16
+ DEFAULT_IM_END_TOKEN = "<im_end>"
17
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
18
+
19
+ class TinyLlavaConfig(PretrainedConfig):
20
+
21
+ model_type = "tinyllava"
22
+ def __init__(
23
+ self,
24
+ llm_model_name_or_path = '',
25
+ tokenizer_name_or_path = None,
26
+ vision_model_name_or_path = '',
27
+ vision_model_name_or_path2 = '',
28
+ connector_type = None,
29
+ text_config=None,
30
+ hidden_size=2048,
31
+ vocab_size=32000,
32
+ ignore_index=-100,
33
+ image_token_index=32000,
34
+ pad_token = None,
35
+ pad_token_id = None,
36
+ tokenizer_padding_side = 'right',
37
+ tokenizer_model_max_length = 2048,
38
+ vision_config = None,
39
+ vision_hidden_size = None,
40
+ vision_feature_layer = -2,
41
+ vision_feature_select_strategy = 'patch',
42
+ image_aspect_ratio = 'square',
43
+ resampler_hidden_size = None,
44
+ num_queries = None,
45
+ num_resampler_layers = None,
46
+ use_cache = False,
47
+ cache_dir = None,
48
+ tokenizer_use_fast = False,
49
+ tune_type_llm = 'frozen',
50
+ tune_type_connector = 'frozen',
51
+ tune_type_vision_tower = 'frozen',
52
+ tune_vision_tower_from_layer = -1,
53
+
54
+ **kwargs
55
+
56
+ ):
57
+ self.llm_model_name_or_path = llm_model_name_or_path
58
+ self.tokenizer_name_or_path = tokenizer_name_or_path or self.llm_model_name_or_path
59
+ self.vision_model_name_or_path = vision_model_name_or_path
60
+ self.vision_model_name_or_path2 = vision_model_name_or_path2
61
+ self.connector_type = connector_type
62
+ self.tune_type_llm = tune_type_llm
63
+ self.tune_type_connector = tune_type_connector
64
+ self.tune_type_vision_tower = tune_type_vision_tower
65
+ self.tune_vision_tower_from_layer = tune_vision_tower_from_layer
66
+
67
+ self.ignore_index = IGNORE_INDEX
68
+ self.image_token_index = IMAGE_TOKEN_INDEX
69
+ self.pad_token = pad_token
70
+ self.pad_token_id = pad_token_id
71
+ self.tokenizer_padding_side = tokenizer_padding_side
72
+ self.tokenizer_model_max_length = tokenizer_model_max_length
73
+ self.vision_feature_layer = vision_feature_layer
74
+ self.vision_feature_select_strategy = vision_feature_select_strategy
75
+ self.image_aspect_ratio = image_aspect_ratio
76
+ self.resampler_hidden_size = resampler_hidden_size
77
+ self.num_queries = num_queries
78
+ self.num_resampler_layers = num_resampler_layers
79
+ self.use_cache = use_cache
80
+ self.cache_dir = cache_dir
81
+ self.tokenizer_use_fast = tokenizer_use_fast
82
+ self._load_text_config(text_config)
83
+ self._load_vision_config(vision_config)
84
+
85
+ super().__init__(**kwargs)
86
+
87
+ def load_from_config(self, config):
88
+ self.llm_model_name_or_path = getattr(config, 'model_name_or_path', '')
89
+ self.tokenizer_name_or_path = getattr(config, 'tokenizer_name_or_path', None) or self.llm_model_name_or_path
90
+ self.vision_model_name_or_path = getattr(config, 'vision_tower', '')
91
+ self.vision_model_name_or_path2 = getattr(config, 'vision_tower2', '')
92
+ self.connector_type = getattr(config, 'connector_type', None)
93
+ self.vision_feature_layer = getattr(config, 'mm_vision_select_layer', -2)
94
+ self.vision_feature_select_strategy = getattr(config, 'mm_vision_select_feature', "patch")
95
+ self.image_aspect_ratio = getattr(config, 'image_aspect_ratio', "pad")
96
+ self.resampler_hidden_size = getattr(config, 'resampler_hidden_size', None)
97
+ self.num_queries = getattr(config, 'num_queries', None)
98
+ self.num_resampler_layers = getattr(config, 'num_resampler_layers', None)
99
+
100
+ self.cache_dir = getattr(config, 'cache_dir', None)
101
+ self.tokenizer_use_fast = getattr(config, 'tokenizer_use_fast', False)
102
+ self.tokenizer_model_max_length = getattr(config, 'model_max_length', 2048)
103
+ self.tokenizer_padding_side = getattr(config, 'tokenizer_padding_side', 'right')
104
+
105
+ self._load_text_config()
106
+ self._load_vision_config()
107
+
108
+
109
+ def _load_text_config(self, text_config=None):
110
+ if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '':
111
+ self.text_config = CONFIG_MAPPING['llama']()
112
+
113
+ else:
114
+ self.text_config = AutoConfig.from_pretrained(self.llm_model_name_or_path, trust_remote_code=True)
115
+ if text_config is not None:
116
+ self.text_config = self.text_config.from_dict(text_config)
117
+
118
+ self.hidden_size = getattr(self.text_config, 'hidden_size', getattr(self.text_config, 'model_dim', None))
119
+ self.vocab_size = getattr(self.text_config, 'vocab_size', None)
120
+
121
+
122
+
123
+ def _load_vision_config(self, vision_config=None):
124
+ if self.vision_model_name_or_path is None or self.vision_model_name_or_path == '':
125
+ self.vision_config = CONFIG_MAPPING['clip_vision_model'](
126
+ intermediate_size=4096,
127
+ hidden_size=1024,
128
+ patch_size=14,
129
+ image_size=336,
130
+ num_hidden_layers=24,
131
+ num_attention_heads=16,
132
+ vocab_size=32000,
133
+ projection_dim=768,
134
+ )
135
+
136
+ else:
137
+ self.vision_config = AutoConfig.from_pretrained(self.vision_model_name_or_path.split(':')[-1])
138
+ self.vision_config = getattr(self.vision_config, 'vision_config', self.vision_config)
139
+ if vision_config is not None:
140
+ self.vision_config = self.vision_config.from_dict(vision_config)
141
+
142
+ self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1]
143
+ self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1]
144
+ self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None)
145
+
146
+
modeling_tinyllava.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+ import ast
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+
9
+ from transformers import PreTrainedModel
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ from transformers.generation.utils import GenerateOutput
12
+
13
+ from . import LLMFactory, ConnectorFactory, VisionTowerFactory
14
+ from .configuration_tinyllava import TinyLlavaConfig
15
+
16
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
17
+ WORKER_HEART_BEAT_INTERVAL = 15
18
+
19
+ LOGDIR = "."
20
+
21
+ # Model Constants
22
+ IGNORE_INDEX = -100
23
+ IMAGE_TOKEN_INDEX = -200
24
+ DEFAULT_IMAGE_TOKEN = "<image>"
25
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
26
+ DEFAULT_IM_START_TOKEN = "<im_start>"
27
+ DEFAULT_IM_END_TOKEN = "<im_end>"
28
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
29
+ # from tinyllava.utils.data_utils import get_value_from_kwargs
30
+
31
+ def get_value_from_kwargs(kwargs, name):
32
+ if name in kwargs:
33
+ return kwargs.pop(name)
34
+ else:
35
+ return None
36
+
37
+
38
+
39
+ class TinyLlavaPreTrainedModel(PreTrainedModel):
40
+ config_class = TinyLlavaConfig
41
+ base_model_prefix = "model"
42
+ supports_gradient_checkpointing = True
43
+ _no_split_modules = ["LlavaVisionAttention"]
44
+ _skip_keys_device_placement = "past_key_values"
45
+ _supports_flash_attn_2 = True
46
+
47
+ def _init_weights(self, module):
48
+ std = (
49
+ self.config.initializer_range
50
+ if hasattr(self.config, "initializer_range")
51
+ else self.config.text_config.initializer_range
52
+ )
53
+
54
+ if hasattr(module, "class_embedding"):
55
+ module.class_embedding.data.normal_(mean=0.0, std=std)
56
+
57
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
58
+ module.weight.data.normal_(mean=0.0, std=std)
59
+ if module.bias is not None:
60
+ module.bias.data.zero_()
61
+ elif isinstance(module, nn.Embedding):
62
+ module.weight.data.normal_(mean=0.0, std=std)
63
+ if module.padding_idx is not None:
64
+ module.weight.data[module.padding_idx].zero_()
65
+
66
+ @property
67
+ def _supports_sdpa(self):
68
+ return self.language_model._supports_sdpa
69
+
70
+
71
+ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
72
+ def __init__(self, config: TinyLlavaConfig):
73
+
74
+ super().__init__(config)
75
+
76
+ self.language_model = LLMFactory(config.llm_model_name_or_path)[0](config.text_config)
77
+ self.vision_tower = VisionTowerFactory(config.vision_model_name_or_path)(config.vision_config)
78
+ self.connector = ConnectorFactory(config.connector_type)(config)
79
+
80
+ (Tokenizer, post_load) = LLMFactory(config.llm_model_name_or_path)[1]
81
+ self.tokenizer = post_load(Tokenizer.from_pretrained(
82
+ config.tokenizer_name_or_path,
83
+ cache_dir = config.cache_dir,
84
+ model_max_length = config.tokenizer_model_max_length,
85
+ padding_side = config.tokenizer_padding_side,
86
+ use_fast = config.tokenizer_use_fast,
87
+ ))
88
+ self.post_init()
89
+
90
+
91
+ def get_input_embeddings(self):
92
+ return self.language_model.get_input_embeddings()
93
+
94
+ def set_input_embeddings(self, value):
95
+ self.language_model.set_input_embeddings(value)
96
+
97
+ def get_output_embeddings(self):
98
+ return self.language_model.get_output_embeddings()
99
+
100
+ def set_output_embeddings(self, new_embeddings):
101
+ self.language_model.set_output_embeddings(new_embeddings)
102
+
103
+ def set_decoder(self, decoder):
104
+ self.language_model.set_decoder(decoder)
105
+
106
+ def get_decoder(self):
107
+ return self.language_model.get_decoder()
108
+
109
+ def tie_weights(self):
110
+ return self.language_model.tie_weights()
111
+
112
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
113
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
114
+ # update vocab size
115
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
116
+ self.config.vocab_size = model_embeds.num_embeddings
117
+ self.vocab_size = model_embeds.num_embeddings
118
+ return model_embeds
119
+
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: torch.LongTensor = None,
124
+ attention_mask: Optional[torch.Tensor] = None,
125
+ position_ids: Optional[torch.LongTensor] = None,
126
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
127
+ inputs_embeds: Optional[torch.FloatTensor] = None,
128
+ labels: Optional[torch.LongTensor] = None,
129
+ use_cache: Optional[bool] = None,
130
+ output_attentions: Optional[bool] = None,
131
+ output_hidden_states: Optional[bool] = None,
132
+ images: Optional[torch.FloatTensor] = None,
133
+ image_sizes: Optional[List[List[int]]] = None,
134
+ return_dict: Optional[bool] = None,
135
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
136
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
137
+ if inputs_embeds is None:
138
+ (
139
+ input_ids,
140
+ position_ids,
141
+ attention_mask,
142
+ past_key_values,
143
+ inputs_embeds,
144
+ labels
145
+ ) = self.prepare_inputs_labels_for_multimodal(
146
+ input_ids,
147
+ position_ids,
148
+ attention_mask,
149
+ past_key_values,
150
+ labels,
151
+ images,
152
+ image_sizes
153
+ )
154
+ return self.language_model.forward(
155
+ input_ids=input_ids,
156
+ attention_mask=attention_mask,
157
+ position_ids=position_ids,
158
+ past_key_values=past_key_values,
159
+ inputs_embeds=inputs_embeds,
160
+ labels=labels,
161
+ use_cache=use_cache,
162
+ output_attentions=output_attentions,
163
+ output_hidden_states=output_hidden_states,
164
+ return_dict=return_dict
165
+ )
166
+
167
+ @torch.no_grad()
168
+ def generate(
169
+ self,
170
+ inputs: Optional[torch.Tensor] = None,
171
+ images: Optional[torch.Tensor] = None,
172
+ image_sizes: Optional[torch.Tensor] = None,
173
+ **kwargs,
174
+ ) -> Union[GenerateOutput, torch.LongTensor]:
175
+ position_ids = kwargs.pop("position_ids", None)
176
+ attention_mask = kwargs.pop("attention_mask", None)
177
+ if "inputs_embeds" in kwargs:
178
+ raise NotImplementedError("`inputs_embeds` is not supported")
179
+
180
+ if images is not None:
181
+ (
182
+ inputs,
183
+ position_ids,
184
+ attention_mask,
185
+ _,
186
+ inputs_embeds,
187
+ _
188
+ ) = self.prepare_inputs_labels_for_multimodal(
189
+ inputs,
190
+ position_ids,
191
+ attention_mask,
192
+ None,
193
+ None,
194
+ images,
195
+ image_sizes=image_sizes
196
+ )
197
+ else:
198
+ inputs_embeds = self.language_model.get_input_embeddings()(inputs)
199
+
200
+ return self.language_model.generate(
201
+ position_ids=position_ids,
202
+ attention_mask=attention_mask,
203
+ inputs_embeds=inputs_embeds,
204
+ **kwargs
205
+ )
206
+
207
+ def encode_images(self, images):
208
+ kwargs = {}
209
+ kwargs['vision_feature_layer'] = self.config.vision_feature_layer
210
+ kwargs['vision_feature_select_strategy'] = self.config.vision_feature_select_strategy
211
+ images = images.to(device=self.device, dtype=self.dtype)
212
+ image_features = self.vision_tower(images, **kwargs)
213
+ image_features = self.connector(image_features)
214
+ return image_features
215
+
216
+
217
+
218
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
219
+ inputs_embeds=None, **kwargs):
220
+ images = kwargs.pop("images", None)
221
+ image_sizes = kwargs.pop("image_sizes", None)
222
+ inputs = self.language_model.prepare_inputs_for_generation(
223
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
224
+ )
225
+ if images is not None:
226
+ inputs['images'] = images
227
+ if image_sizes is not None:
228
+ inputs['image_sizes'] = image_sizes
229
+ return inputs
230
+
231
+ def prepare_inputs_labels_for_multimodal(
232
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
233
+ images, image_sizes=None
234
+ ):
235
+ vision_tower = self.vision_tower
236
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
237
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
238
+
239
+
240
+ image_features = self.encode_images(images)
241
+
242
+ # TODO: image start / end is not implemented here to support pretraining.
243
+ if getattr(self.config, 'tune_mm_mlp_adapter', False):
244
+ raise NotImplementedError
245
+
246
+ # Let's just add dummy tensors if they do not exist,
247
+ # it is a headache to deal with None all the time.
248
+ # But it is not ideal, and if you have a better idea,
249
+ # please open an issue / submit a PR, thanks.
250
+ _labels = labels
251
+ _position_ids = position_ids
252
+ _attention_mask = attention_mask
253
+ if attention_mask is None:
254
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
255
+ else:
256
+ attention_mask = attention_mask.bool()
257
+ if position_ids is None:
258
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
259
+ if labels is None:
260
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
261
+
262
+ # remove the padding using attention_mask -- FIXME
263
+ _input_ids = input_ids
264
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
265
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
266
+
267
+ new_input_embeds = []
268
+ new_labels = []
269
+ cur_image_idx = 0
270
+ for batch_idx, cur_input_ids in enumerate(input_ids):
271
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
272
+ if num_images == 0:
273
+ cur_image_features = image_features[cur_image_idx]
274
+ cur_input_embeds_1 = self.language_model.get_input_embeddings()(cur_input_ids)
275
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
276
+ new_input_embeds.append(cur_input_embeds)
277
+ new_labels.append(labels[batch_idx])
278
+ cur_image_idx += 1
279
+ continue
280
+
281
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
282
+ cur_input_ids_noim = []
283
+ cur_labels = labels[batch_idx]
284
+ cur_labels_noim = []
285
+ for i in range(len(image_token_indices) - 1):
286
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
287
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
288
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
289
+ cur_input_embeds = self.language_model.get_input_embeddings()(torch.cat(cur_input_ids_noim))
290
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
291
+ cur_new_input_embeds = []
292
+ cur_new_labels = []
293
+
294
+ for i in range(num_images + 1):
295
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
296
+ cur_new_labels.append(cur_labels_noim[i])
297
+ if i < num_images:
298
+ cur_image_features = image_features[cur_image_idx]
299
+ cur_image_idx += 1
300
+ cur_new_input_embeds.append(cur_image_features)
301
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
302
+
303
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
304
+
305
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
306
+ cur_new_labels = torch.cat(cur_new_labels)
307
+
308
+ new_input_embeds.append(cur_new_input_embeds)
309
+ new_labels.append(cur_new_labels)
310
+
311
+ # Truncate sequences to max length as image embeddings can make the sequence longer
312
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
313
+ if tokenizer_model_max_length is not None:
314
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
315
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
316
+
317
+ # Combine them
318
+ max_len = max(x.shape[0] for x in new_input_embeds)
319
+ batch_size = len(new_input_embeds)
320
+
321
+ new_input_embeds_padded = []
322
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
323
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
324
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
325
+
326
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
327
+ cur_len = cur_new_embed.shape[0]
328
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
329
+ new_input_embeds_padded.append(torch.cat((
330
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
331
+ cur_new_embed
332
+ ), dim=0))
333
+ if cur_len > 0:
334
+ new_labels_padded[i, -cur_len:] = cur_new_labels
335
+ attention_mask[i, -cur_len:] = True
336
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
337
+ else:
338
+ new_input_embeds_padded.append(torch.cat((
339
+ cur_new_embed,
340
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
341
+ ), dim=0))
342
+ if cur_len > 0:
343
+ new_labels_padded[i, :cur_len] = cur_new_labels
344
+ attention_mask[i, :cur_len] = True
345
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
346
+
347
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
348
+
349
+ if _labels is None:
350
+ new_labels = None
351
+ else:
352
+ new_labels = new_labels_padded
353
+
354
+ if _attention_mask is None:
355
+ attention_mask = None
356
+ else:
357
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
358
+
359
+ if _position_ids is None:
360
+ position_ids = None
361
+
362
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
363
+
364
+
365
+
366
+
367
+ def load_llm(self, **kwargs):
368
+ language_model_name = get_value_from_kwargs(kwargs, 'model_name_or_path')
369
+ pretrained_llm_path = get_value_from_kwargs(kwargs, 'pretrained_llm_path')
370
+ if pretrained_llm_path is not None:
371
+ language_model_name = pretrained_llm_path
372
+ if language_model_name is not None:
373
+ self.language_model = self.language_model.from_pretrained(
374
+ language_model_name, **kwargs
375
+ )
376
+ print('loading language model from ', language_model_name)
377
+ self.language_model.requires_grad_(False)
378
+
379
+ self.config.text_config.torch_dtype = kwargs.get('torch_dtype', None)
380
+ self.config.pad_token = getattr(self.tokenizer, 'pad_token', None)
381
+ self.config.pad_token_id = getattr(self.tokenizer, 'pad_token_id', None)
382
+ #self.config.tokenizer_padding_side = getattr(self.tokenizer, 'padding_side', None)
383
+ #self.config.tokenizer_model_max_length = getattr(self.tokenizer, 'model_max_length', None)
384
+
385
+
386
+ def load_vision_tower(self, **kwargs):
387
+ vision_tower_name = get_value_from_kwargs(kwargs, 'model_name_or_path')
388
+ self.vision_tower.load_model(vision_tower_name, **kwargs)
389
+
390
+
391
+ def load_connector(self, **kwargs):
392
+ self.connector.load_model(**kwargs)
393
+
394
+
395
+
396
+
397
+