MinxuanQin commited on
Commit
1ec4aa4
1 Parent(s): ddd38e0

add additional feature

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -11,6 +11,8 @@ config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
11
  processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
12
  model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
13
 
 
 
14
  uploaded_file = st.file_uploader("Please upload one image (jpg)", type="jpg")
15
 
16
  question = st.text_input("Type here one question on the image")
@@ -29,4 +31,9 @@ if uploaded_file is not None:
29
  idx = logits.argmax(-1).item()
30
  pred = model.config.id2label[idx]
31
 
32
- st.text(f"Answer: {pred}")
 
 
 
 
 
 
11
  processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
12
  model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
13
 
14
+ orig_model = ViltForQuestionAnswering("dandelin/vilt-b32-finetuned-vqa")
15
+
16
  uploaded_file = st.file_uploader("Please upload one image (jpg)", type="jpg")
17
 
18
  question = st.text_input("Type here one question on the image")
 
31
  idx = logits.argmax(-1).item()
32
  pred = model.config.id2label[idx]
33
 
34
+ orig_outputs = orig_model(**encoding)
35
+ orig_logits = orig_outputs.logits
36
+ idx = orig_logits.argmax(-1).item()
37
+ orig_pred = orig_model.config.id2label[idx]
38
+ st.text(f"Answer of ViLT: {orig_pred}")
39
+ st.text(f"Answer after fine-tuning: {pred}")