gchhablani's picture
Update layout
fb3c77c
raw history blame
No virus
4.59 kB
from .utils import (
get_text_attributes,
get_top_5_predictions,
get_transformed_image,
plotly_express_horizontal_bar_plot,
translate_labels,
)
import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import json
from mtranslate import translate
from .utils import read_markdown
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
FlaxCLIPVisionBertForSequenceClassification,
)
def softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
def app(state):
vqa_state = state
with st.beta_expander("Usage"):
st.write(read_markdown("vqa_usage.md"))
st.write(read_markdown("vqa_intro.md"))
# @st.cache(persist=False)
def predict(transformed_image, question_inputs):
return np.array(
vqa_state.vqa_model(pixel_values=transformed_image, **question_inputs)[0][0]
)
# @st.cache(persist=False)
def load_model(ckpt):
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
vqa_checkpoints = [
"flax-community/clip-vision-bert-vqa-ft-6k"
] # TODO: Maybe add more checkpoints?
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
code_to_name = {
"en": "English",
"fr": "French",
"de": "German",
"es": "Spanish",
}
with open("answer_reverse_mapping.json") as f:
answer_reverse_mapping = json.load(f)
first_index = 20
# Init Session vqa_state
if vqa_state.vqa_image_file is None:
vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"]
vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ")
vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"]
vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
image = plt.imread(image_path)
vqa_state.vqa_image = image
if vqa_state.vqa_model is None:
with st.spinner("Loading model..."):
vqa_state.vqa_model = load_model(vqa_checkpoints[0])
# Display Top-5 Predictions
if st.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
vqa_state.vqa_image_file = sample.loc[0, "image_file"]
vqa_state.question = sample.loc[0, "question"].strip("- ")
vqa_state.answer_label = sample.loc[0, "answer_label"]
vqa_state.question_lang_id = sample.loc[0, "lang_id"]
vqa_state.answer_lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
image = plt.imread(image_path)
vqa_state.vqa_image = image
transformed_image = get_transformed_image(vqa_state.vqa_image)
new_col1, new_col2 = st.beta_columns([5, 5])
# Display Image
new_col1.image(vqa_state.vqa_image, use_column_width="always")
# Display Question
question = new_col2.text_input(
label="Question",
value=vqa_state.question,
help="Type your question regarding the image above in one of the four languages.",
)
new_col2.markdown(
f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}"""
)
question_inputs = get_text_attributes(question)
# Select Language
options = ["en", "de", "es", "fr"]
vqa_state.answer_lang_id = new_col2.selectbox(
"Answer Language",
index=options.index(vqa_state.answer_lang_id),
options=options,
format_func=lambda x: code_to_name[x],
help="The language to be used to show the top-5 labels.",
)
actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)]
new_col2.markdown(
"**Actual Answer**: "
+ translate_labels([actual_answer], vqa_state.answer_lang_id)[0]
+ " ("
+ actual_answer
+ ")"
)
with st.spinner("Predicting..."):
logits = predict(transformed_image, dict(question_inputs))
logits = softmax(logits)
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
translated_labels = translate_labels(labels, vqa_state.answer_lang_id)
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
st.plotly_chart(fig, use_container_width=True)