Removing the if __name__ == "__main__" part which (if I am not mistaken) would never be used in this scenario.
#4
by
fboeEl
- opened
- modeling_mplug_owl2.py +1 -24
modeling_mplug_owl2.py
CHANGED
@@ -37,7 +37,6 @@ from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
37 |
IGNORE_INDEX = -100
|
38 |
IMAGE_TOKEN_INDEX = -200
|
39 |
DEFAULT_IMAGE_TOKEN = "<|image|>"
|
40 |
-
from icecream import ic
|
41 |
|
42 |
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
43 |
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
|
@@ -387,26 +386,4 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|
387 |
AutoConfig.register("mplug_owl2", MPLUGOwl2Config)
|
388 |
AutoModelForCausalLM.register(MPLUGOwl2Config, MPLUGOwl2LlamaForCausalLM)
|
389 |
|
390 |
-
replace_llama_modality_adaptive()
|
391 |
-
|
392 |
-
if __name__ == "__main__":
|
393 |
-
config = MPLUGOwl2Config.from_pretrained('q-future/one-align')
|
394 |
-
from icecream import ic
|
395 |
-
# config = MPLUGOwl2Config()
|
396 |
-
model = AutoModelForCausalLM(config)
|
397 |
-
|
398 |
-
images = torch.randn(2, 3, 448, 448)
|
399 |
-
input_ids = torch.cat([
|
400 |
-
torch.ones(8).long(), torch.tensor([-1]*1).long(), torch.ones(8).long(), torch.tensor([-1]*1).long(), torch.ones(8).long()
|
401 |
-
], dim=0).unsqueeze(0)
|
402 |
-
labels = input_ids.clone()
|
403 |
-
labels[labels < 0] = -100
|
404 |
-
|
405 |
-
# image_feature = model.encode_images(images)
|
406 |
-
# ic(image_feature.shape)
|
407 |
-
|
408 |
-
output = model(images=images, input_ids=input_ids, labels=labels)
|
409 |
-
ic(output.loss)
|
410 |
-
ic(output.logits.shape)
|
411 |
-
|
412 |
-
model.save_pretrained('/cpfs01/shared/public/test/tmp_owl')
|
|
|
37 |
IGNORE_INDEX = -100
|
38 |
IMAGE_TOKEN_INDEX = -200
|
39 |
DEFAULT_IMAGE_TOKEN = "<|image|>"
|
|
|
40 |
|
41 |
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
42 |
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
|
|
|
386 |
AutoConfig.register("mplug_owl2", MPLUGOwl2Config)
|
387 |
AutoModelForCausalLM.register(MPLUGOwl2Config, MPLUGOwl2LlamaForCausalLM)
|
388 |
|
389 |
+
replace_llama_modality_adaptive()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|