Spaces:
Sleeping
Sleeping
File size: 2,225 Bytes
de05d04 5487511 de05d04 5487511 c40a6be de05d04 6f67cca 1ec4aa4 5487511 5cca687 de05d04 30b0855 de05d04 30b0855 c40a6be de05d04 c40a6be de05d04 1ec4aa4 5487511 a0bc852 5487511 1ec4aa4 a0bc852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import numpy as np
import torch
from PIL import Image
from transformers import ViltConfig, ViltProcessor, ViltForQuestionAnswering
from transformers import BlipProcessor, BlipForQuestionAnswering
import cv2
import streamlit as st
st.title("Live demo of multimodal vqa")
config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
orig_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
blip_processor = BlipProcessor.from_pretrained('Salesforce/blip-vqa-base')
blip_model = BlipForQuestionAnswering.from_pretrained('Salesforce/blip-vqa-base')
uploaded_file = st.file_uploader("Please upload one image", type=["jpg", "png", "bmp", "jpeg"])
question = st.text_input("Type here one question on the image")
if uploaded_file is not None:
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
opencv_img = cv2.imdecode(file_bytes, 1)
image_cv2 = cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB)
st.image(image_cv2, channels="RGB")
img = Image.fromarray(image_cv2)
encoding = processor(images=img, text=question, return_tensors="pt")
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
pred = model.config.id2label[idx]
orig_outputs = orig_model(**encoding)
orig_logits = orig_outputs.logits
idx = orig_logits.argmax(-1).item()
orig_pred = orig_model.config.id2label[idx]
## BLIP
pixel_values = blip_processor(images=img, return_tensors="pt").pixel_values
blip_ques = blip_processor.tokenizer.cls_token + question
batch_input_ids = blip_processor(text=blip_ques, add_special_tokens=False).input_ids
batch_input_ids = torch.tensor(batch_input_ids).unsqueeze(0)
generate_ids = blip_model.generate(pixel_values=pixel_values, input_ids=batch_input_ids, max_length=50)
blip_output = blip_processor.batch_decode(generate_ids, skip_special_tokens=True)
st.text(f"Answer of ViLT: {orig_pred}")
st.text(f"Answer after fine-tuning: {pred}")
st.text(f"Answer of BLIP: {blip_output[0]}")
|