Spaces:
Runtime error
Runtime error
from .utils import ( | |
get_text_attributes, | |
get_top_5_predictions, | |
get_transformed_image, | |
plotly_express_horizontal_bar_plot, | |
translate_labels, | |
) | |
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import os | |
import requests | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import json | |
from mtranslate import translate | |
from .utils import read_markdown | |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import ( | |
FlaxCLIPVisionBertForSequenceClassification, | |
) | |
def softmax(logits): | |
return np.exp(logits) / np.sum(np.exp(logits), axis=0) | |
def app(state): | |
vqa_state = state | |
st.header("Visual Question Answering Demo") | |
with st.beta_expander("Usage"): | |
st.write(read_markdown("vqa_usage.md")) | |
st.info(read_markdown("vqa_intro.md")) | |
# @st.cache(persist=False) | |
def predict(transformed_image, question_inputs): | |
return np.array( | |
vqa_state.vqa_model(pixel_values=transformed_image, **question_inputs)[0][0] | |
) | |
# @st.cache(persist=False) | |
def load_model(ckpt): | |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt) | |
vqa_checkpoints = [ | |
"flax-community/clip-vision-bert-vqa-ft-6k" | |
] # TODO: Maybe add more checkpoints? | |
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t") | |
code_to_name = { | |
"en": "English", | |
"fr": "French", | |
"de": "German", | |
"es": "Spanish", | |
} | |
with open("answer_reverse_mapping.json") as f: | |
answer_reverse_mapping = json.load(f) | |
first_index = 20 | |
# Init Session vqa_state | |
if vqa_state.vqa_image_file is None: | |
vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"] | |
vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ") | |
vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"] | |
vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"] | |
vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"] | |
image_path = os.path.join("resized_images", vqa_state.vqa_image_file) | |
image = plt.imread(image_path) | |
vqa_state.vqa_image = image | |
if vqa_state.vqa_model is None: | |
with st.spinner("Loading model..."): | |
vqa_state.vqa_model = load_model(vqa_checkpoints[0]) | |
# Display Top-5 Predictions | |
query1 = st.text_input( | |
"Enter a URL to an image", | |
value="http://images.cocodataset.org/val2017/000000039769.jpg", | |
) | |
col1, col2, col3 = st.beta_columns([2,1, 2]) | |
if col1.button( | |
"Get a random example", | |
help="Get a random example from the 100 `seeded` image-text pairs.", | |
): | |
sample = dummy_data.sample(1).reset_index() | |
vqa_state.vqa_image_file = sample.loc[0, "image_file"] | |
vqa_state.question = sample.loc[0, "question"].strip("- ") | |
vqa_state.answer_label = sample.loc[0, "answer_label"] | |
vqa_state.question_lang_id = sample.loc[0, "lang_id"] | |
vqa_state.answer_lang_id = sample.loc[0, "lang_id"] | |
image_path = os.path.join("resized_images", vqa_state.vqa_image_file) | |
image = plt.imread(image_path) | |
vqa_state.vqa_image = image | |
col2.write("OR") | |
if col3.button("Use above URL"): | |
image_data = requests.get(query1, stream=True).raw | |
image = np.asarray(Image.open(image_data)) | |
vqa_state.vqa_image = image | |
transformed_image = get_transformed_image(vqa_state.vqa_image) | |
new_col1, new_col2 = st.beta_columns([5, 5]) | |
# Display Image | |
new_col1.image(vqa_state.vqa_image, use_column_width="auto") | |
# Display Question | |
question = new_col2.text_input( | |
label="Question", | |
value=vqa_state.question, | |
help="Type your question regarding the image above in one of the four languages.", | |
) | |
new_col2.markdown( | |
f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}""" | |
) | |
question_inputs = get_text_attributes(question) | |
# Select Language | |
options = ["en", "de", "es", "fr"] | |
vqa_state.answer_lang_id = new_col2.selectbox( | |
"Answer Language", | |
index=options.index(vqa_state.answer_lang_id), | |
options=options, | |
format_func=lambda x: code_to_name[x], | |
help="The language to be used to show the top-5 labels.", | |
) | |
if question == vqa_state.question: | |
actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)] | |
new_col2.markdown( | |
"**Actual Answer**: " | |
+ translate_labels([actual_answer], vqa_state.answer_lang_id)[0] | |
+ " (" | |
+ actual_answer | |
+ ")" | |
) | |
with st.spinner("Predicting..."): | |
logits = predict(transformed_image, dict(question_inputs)) | |
logits = softmax(logits) | |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping) | |
translated_labels = translate_labels(labels, vqa_state.answer_lang_id) | |
fig = plotly_express_horizontal_bar_plot(values, translated_labels) | |
st.plotly_chart(fig, use_container_width=True) | |