gchhablani's picture
Move back to online checkpoints
7fdcddd
raw history blame
No virus
5.19 kB
from .utils import (
get_text_attributes,
get_top_5_predictions,
get_transformed_image,
plotly_express_horizontal_bar_plot,
translate_labels,
)
import streamlit as st
import numpy as np
import pandas as pd
import os
import requests
from PIL import Image
import matplotlib.pyplot as plt
import json
from mtranslate import translate
from .utils import read_markdown
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
FlaxCLIPVisionBertForSequenceClassification,
)
def softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
def app(state):
vqa_state = state
st.header("Visual Question Answering Demo")
with st.beta_expander("Usage"):
st.write(read_markdown("vqa_usage.md"))
st.info(read_markdown("vqa_intro.md"))
# @st.cache(persist=False)
def predict(transformed_image, question_inputs):
return np.array(
vqa_state.vqa_model(pixel_values=transformed_image, **question_inputs)[0][0]
)
# @st.cache(persist=False)
def load_model(ckpt):
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
vqa_checkpoints = [
"flax-community/clip-vision-bert-vqa-ft-6k"
] # TODO: Maybe add more checkpoints?
# vqa_checkpoints = ["./ckpt/vqa/ckpt-60k-5999"]
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
code_to_name = {
"en": "English",
"fr": "French",
"de": "German",
"es": "Spanish",
}
with open("answer_reverse_mapping.json") as f:
answer_reverse_mapping = json.load(f)
first_index = 20
# Init Session vqa_state
if vqa_state.vqa_image_file is None:
vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"]
vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ")
vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"]
vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
image = plt.imread(image_path)
vqa_state.vqa_image = image
if vqa_state.vqa_model is None:
with st.spinner("Loading model..."):
vqa_state.vqa_model = load_model(vqa_checkpoints[0])
# Display Top-5 Predictions
query1 = st.text_input(
"Enter a URL to an image",
value="http://images.cocodataset.org/val2017/000000039769.jpg",
)
col1, col2, col3 = st.beta_columns([2,1, 2])
if col1.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
vqa_state.vqa_image_file = sample.loc[0, "image_file"]
vqa_state.question = sample.loc[0, "question"].strip("- ")
vqa_state.answer_label = sample.loc[0, "answer_label"]
vqa_state.question_lang_id = sample.loc[0, "lang_id"]
vqa_state.answer_lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
image = plt.imread(image_path)
vqa_state.vqa_image = image
col2.write("OR")
if col3.button("Use above URL"):
image_data = requests.get(query1, stream=True).raw
image = np.asarray(Image.open(image_data))
vqa_state.vqa_image = image
transformed_image = get_transformed_image(vqa_state.vqa_image)
new_col1, new_col2 = st.beta_columns([5, 5])
# Display Image
new_col1.image(vqa_state.vqa_image, use_column_width="auto")
# Display Question
question = new_col2.text_input(
label="Question",
value=vqa_state.question,
help="Type your question regarding the image above in one of the four languages.",
)
new_col2.markdown(
f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}"""
)
question_inputs = get_text_attributes(question)
# Select Language
options = ["en", "de", "es", "fr"]
vqa_state.answer_lang_id = new_col2.selectbox(
"Answer Language",
index=options.index(vqa_state.answer_lang_id),
options=options,
format_func=lambda x: code_to_name[x],
help="The language to be used to show the top-5 labels.",
)
if question == vqa_state.question:
actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)]
new_col2.markdown(
"**Actual Answer**: "
+ translate_labels([actual_answer], vqa_state.answer_lang_id)[0]
+ " ("
+ actual_answer
+ ")"
)
with st.spinner("Predicting..."):
logits = predict(transformed_image, dict(question_inputs))
logits = softmax(logits)
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
translated_labels = translate_labels(labels, vqa_state.answer_lang_id)
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
st.plotly_chart(fig, use_container_width=True)