MinxuanQin commited on
Commit
5487511
1 Parent(s): 5cca687

add BLIP features

Browse files
Files changed (1) hide show
  1. app.py +15 -0
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import numpy as np
 
2
  from PIL import Image
3
  from transformers import ViltConfig, ViltProcessor, ViltForQuestionAnswering
 
4
  import cv2
5
  import streamlit as st
6
 
@@ -13,6 +15,9 @@ model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
13
 
14
  orig_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
15
 
 
 
 
16
  uploaded_file = st.file_uploader("Please upload one image", type=["jpg", "png", "bmp", "jpeg"])
17
 
18
  question = st.text_input("Type here one question on the image")
@@ -35,5 +40,15 @@ if uploaded_file is not None:
35
  orig_logits = orig_outputs.logits
36
  idx = orig_logits.argmax(-1).item()
37
  orig_pred = orig_model.config.id2label[idx]
 
 
 
 
 
 
 
 
 
38
  st.text(f"Answer of ViLT: {orig_pred}")
 
39
  st.text(f"Answer after fine-tuning: {pred}")
 
1
  import numpy as np
2
+ import torch
3
  from PIL import Image
4
  from transformers import ViltConfig, ViltProcessor, ViltForQuestionAnswering
5
+ from transformers import BlipProcessor, BlipForQuestionAnswering
6
  import cv2
7
  import streamlit as st
8
 
 
15
 
16
  orig_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
17
 
18
+ blip_processor = BlipProcessor.from_pretrained('Salesforce/blip-vqa-base')
19
+ blip_model = BlipForQuestionAnswering.from_pretrained('Salesforce/blip-vqa-base')
20
+
21
  uploaded_file = st.file_uploader("Please upload one image", type=["jpg", "png", "bmp", "jpeg"])
22
 
23
  question = st.text_input("Type here one question on the image")
 
40
  orig_logits = orig_outputs.logits
41
  idx = orig_logits.argmax(-1).item()
42
  orig_pred = orig_model.config.id2label[idx]
43
+
44
+ ## BLIP
45
+ pixel_values = blip_processor(images=img, return_tensors="pt").pixel_values
46
+ blip_ques = blip_processor.tokenizer.cls_token + question
47
+ batch_input_ids = blip_processor(text=blip_ques, add_special_tokens=False).input_ids
48
+ batch_input_ids = torch.tensor(batch_input_ids)
49
+
50
+ generate_ids = blip_model.generate(pixel_values=pixel_values, input_ids=batch_input_ids, max_length=50)
51
+ blip_output = blip_processor.batch_decode(generate_ids, skip_special_tokens=True)
52
  st.text(f"Answer of ViLT: {orig_pred}")
53
+ st.text(f"Answer of BLIP: {blip_output}")
54
  st.text(f"Answer after fine-tuning: {pred}")