import json import os from io import BytesIO import matplotlib.pyplot as plt import numpy as np import pandas as pd import streamlit as st from mtranslate import translate from PIL import Image from streamlit.elements import markdown from model.flax_clip_vision_bert.modeling_clip_vision_bert import ( FlaxCLIPVisionBertForSequenceClassification, ) from session import _get_state from utils import ( get_text_attributes, get_top_5_predictions, get_transformed_image, plotly_express_horizontal_bar_plot, translate_labels, ) state = _get_state() @st.cache(persist=True) def load_model(ckpt): return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt) @st.cache(persist=True) def predict(transformed_image, question_inputs): return np.array(model(pixel_values=transformed_image, **question_inputs)[0][0]) def softmax(logits): return np.exp(logits) / np.sum(np.exp(logits), axis=0) def read_markdown(path, parent="./sections/"): with open(os.path.join(parent, path)) as f: return f.read() checkpoints = ["./ckpt/vqa/ckpt-60k-5999"] # TODO: Maybe add more checkpoints? dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t") code_to_name = { "en": "English", "fr": "French", "de": "German", "es": "Spanish", } with open("answer_reverse_mapping.json") as f: answer_reverse_mapping = json.load(f) st.set_page_config( page_title="Multilingual VQA", layout="wide", initial_sidebar_state="collapsed", page_icon="./misc/mvqa-logo-3-white.png", ) st.title("Multilingual Visual Question Answering") st.write( "[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)" ) image_col, intro_col = st.beta_columns([3, 8]) image_col.image("./misc/mvqa-logo-3-white.png", use_column_width="always") intro_col.write(read_markdown("intro.md")) with st.beta_expander("Usage"): st.write(read_markdown("usage.md")) with st.beta_expander("Article"): st.write(read_markdown("abstract.md")) st.write(read_markdown("caveats.md")) st.write("## Methodology") st.image( "./misc/articles/resized/Multilingual-VQA.png", caption="Masked LM model for Image-text Pretraining.", ) st.markdown(read_markdown("pretraining.md")) st.markdown(read_markdown("finetuning.md")) st.write(read_markdown("challenges.md")) st.write(read_markdown("social_impact.md")) st.write(read_markdown("references.md")) st.write(read_markdown("checkpoints.md")) st.write(read_markdown("acknowledgements.md")) first_index = 20 # Init Session State if state.image_file is None: state.image_file = dummy_data.loc[first_index, "image_file"] state.question = dummy_data.loc[first_index, "question"].strip("- ") state.answer_label = dummy_data.loc[first_index, "answer_label"] state.question_lang_id = dummy_data.loc[first_index, "lang_id"] state.answer_lang_id = dummy_data.loc[first_index, "lang_id"] image_path = os.path.join("resized_images", state.image_file) image = plt.imread(image_path) state.image = image # col1, col2, col3 = st.beta_columns([3,3,3]) if st.button( "Get a random example", help="Get a random example from the 100 `seeded` image-text pairs.", ): sample = dummy_data.sample(1).reset_index() state.image_file = sample.loc[0, "image_file"] state.question = sample.loc[0, "question"].strip("- ") state.answer_label = sample.loc[0, "answer_label"] state.question_lang_id = sample.loc[0, "lang_id"] state.answer_lang_id = sample.loc[0, "lang_id"] image_path = os.path.join("resized_images", state.image_file) image = plt.imread(image_path) state.image = image # col2.write("OR") # uploaded_file = col2.file_uploader( # "Upload your image", # type=["png", "jpg", "jpeg"], # help="Upload a file of your choosing.", # ) # if uploaded_file is not None: # state.image_file = os.path.join("images/val2014", uploaded_file.name) # state.image = np.array(Image.open(uploaded_file)) transformed_image = get_transformed_image(state.image) new_col1, new_col2 = st.beta_columns([5, 5]) # Display Image new_col1.image(state.image, use_column_width="always") # Display Question question = new_col2.text_input( label="Question", value=state.question, help="Type your question regarding the image above in one of the four languages.", ) new_col2.markdown( f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}""" ) question_inputs = get_text_attributes(question) # Select Language options = ["en", "de", "es", "fr"] state.answer_lang_id = new_col2.selectbox( "Answer Language", index=options.index(state.answer_lang_id), options=options, format_func=lambda x: code_to_name[x], help="The language to be used to show the top-5 labels.", ) actual_answer = answer_reverse_mapping[str(state.answer_label)] new_col2.markdown( "**Actual Answer**: " + translate_labels([actual_answer], state.answer_lang_id)[0] + " (" + actual_answer + ")" ) # Display Top-5 Predictions with st.spinner("Loading model..."): model = load_model(checkpoints[0]) with st.spinner("Predicting..."): logits = predict(transformed_image, dict(question_inputs)) logits = softmax(logits) labels, values = get_top_5_predictions(logits, answer_reverse_mapping) translated_labels = translate_labels(labels, state.answer_lang_id) fig = plotly_express_horizontal_bar_plot(values, translated_labels) st.plotly_chart(fig, use_container_width=True)