MinxuanQin
commited on
Commit
·
58c2c99
1
Parent(s):
7b4b5f6
add display visualbert
Browse files- model_loader.py +3 -1
model_loader.py
CHANGED
@@ -5,6 +5,7 @@ from datasets import load_dataset, get_dataset_split_names
|
|
5 |
import numpy as np
|
6 |
|
7 |
import requests
|
|
|
8 |
from transformers import ViltProcessor, ViltForQuestionAnswering
|
9 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
10 |
from transformers import BlipProcessor, BlipForQuestionAnswering
|
@@ -87,6 +88,7 @@ def get_item(image, question, tokenizer, image_model, model_name):
|
|
87 |
)
|
88 |
visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
|
89 |
.squeeze(2, 3).unsqueeze(0)
|
|
|
90 |
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
91 |
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
92 |
upd_dict = {
|
@@ -95,7 +97,7 @@ def get_item(image, question, tokenizer, image_model, model_name):
|
|
95 |
"visual_attention_mask": visual_attention_mask,
|
96 |
}
|
97 |
inputs.update(upd_dict)
|
98 |
-
|
99 |
return upd_dict, inputs
|
100 |
|
101 |
|
|
|
5 |
import numpy as np
|
6 |
|
7 |
import requests
|
8 |
+
import streamlit as st
|
9 |
from transformers import ViltProcessor, ViltForQuestionAnswering
|
10 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
11 |
from transformers import BlipProcessor, BlipForQuestionAnswering
|
|
|
88 |
)
|
89 |
visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
|
90 |
.squeeze(2, 3).unsqueeze(0)
|
91 |
+
st.text(f"ques embed: {inputs.shape}, visual: {visual_embeds.shape}")
|
92 |
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
93 |
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
94 |
upd_dict = {
|
|
|
97 |
"visual_attention_mask": visual_attention_mask,
|
98 |
}
|
99 |
inputs.update(upd_dict)
|
100 |
+
|
101 |
return upd_dict, inputs
|
102 |
|
103 |
|