MinxuanQin commited on
Commit
0c9e22d
1 Parent(s): a5ab0ec

fix error in visualbert

Browse files
Files changed (1) hide show
  1. model_loader.py +16 -8
model_loader.py CHANGED
@@ -62,13 +62,20 @@ def load_dataset(type):
62
  raise ValueError("invalid dataset: ", type)
63
  '''
64
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- def tokenize_function(examples, processor):
67
- sample = {}
68
- sample['inputs'] = processor(images=examples['image'], text=examples['question'], return_tensors="pt")
69
- sample['outputs'] = examples['multiple_choice_answer']
70
- return sample
71
-
72
 
73
  def label_count_list(labels):
74
  res = {}
@@ -88,7 +95,7 @@ def get_item(image, question, tokenizer, image_model, model_name):
88
  )
89
  visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
90
  .squeeze(2, 3).unsqueeze(0)
91
- st.text(f"ques embed: {inputs.shape}, visual: {visual_embeds.shape}")
92
  visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
93
  visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
94
  upd_dict = {
@@ -192,7 +199,8 @@ def get_answer(model_loader_args, img, question, model_name):
192
 
193
  # load question and image (processor = tokenizer)
194
  ## MOD Minxuan: fix error
195
- _, inputs = get_item(img, question, processor, "resnet50")
 
196
  outputs = model(**inputs)
197
  #except Exception:
198
  # return err_msg()
 
62
  raise ValueError("invalid dataset: ", type)
63
  '''
64
 
65
+ def load_img_model(name):
66
+ """
67
+ loads image models for feature extraction
68
+ returns model name and the loaded model
69
+ """
70
+ if name == "resnet50":
71
+ model = resnet50(weights='DEFAULT')
72
+ elif name == "vitb16":
73
+ ## MOD Minxuan: add param
74
+ model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
75
+ else:
76
+ raise ValueError("undefined model name: ", name)
77
 
78
+ return model, name
 
 
 
 
 
79
 
80
  def label_count_list(labels):
81
  res = {}
 
95
  )
96
  visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
97
  .squeeze(2, 3).unsqueeze(0)
98
+
99
  visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
100
  visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
101
  upd_dict = {
 
199
 
200
  # load question and image (processor = tokenizer)
201
  ## MOD Minxuan: fix error
202
+ img_model, name = load_img_model("resnet50")
203
+ _, inputs = get_item(img, question, processor, img_model, name)
204
  outputs = model(**inputs)
205
  #except Exception:
206
  # return err_msg()