fboeEl commited on
Commit
0bd26a9
1 Parent(s): 29016d4

Removing the if __name__ == "__main__" part which (if I am not mistaken) would never be used in this scenario.

Browse files
Files changed (1) hide show
  1. modeling_mplug_owl2.py +1 -24
modeling_mplug_owl2.py CHANGED
@@ -37,7 +37,6 @@ from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
  IGNORE_INDEX = -100
38
  IMAGE_TOKEN_INDEX = -200
39
  DEFAULT_IMAGE_TOKEN = "<|image|>"
40
- from icecream import ic
41
 
42
  def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
43
  prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
@@ -387,26 +386,4 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
387
  AutoConfig.register("mplug_owl2", MPLUGOwl2Config)
388
  AutoModelForCausalLM.register(MPLUGOwl2Config, MPLUGOwl2LlamaForCausalLM)
389
 
390
- replace_llama_modality_adaptive()
391
-
392
- if __name__ == "__main__":
393
- config = MPLUGOwl2Config.from_pretrained('q-future/one-align')
394
- from icecream import ic
395
- # config = MPLUGOwl2Config()
396
- model = AutoModelForCausalLM(config)
397
-
398
- images = torch.randn(2, 3, 448, 448)
399
- input_ids = torch.cat([
400
- torch.ones(8).long(), torch.tensor([-1]*1).long(), torch.ones(8).long(), torch.tensor([-1]*1).long(), torch.ones(8).long()
401
- ], dim=0).unsqueeze(0)
402
- labels = input_ids.clone()
403
- labels[labels < 0] = -100
404
-
405
- # image_feature = model.encode_images(images)
406
- # ic(image_feature.shape)
407
-
408
- output = model(images=images, input_ids=input_ids, labels=labels)
409
- ic(output.loss)
410
- ic(output.logits.shape)
411
-
412
- model.save_pretrained('/cpfs01/shared/public/test/tmp_owl')
 
37
  IGNORE_INDEX = -100
38
  IMAGE_TOKEN_INDEX = -200
39
  DEFAULT_IMAGE_TOKEN = "<|image|>"
 
40
 
41
  def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
42
  prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
 
386
  AutoConfig.register("mplug_owl2", MPLUGOwl2Config)
387
  AutoModelForCausalLM.register(MPLUGOwl2Config, MPLUGOwl2LlamaForCausalLM)
388
 
389
+ replace_llama_modality_adaptive()