File size: 1,435 Bytes
de05d04
 
 
c40a6be
de05d04
 
 
 
 
 
 
 
 
6f67cca
1ec4aa4
5cca687
de05d04
30b0855
de05d04
30b0855
c40a6be
 
 
 
 
de05d04
c40a6be
de05d04
 
 
 
 
 
1ec4aa4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
from PIL import Image
from transformers import ViltConfig, ViltProcessor, ViltForQuestionAnswering
import cv2
import streamlit as st

st.title("Live demo of multimodal vqa")

config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")

orig_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

uploaded_file = st.file_uploader("Please upload one image", type=["jpg", "png", "bmp", "jpeg"])

question = st.text_input("Type here one question on the image")
if uploaded_file is not None:
    file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
    opencv_img = cv2.imdecode(file_bytes, 1)
    image_cv2 = cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB)
    st.image(image_cv2, channels="RGB")

    img = Image.fromarray(image_cv2)

    encoding = processor(images=img, text=question, return_tensors="pt")

    outputs = model(**encoding)
    logits = outputs.logits
    idx = logits.argmax(-1).item()
    pred = model.config.id2label[idx]

    orig_outputs = orig_model(**encoding)
    orig_logits = orig_outputs.logits
    idx = orig_logits.argmax(-1).item()
    orig_pred = orig_model.config.id2label[idx]
    st.text(f"Answer of ViLT: {orig_pred}")
    st.text(f"Answer after fine-tuning: {pred}")