gchhablani's picture
Fix style
e289356
raw
history blame
No virus
5.85 kB
import json
import os
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from mtranslate import translate
from PIL import Image
from streamlit.elements import markdown
from model.flax_clip_vision_bert.modeling_clip_vision_bert import (
FlaxCLIPVisionBertForSequenceClassification,
)
from session import _get_state
from utils import (
get_text_attributes,
get_top_5_predictions,
get_transformed_image,
plotly_express_horizontal_bar_plot,
translate_labels,
)
state = _get_state()
@st.cache(persist=True)
def load_model(ckpt):
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
@st.cache(persist=True)
def predict(transformed_image, question_inputs):
return np.array(model(pixel_values=transformed_image, **question_inputs)[0][0])
def softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
def read_markdown(path, parent="./sections/"):
with open(os.path.join(parent, path)) as f:
return f.read()
# def resize_height(image, new_height):
# h, w, c = image.shape
# new_width = int(w * new_height / h)
# return cv2.resize(image, (new_width, new_height))
checkpoints = ["./ckpt/ckpt-60k-5999"] # TODO: Maybe add more checkpoints?
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
code_to_name = {
"en": "English",
"fr": "French",
"de": "German",
"es": "Spanish",
}
with open("answer_reverse_mapping.json") as f:
answer_reverse_mapping = json.load(f)
st.set_page_config(
page_title="Multilingual VQA",
layout="wide",
initial_sidebar_state="collapsed",
page_icon="./misc/mvqa-logo-white.png",
)
st.title("Multilingual Visual Question Answering")
st.write(
"[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)"
)
image_col, intro_col = st.beta_columns([3, 8])
image_col.image("./misc/mvqa-logo-white.png", use_column_width="always")
intro_col.write(read_markdown("intro.md"))
with st.beta_expander("Usage"):
st.write(read_markdown("usage.md"))
with st.beta_expander("Article"):
st.write(read_markdown("abstract.md"))
st.write(read_markdown("caveats.md"))
st.write("## Methodology")
st.image(
"./misc/Multilingual-VQA.png",
caption="Masked LM model for Image-text Pretraining.",
)
st.markdown(read_markdown("pretraining.md"))
st.markdown(read_markdown("finetuning.md"))
st.write(read_markdown("challenges.md"))
st.write(read_markdown("social_impact.md"))
st.write(read_markdown("references.md"))
st.write(read_markdown("checkpoints.md"))
st.write(read_markdown("acknowledgements.md"))
first_index = 20
# Init Session State
if state.image_file is None:
state.image_file = dummy_data.loc[first_index, "image_file"]
state.question = dummy_data.loc[first_index, "question"].strip("- ")
state.answer_label = dummy_data.loc[first_index, "answer_label"]
state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("images", state.image_file)
image = plt.imread(image_path)
state.image = image
col1, col2 = st.beta_columns([6, 4])
if col2.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
state.image_file = sample.loc[0, "image_file"]
state.question = sample.loc[0, "question"].strip("- ")
state.answer_label = sample.loc[0, "answer_label"]
state.question_lang_id = sample.loc[0, "lang_id"]
state.answer_lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("images", state.image_file)
image = plt.imread(image_path)
state.image = image
col2.write("OR")
uploaded_file = col2.file_uploader(
"Upload your image",
type=["png", "jpg", "jpeg"],
help="Upload a file of your choosing.",
)
if uploaded_file is not None:
st.error(
"Uploading files does not work on HuggingFace spaces. This app only supports random examples for now."
)
# state.image_file = os.path.join("images/val2014", uploaded_file.name)
# state.image = np.array(Image.open(uploaded_file))
transformed_image = get_transformed_image(state.image)
# Display Image
col1.image(state.image, use_column_width="auto")
new_col1, new_col2 = st.beta_columns([5, 5])
# Display Question
question = new_col1.text_input(
label="Question",
value=state.question,
help="Type your question regarding the image above in one of the four languages.",
)
new_col1.markdown(
f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
)
question_inputs = get_text_attributes(question)
# Select Language
options = ["en", "de", "es", "fr"]
state.answer_lang_id = new_col2.selectbox(
"Answer Language",
index=options.index(state.answer_lang_id),
options=options,
format_func=lambda x: code_to_name[x],
help="The language to be used to show the top-5 labels.",
)
actual_answer = answer_reverse_mapping[str(state.answer_label)]
new_col2.markdown(
"**Actual Answer**: "
+ translate_labels([actual_answer], state.answer_lang_id)[0]
+ " ("
+ actual_answer
+ ")"
)
# Display Top-5 Predictions
with st.spinner("Loading model..."):
model = load_model(checkpoints[0])
with st.spinner("Predicting..."):
logits = predict(transformed_image, dict(question_inputs))
logits = softmax(logits)
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
translated_labels = translate_labels(labels, state.answer_lang_id)
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
st.plotly_chart(fig, use_container_width=True)