abreza commited on
Commit
4e54f49
1 Parent(s): 471609c

use mg-llava instead llava in AutoConfig.register

Browse files
Files changed (1) hide show
  1. ml_mgie/mgie_llava.py +91 -47
ml_mgie/mgie_llava.py CHANGED
@@ -12,12 +12,12 @@ import torch.nn.functional as F
12
  from torch.nn import CrossEntropyLoss
13
 
14
  from transformers import AutoConfig, AutoModelForCausalLM, \
15
- LlamaConfig, LlamaModel, LlamaForCausalLM, \
16
- CLIPVisionModel, CLIPImageProcessor
17
 
18
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
 
20
- import os, diffusers
21
 
22
  DEFAULT_IMAGE_TOKEN = "<image>"
23
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
@@ -26,7 +26,7 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
26
 
27
 
28
  class LlavaConfig(LlamaConfig):
29
- model_type = "llava"
30
 
31
 
32
  class LlavaLlamaModel(LlamaModel):
@@ -37,11 +37,13 @@ class LlavaLlamaModel(LlamaModel):
37
 
38
  if hasattr(config, "mm_vision_tower"):
39
  # HACK: for FSDP
40
- self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
 
41
  # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
42
 
43
  if hasattr(config, "use_mm_proj"):
44
- self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
 
45
 
46
  def get_vision_tower(self):
47
  vision_tower = getattr(self, 'vision_tower', None)
@@ -67,18 +69,22 @@ class LlavaLlamaModel(LlamaModel):
67
  self.vision_tower = vision_tower
68
 
69
  vision_config = vision_tower.config
70
- num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
 
71
 
72
  self.config.use_mm_proj = True
73
  self.config.mm_hidden_size = vision_config.hidden_size
74
  self.config.mm_vision_select_layer = mm_vision_select_layer
75
 
76
  if not hasattr(self, 'mm_projector'):
77
- self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
 
78
 
79
  if pretrain_mm_mlp_adapter is not None:
80
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
81
- self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
 
 
82
 
83
  return dict(
84
  image_processor=image_processor,
@@ -117,21 +123,28 @@ class LlavaLlamaModel(LlamaModel):
117
  # variable length images
118
  image_features = []
119
  for image in images:
120
- image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
121
- select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
 
 
122
  select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
123
  image_feature = select_hidden_state[:, 1:]
124
  image_features.append(image_feature)
125
  else:
126
- image_forward_outs = vision_tower(images.to(vision_tower.dtype), output_hidden_states=True)
127
- select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
 
 
128
  select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
129
- image_features = select_hidden_state[:, 1:].to(images.dtype)
 
130
  if type(images) is list:
131
- image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
 
132
  else:
133
  image_features = self.mm_projector(image_features)
134
- dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
 
135
  dummy_image_features = self.mm_projector(dummy_image_features)
136
 
137
  new_input_embeds = []
@@ -139,7 +152,8 @@ class LlavaLlamaModel(LlamaModel):
139
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
140
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
141
  # multimodal LLM, but the current sample is not multimodal
142
- cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
 
143
  new_input_embeds.append(cur_input_embeds)
144
  cur_image_idx += 1
145
  continue
@@ -147,32 +161,43 @@ class LlavaLlamaModel(LlamaModel):
147
  cur_image_features = image_features[cur_image_idx]
148
  num_patches = cur_image_features.shape[0]
149
  if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
150
- raise ValueError("The number of image start tokens and image end tokens should be the same.")
151
- image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
 
 
152
  for image_start_token_pos in image_start_tokens:
153
- cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
 
154
  num_patches = cur_image_features.shape[0]
155
  if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
156
- raise ValueError("The image end token should follow the image start token.")
 
157
  if orig_embeds_params is not None:
158
- cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
 
159
  else:
160
- cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
 
161
  cur_image_idx += 1
162
  new_input_embeds.append(cur_new_input_embeds)
163
  else:
164
  cur_image_features = image_features[cur_image_idx]
165
  num_patches = cur_image_features.shape[0]
166
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
167
- raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
168
- masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
 
 
169
  mask_index_start = masked_indices[0]
170
  if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
171
- raise ValueError("The image patch tokens should be consecutive.")
 
172
  if orig_embeds_params is not None:
173
- cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
 
174
  else:
175
- cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
 
176
  new_input_embeds.append(cur_new_input_embeds)
177
  cur_image_idx += 1
178
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
@@ -184,6 +209,7 @@ class LlavaLlamaModel(LlamaModel):
184
  return_dict=return_dict
185
  )
186
 
 
187
  class EditMapper(nn.Module):
188
  def __init__(self):
189
  super().__init__()
@@ -202,6 +228,7 @@ class EditMapper(nn.Module):
202
 
203
  return feat
204
 
 
205
  class LlavaLlamaForCausalLM(LlamaForCausalLM):
206
  config_class = LlavaConfig
207
 
@@ -209,7 +236,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
209
  super(LlamaForCausalLM, self).__init__(config)
210
  self.model = LlavaLlamaModel(config)
211
 
212
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
213
 
214
  self.edit_head = EditMapper()
215
 
@@ -292,12 +320,15 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
292
  if labels is not None:
293
  llm = []
294
  for i in range(labels.shape[0]):
295
- try: p = labels[i].data.cpu().tolist().index(32003)-1
296
- except: p = len(labels[i])-9
 
 
297
  p = min(len(hidden_states[i])-9, p)
298
  llm.append(hidden_states[i][p:p+8].unsqueeze(0))
299
  llm = torch.cat(llm, dim=0)
300
- hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
 
301
 
302
  B, DROP = labels.shape[0], 0.05
303
 
@@ -305,24 +336,30 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
305
  self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
306
 
307
  with torch.no_grad():
308
- lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
 
309
  lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
310
  torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
311
 
312
  noise = torch.randn_like(lat_ans)
313
- ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
 
314
  lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
315
 
316
  prob = torch.rand(B, device=lat_ans.device)
317
- mask = (prob<(DROP*2)).reshape(B, 1, 1)
318
  hid_edit = torch.where(mask, hid_null, hid_edit)
319
- mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*(prob<(DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
 
320
  lat_inp *= mask
321
 
322
- out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
 
323
 
324
- loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
325
- if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
 
 
326
  loss = loss_ce+loss_edit*0.5
327
 
328
  if not return_dict:
@@ -367,9 +404,11 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
367
  self.resize_token_embeddings(len(tokenizer))
368
 
369
  if mm_use_im_start_end:
370
- num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
371
  self.resize_token_embeddings(len(tokenizer))
372
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
373
 
374
  if num_new_tokens > 0:
375
  input_embeddings = self.get_input_embeddings().weight.data
@@ -384,14 +423,16 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
384
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
385
 
386
  if tune_mm_mlp_adapter:
387
- self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
 
388
  for p in self.get_input_embeddings().parameters():
389
  p.requires_grad = True
390
  for p in self.get_output_embeddings().parameters():
391
  p.requires_grad = False
392
 
393
  if pretrain_mm_mlp_adapter:
394
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
 
395
  embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
396
  assert num_new_tokens == 2
397
  if input_embeddings.shape == embed_tokens_weight.shape:
@@ -399,9 +440,12 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
399
  elif embed_tokens_weight.shape[0] == num_new_tokens:
400
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
401
  else:
402
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
 
 
 
 
403
 
404
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
405
 
406
- AutoConfig.register("llava", LlavaConfig)
407
  AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
 
12
  from torch.nn import CrossEntropyLoss
13
 
14
  from transformers import AutoConfig, AutoModelForCausalLM, \
15
+ LlamaConfig, LlamaModel, LlamaForCausalLM, \
16
+ CLIPVisionModel, CLIPImageProcessor
17
 
18
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
 
20
+ import os
21
 
22
  DEFAULT_IMAGE_TOKEN = "<image>"
23
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
26
 
27
 
28
  class LlavaConfig(LlamaConfig):
29
+ model_type = "mg-llava"
30
 
31
 
32
  class LlavaLlamaModel(LlamaModel):
 
37
 
38
  if hasattr(config, "mm_vision_tower"):
39
  # HACK: for FSDP
40
+ self.vision_tower = [
41
+ CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
42
  # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
43
 
44
  if hasattr(config, "use_mm_proj"):
45
+ self.mm_projector = nn.Linear(
46
+ config.mm_hidden_size, config.hidden_size)
47
 
48
  def get_vision_tower(self):
49
  vision_tower = getattr(self, 'vision_tower', None)
 
69
  self.vision_tower = vision_tower
70
 
71
  vision_config = vision_tower.config
72
+ num_patches = (vision_config.image_size //
73
+ vision_config.patch_size) ** 2
74
 
75
  self.config.use_mm_proj = True
76
  self.config.mm_hidden_size = vision_config.hidden_size
77
  self.config.mm_vision_select_layer = mm_vision_select_layer
78
 
79
  if not hasattr(self, 'mm_projector'):
80
+ self.mm_projector = nn.Linear(
81
+ vision_config.hidden_size, self.config.hidden_size)
82
 
83
  if pretrain_mm_mlp_adapter is not None:
84
+ mm_projector_weights = torch.load(
85
+ pretrain_mm_mlp_adapter, map_location='cpu')
86
+ self.mm_projector.load_state_dict(
87
+ {k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
88
 
89
  return dict(
90
  image_processor=image_processor,
 
123
  # variable length images
124
  image_features = []
125
  for image in images:
126
+ image_forward_out = vision_tower(
127
+ image.unsqueeze(0), output_hidden_states=True)
128
+ select_hidden_state_layer = getattr(
129
+ self.config, "mm_vision_select_layer", -1)
130
  select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
131
  image_feature = select_hidden_state[:, 1:]
132
  image_features.append(image_feature)
133
  else:
134
+ image_forward_outs = vision_tower(
135
+ images.to(vision_tower.dtype), output_hidden_states=True)
136
+ select_hidden_state_layer = getattr(
137
+ self.config, "mm_vision_select_layer", -1)
138
  select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
139
+ image_features = select_hidden_state[:, 1:].to(
140
+ images.dtype)
141
  if type(images) is list:
142
+ image_features = [self.mm_projector(
143
+ image_feature)[0] for image_feature in image_features]
144
  else:
145
  image_features = self.mm_projector(image_features)
146
+ dummy_image_features = torch.zeros(
147
+ 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
148
  dummy_image_features = self.mm_projector(dummy_image_features)
149
 
150
  new_input_embeds = []
 
152
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
153
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
154
  # multimodal LLM, but the current sample is not multimodal
155
+ cur_input_embeds = cur_input_embeds + \
156
+ (0. * dummy_image_features).sum()
157
  new_input_embeds.append(cur_input_embeds)
158
  cur_image_idx += 1
159
  continue
 
161
  cur_image_features = image_features[cur_image_idx]
162
  num_patches = cur_image_features.shape[0]
163
  if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
164
+ raise ValueError(
165
+ "The number of image start tokens and image end tokens should be the same.")
166
+ image_start_tokens = torch.where(
167
+ cur_input_ids == vision_tower.config.im_start_token)[0]
168
  for image_start_token_pos in image_start_tokens:
169
+ cur_image_features = image_features[cur_image_idx].to(
170
+ device=cur_input_embeds.device)
171
  num_patches = cur_image_features.shape[0]
172
  if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
173
+ raise ValueError(
174
+ "The image end token should follow the image start token.")
175
  if orig_embeds_params is not None:
176
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
177
+ cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
178
  else:
179
+ cur_new_input_embeds = torch.cat(
180
+ (cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
181
  cur_image_idx += 1
182
  new_input_embeds.append(cur_new_input_embeds)
183
  else:
184
  cur_image_features = image_features[cur_image_idx]
185
  num_patches = cur_image_features.shape[0]
186
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
187
+ raise ValueError(
188
+ "The number of image patch tokens should be the same as the number of image patches.")
189
+ masked_indices = torch.where(
190
+ cur_input_ids == vision_tower.config.im_patch_token)[0]
191
  mask_index_start = masked_indices[0]
192
  if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
193
+ raise ValueError(
194
+ "The image patch tokens should be consecutive.")
195
  if orig_embeds_params is not None:
196
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(
197
+ ), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
198
  else:
199
+ cur_new_input_embeds = torch.cat(
200
+ (cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
201
  new_input_embeds.append(cur_new_input_embeds)
202
  cur_image_idx += 1
203
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
 
209
  return_dict=return_dict
210
  )
211
 
212
+
213
  class EditMapper(nn.Module):
214
  def __init__(self):
215
  super().__init__()
 
228
 
229
  return feat
230
 
231
+
232
  class LlavaLlamaForCausalLM(LlamaForCausalLM):
233
  config_class = LlavaConfig
234
 
 
236
  super(LlamaForCausalLM, self).__init__(config)
237
  self.model = LlavaLlamaModel(config)
238
 
239
+ self.lm_head = nn.Linear(
240
+ config.hidden_size, config.vocab_size, bias=False)
241
 
242
  self.edit_head = EditMapper()
243
 
 
320
  if labels is not None:
321
  llm = []
322
  for i in range(labels.shape[0]):
323
+ try:
324
+ p = labels[i].data.cpu().tolist().index(32003)-1
325
+ except:
326
+ p = len(labels[i])-9
327
  p = min(len(hidden_states[i])-9, p)
328
  llm.append(hidden_states[i][p:p+8].unsqueeze(0))
329
  llm = torch.cat(llm, dim=0)
330
+ hid_edit = self.edit_head(
331
+ llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
332
 
333
  B, DROP = labels.shape[0], 0.05
334
 
 
336
  self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
337
 
338
  with torch.no_grad():
339
+ lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample(
340
+ )*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
341
  lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
342
  torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
343
 
344
  noise = torch.randn_like(lat_ans)
345
+ ts = torch.randint(
346
+ 0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
347
  lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
348
 
349
  prob = torch.rand(B, device=lat_ans.device)
350
+ mask = (prob < (DROP*2)).reshape(B, 1, 1)
351
  hid_edit = torch.where(mask, hid_null, hid_edit)
352
+ mask = (1.0-((prob >= DROP).to(lat_inp.dtype) *
353
+ (prob < (DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
354
  lat_inp *= mask
355
 
356
+ out = self.unet(
357
+ torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
358
 
359
+ loss_ce, loss_edit = loss, nn.functional.mse_loss(
360
+ out, noise, reduction='mean')
361
+ if int(os.environ['LOCAL_RANK']) == 0:
362
+ print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
363
  loss = loss_ce+loss_edit*0.5
364
 
365
  if not return_dict:
 
404
  self.resize_token_embeddings(len(tokenizer))
405
 
406
  if mm_use_im_start_end:
407
+ num_new_tokens = tokenizer.add_tokens(
408
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
409
  self.resize_token_embeddings(len(tokenizer))
410
+ vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
411
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
412
 
413
  if num_new_tokens > 0:
414
  input_embeddings = self.get_input_embeddings().weight.data
 
423
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
424
 
425
  if tune_mm_mlp_adapter:
426
+ self.get_model().orig_embeds_params = [
427
+ self.get_input_embeddings().weight.data.clone().to(device=device)]
428
  for p in self.get_input_embeddings().parameters():
429
  p.requires_grad = True
430
  for p in self.get_output_embeddings().parameters():
431
  p.requires_grad = False
432
 
433
  if pretrain_mm_mlp_adapter:
434
+ mm_projector_weights = torch.load(
435
+ pretrain_mm_mlp_adapter, map_location='cpu')
436
  embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
437
  assert num_new_tokens == 2
438
  if input_embeddings.shape == embed_tokens_weight.shape:
 
440
  elif embed_tokens_weight.shape[0] == num_new_tokens:
441
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
442
  else:
443
+ raise ValueError(
444
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
445
+
446
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
447
+ [DEFAULT_IMAGE_PATCH_TOKEN])[0]
448
 
 
449
 
450
+ AutoConfig.register("mg-llava", LlavaConfig)
451
  AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)