MinxuanQin commited on
Commit
d43497c
1 Parent(s): 2e4b982

fix model load error

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. model_loader.py +15 -1
app.py CHANGED
@@ -17,7 +17,7 @@ df = pd.read_json('vqa_samples.json', orient="columns")
17
  # define selector
18
  model_name = st.sidebar.selectbox(
19
  "Select a model: ",
20
- ('vilt', 'git', 'blip', 'vbert')
21
  )
22
 
23
  image_selector_unspecific = st.number_input(
@@ -41,4 +41,4 @@ question = st.text_input(f"Ask the model a question related to the image: \n"
41
  args = load_model(model_name) # TODO: cache
42
  answer = get_answer(args, image, question, model_name)
43
  st.text(f"Answer by {model_name}: {answer}")
44
- st.text(f"Ground truth: {label}")
 
17
  # define selector
18
  model_name = st.sidebar.selectbox(
19
  "Select a model: ",
20
+ ('vilt', 'vilt_finetuned', 'git', 'blip', 'vbert')
21
  )
22
 
23
  image_selector_unspecific = st.number_input(
 
41
  args = load_model(model_name) # TODO: cache
42
  answer = get_answer(args, image, question, model_name)
43
  st.text(f"Answer by {model_name}: {answer}")
44
+ st.text(f"Ground truth (of the example): {label}")
model_loader.py CHANGED
@@ -33,7 +33,10 @@ VQA_URL = "https://dl.fbaipublicfiles.com/pythia/data/answers_vqa.txt"
33
  def load_model(name):
34
  if name == "vilt":
35
  processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
36
- model = ViltForQuestionAnswering.from_pretrained("CARETS/vilt_neg_model")
 
 
 
37
  elif name == "git":
38
  processor = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
39
  model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
@@ -155,6 +158,17 @@ def get_answer(model_loader_args, img, question, model_name):
155
  logits = outputs.logits
156
  idx = logits.argmax(-1).item()
157
  pred = model.config.id2label[idx]
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  elif model_name == "git":
160
  try:
 
33
  def load_model(name):
34
  if name == "vilt":
35
  processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
36
+ model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
37
+ elif name == "vilt_finetuned":
38
+ processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
39
+ model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
40
  elif name == "git":
41
  processor = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
42
  model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
 
158
  logits = outputs.logits
159
  idx = logits.argmax(-1).item()
160
  pred = model.config.id2label[idx]
161
+
162
+ elif model_name == "vilt_finetuned":
163
+ try:
164
+ encoding = processor(images=img, text=question, return_tensors="pt")
165
+ except Exception:
166
+ return err_msg()
167
+ else:
168
+ outputs = model(**encoding)
169
+ logits = outputs.logits
170
+ idx = logits.argmax(-1).item()
171
+ pred = model.config.id2label[idx]
172
 
173
  elif model_name == "git":
174
  try: