bingwork commited on
Commit
82a5b0d
1 Parent(s): 4bd1add

Upload mmalaya_arch.py

Browse files
Files changed (1) hide show
  1. mmalaya_arch.py +4 -44
mmalaya_arch.py CHANGED
@@ -3,7 +3,7 @@ import re
3
  import torch
4
  import torch.nn as nn
5
  from transformers import Blip2Model, Blip2Processor, Blip2Config
6
- from .mm_utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
7
 
8
 
9
  class BLIP2VisionTower(nn.Module):
@@ -265,46 +265,6 @@ class MMAlayaMetaForCausalLM(ABC):
265
 
266
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
267
 
268
- def initialize_vision_tokenizer(self, model_args, tokenizer):
269
- if model_args.mm_use_im_patch_token:
270
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
271
- self.resize_token_embeddings(len(tokenizer))
272
-
273
- if model_args.mm_use_im_start_end:
274
- num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
275
- self.resize_token_embeddings(len(tokenizer))
276
-
277
- if num_new_tokens > 0:
278
- input_embeddings = self.get_input_embeddings().weight.data
279
- output_embeddings = self.get_output_embeddings().weight.data
280
-
281
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
282
- dim=0, keepdim=True)
283
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
284
- dim=0, keepdim=True)
285
-
286
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
287
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
288
-
289
- if model_args.tune_mm_mlp_adapter:
290
- for p in self.get_input_embeddings().parameters():
291
- p.requires_grad = True
292
- for p in self.get_output_embeddings().parameters():
293
- p.requires_grad = False
294
-
295
- if model_args.pretrain_mm_mlp_adapter:
296
- mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
297
- embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
298
- assert num_new_tokens == 2
299
- if input_embeddings.shape == embed_tokens_weight.shape:
300
- input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
301
- elif embed_tokens_weight.shape[0] == num_new_tokens:
302
- input_embeddings[-num_new_tokens:] = embed_tokens_weight
303
- else:
304
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
305
- elif model_args.mm_use_im_patch_token:
306
- if model_args.tune_mm_mlp_adapter:
307
- for p in self.get_input_embeddings().parameters():
308
- p.requires_grad = False
309
- for p in self.get_output_embeddings().parameters():
310
- p.requires_grad = False
 
3
  import torch
4
  import torch.nn as nn
5
  from transformers import Blip2Model, Blip2Processor, Blip2Config
6
+ from .mm_utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
7
 
8
 
9
  class BLIP2VisionTower(nn.Module):
 
265
 
266
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
267
 
268
+ def initialize_vision_tokenizer(self, tokenizer):
269
+ tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True)
270
+ self.resize_token_embeddings(len(tokenizer))