File size: 2,225 Bytes
de05d04
5487511
de05d04
 
5487511
c40a6be
de05d04
 
 
 
 
 
 
 
 
6f67cca
1ec4aa4
5487511
 
 
5cca687
de05d04
30b0855
de05d04
30b0855
c40a6be
 
 
 
 
de05d04
c40a6be
de05d04
 
 
 
 
 
1ec4aa4
 
 
 
5487511
 
 
 
 
a0bc852
 
 
 
5487511
1ec4aa4
 
a0bc852
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
import torch
from PIL import Image
from transformers import ViltConfig, ViltProcessor, ViltForQuestionAnswering
from transformers import BlipProcessor, BlipForQuestionAnswering
import cv2
import streamlit as st

st.title("Live demo of multimodal vqa")

config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")

orig_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

blip_processor = BlipProcessor.from_pretrained('Salesforce/blip-vqa-base')
blip_model = BlipForQuestionAnswering.from_pretrained('Salesforce/blip-vqa-base')

uploaded_file = st.file_uploader("Please upload one image", type=["jpg", "png", "bmp", "jpeg"])

question = st.text_input("Type here one question on the image")
if uploaded_file is not None:
    file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
    opencv_img = cv2.imdecode(file_bytes, 1)
    image_cv2 = cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB)
    st.image(image_cv2, channels="RGB")

    img = Image.fromarray(image_cv2)

    encoding = processor(images=img, text=question, return_tensors="pt")

    outputs = model(**encoding)
    logits = outputs.logits
    idx = logits.argmax(-1).item()
    pred = model.config.id2label[idx]

    orig_outputs = orig_model(**encoding)
    orig_logits = orig_outputs.logits
    idx = orig_logits.argmax(-1).item()
    orig_pred = orig_model.config.id2label[idx]

    ## BLIP
    pixel_values = blip_processor(images=img, return_tensors="pt").pixel_values
    blip_ques = blip_processor.tokenizer.cls_token + question
    batch_input_ids = blip_processor(text=blip_ques, add_special_tokens=False).input_ids
    batch_input_ids = torch.tensor(batch_input_ids).unsqueeze(0)

    generate_ids = blip_model.generate(pixel_values=pixel_values, input_ids=batch_input_ids, max_length=50)
    blip_output = blip_processor.batch_decode(generate_ids, skip_special_tokens=True)

    st.text(f"Answer of ViLT: {orig_pred}")
    st.text(f"Answer after fine-tuning: {pred}")
    st.text(f"Answer of BLIP: {blip_output[0]}")