from .utils import ( get_text_attributes, get_top_5_predictions, get_transformed_image, plotly_express_horizontal_bar_plot, translate_labels, ) import streamlit as st import numpy as np import pandas as pd import os import requests from PIL import Image import matplotlib.pyplot as plt import json from mtranslate import translate from .utils import read_markdown from .model.flax_clip_vision_bert.modeling_clip_vision_bert import ( FlaxCLIPVisionBertForSequenceClassification, ) def softmax(logits): return np.exp(logits) / np.sum(np.exp(logits), axis=0) def app(state): vqa_state = state with st.beta_expander("Usage"): st.write(read_markdown("vqa_usage.md")) st.write(read_markdown("vqa_intro.md")) # @st.cache(persist=False) def predict(transformed_image, question_inputs): return np.array( vqa_state.vqa_model(pixel_values=transformed_image, **question_inputs)[0][0] ) # @st.cache(persist=False) def load_model(ckpt): return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt) vqa_checkpoints = [ "flax-community/clip-vision-bert-vqa-ft-6k" ] # TODO: Maybe add more checkpoints? dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t") code_to_name = { "en": "English", "fr": "French", "de": "German", "es": "Spanish", } with open("answer_reverse_mapping.json") as f: answer_reverse_mapping = json.load(f) first_index = 20 # Init Session vqa_state if vqa_state.vqa_image_file is None: vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"] vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ") vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"] vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"] vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"] image_path = os.path.join("resized_images", vqa_state.vqa_image_file) image = plt.imread(image_path) vqa_state.vqa_image = image if vqa_state.vqa_model is None: with st.spinner("Loading model..."): vqa_state.vqa_model = load_model(vqa_checkpoints[0]) # Display Top-5 Predictions if st.button( "Get a random example", help="Get a random example from the 100 `seeded` image-text pairs.", ): sample = dummy_data.sample(1).reset_index() vqa_state.vqa_image_file = sample.loc[0, "image_file"] vqa_state.question = sample.loc[0, "question"].strip("- ") vqa_state.answer_label = sample.loc[0, "answer_label"] vqa_state.question_lang_id = sample.loc[0, "lang_id"] vqa_state.answer_lang_id = sample.loc[0, "lang_id"] image_path = os.path.join("resized_images", vqa_state.vqa_image_file) image = plt.imread(image_path) vqa_state.vqa_image = image st.write("OR") query1 = st.text_input( "Enter a URL to an image", value="http://images.cocodataset.org/val2017/000000039769.jpg", ) if st.button("Use this URL"): image_data = requests.get(query1, stream=True).raw image = np.asarray(Image.open(image_data)) vqa_state.mlm_image = image transformed_image = get_transformed_image(vqa_state.vqa_image) new_col1, new_col2 = st.beta_columns([5, 5]) # Display Image new_col1.image(vqa_state.vqa_image, use_column_width="auto") # Display Question question = new_col2.text_input( label="Question", value=vqa_state.question, help="Type your question regarding the image above in one of the four languages.", ) new_col2.markdown( f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}""" ) question_inputs = get_text_attributes(question) # Select Language options = ["en", "de", "es", "fr"] vqa_state.answer_lang_id = new_col2.selectbox( "Answer Language", index=options.index(vqa_state.answer_lang_id), options=options, format_func=lambda x: code_to_name[x], help="The language to be used to show the top-5 labels.", ) actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)] new_col2.markdown( "**Actual Answer**: " + translate_labels([actual_answer], vqa_state.answer_lang_id)[0] + " (" + actual_answer + ")" ) with st.spinner("Predicting..."): logits = predict(transformed_image, dict(question_inputs)) logits = softmax(logits) labels, values = get_top_5_predictions(logits, answer_reverse_mapping) translated_labels = translate_labels(labels, vqa_state.answer_lang_id) fig = plotly_express_horizontal_bar_plot(values, translated_labels) st.plotly_chart(fig, use_container_width=True)