MinxuanQin
adapt dimension of blip inputs
a0bc852
raw
history blame contribute delete
No virus
2.23 kB
import numpy as np
import torch
from PIL import Image
from transformers import ViltConfig, ViltProcessor, ViltForQuestionAnswering
from transformers import BlipProcessor, BlipForQuestionAnswering
import cv2
import streamlit as st
st.title("Live demo of multimodal vqa")
config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
orig_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
blip_processor = BlipProcessor.from_pretrained('Salesforce/blip-vqa-base')
blip_model = BlipForQuestionAnswering.from_pretrained('Salesforce/blip-vqa-base')
uploaded_file = st.file_uploader("Please upload one image", type=["jpg", "png", "bmp", "jpeg"])
question = st.text_input("Type here one question on the image")
if uploaded_file is not None:
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
opencv_img = cv2.imdecode(file_bytes, 1)
image_cv2 = cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB)
st.image(image_cv2, channels="RGB")
img = Image.fromarray(image_cv2)
encoding = processor(images=img, text=question, return_tensors="pt")
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
pred = model.config.id2label[idx]
orig_outputs = orig_model(**encoding)
orig_logits = orig_outputs.logits
idx = orig_logits.argmax(-1).item()
orig_pred = orig_model.config.id2label[idx]
## BLIP
pixel_values = blip_processor(images=img, return_tensors="pt").pixel_values
blip_ques = blip_processor.tokenizer.cls_token + question
batch_input_ids = blip_processor(text=blip_ques, add_special_tokens=False).input_ids
batch_input_ids = torch.tensor(batch_input_ids).unsqueeze(0)
generate_ids = blip_model.generate(pixel_values=pixel_values, input_ids=batch_input_ids, max_length=50)
blip_output = blip_processor.batch_decode(generate_ids, skip_special_tokens=True)
st.text(f"Answer of ViLT: {orig_pred}")
st.text(f"Answer after fine-tuning: {pred}")
st.text(f"Answer of BLIP: {blip_output[0]}")