vqa_demo / app.py
MinxuanQin
fix error in vbert; add time counter
7b4b5f6
import sys
sys.path.append(".")
import streamlit as st
import pandas as pd
from PIL import Image
import time
from model_loader import *
from datasets import load_dataset
# load dataset
#ds = load_dataset("test")
# ds = load_dataset("HuggingFaceM4/VQAv2", split="validation", cache_dir="cache", streaming=False)
df = pd.read_json('vqa_samples.json', orient="columns")
# define selector
model_name = st.sidebar.selectbox(
"Select a model: ",
('vilt', 'vilt_finetuned', 'git', 'blip', 'vbert')
)
image_selector_unspecific = st.number_input(
"Select an question id: ",
0, len(df)
)
# select and display
#sample = ds[image_selector_unspecific]
sample = df.iloc[image_selector_unspecific]
img_path = sample['img_path']
image = Image.open(f'images/{img_path}.jpg')
st.image(image, channels="RGB")
question = sample['ques']
label = sample['label']
# inference
question = st.text_input(f"Ask the model a question related to the image: \n"
f"(e.g. \"{sample['ques']}\")")
t_begin = time.perf_counter()
args = load_model(model_name) # TODO: cache
answer = get_answer(args, image, question, model_name)
t_end = time.perf_counter()
st.text(f"Answer by {model_name}: {answer}")
st.text(f"Ground truth (of the example): {label}")
st.text(f"Time consumption: {(t_end-t_begin): .4f} s")