from .utils import ( get_text_attributes, get_top_5_predictions, get_transformed_image, plotly_express_horizontal_bar_plot, bert_tokenizer, ) import streamlit as st import numpy as np import pandas as pd import os import matplotlib.pyplot as plt from mtranslate import translate from .utils import read_markdown import requests from PIL import Image from .model.flax_clip_vision_bert.modeling_clip_vision_bert import ( FlaxCLIPVisionBertForMaskedLM, ) def softmax(logits): return np.exp(logits) / np.sum(np.exp(logits), axis=0) def app(state): mlm_state = state st.header("Visuo-linguistic Mask Filling Demo") with st.beta_expander("Usage"): st.write(read_markdown("mlm_usage.md")) st.info(read_markdown("mlm_intro.md")) # @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported. def predict(transformed_image, caption_inputs): outputs = mlm_state.mlm_model(pixel_values=transformed_image, **caption_inputs) indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[1][0] preds = outputs.logits[0][indices] scores = np.array(preds) return scores # @st.cache(persist=False) def load_model(ckpt): return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt) mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"] #mlm_checkpoints = ["./ckpt/mlm/ckpt-60k"] dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t") first_index = 15 # Init Session mlm_state if mlm_state.mlm_image_file is None: mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"] caption = dummy_data.loc[first_index, "caption"].strip("- ") mlm_state.unmasked_caption = caption ids = bert_tokenizer.encode(caption) mask_index = np.random.randint(1, len(ids) - 1) mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens([ids[mask_index]])[0] ids[mask_index] = bert_tokenizer.mask_token_id mlm_state.caption = bert_tokenizer.decode(ids[1:-1]) mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"] image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file) image = plt.imread(image_path) mlm_state.mlm_image = image if mlm_state.mlm_model is None: # Display Top-5 Predictions with st.spinner("Loading model..."): mlm_state.mlm_model = load_model(mlm_checkpoints[0]) query1 = st.text_input( "Enter a URL to an image", value="http://images.cocodataset.org/val2017/000000039769.jpg", ) col1, col2, col3 = st.beta_columns([2,1, 2]) if col1.button( "Get a random example", help="Get a random example from the 100 `seeded` image-text pairs.", ): sample = dummy_data.sample(1).reset_index() mlm_state.mlm_image_file = sample.loc[0, "image_file"] caption = sample.loc[0, "caption"].strip("- ") mlm_state.unmasked_caption = caption ids = bert_tokenizer.encode(caption) mask_index = np.random.randint(1, len(ids) - 1) mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens([ids[mask_index]])[0] ids[mask_index] = bert_tokenizer.mask_token_id mlm_state.caption = bert_tokenizer.decode(ids[1:-1]) mlm_state.caption_lang_id = sample.loc[0, "lang_id"] image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file) image = plt.imread(image_path) mlm_state.mlm_image = image col2.write("OR") if col3.button("Use above URL"): image_data = requests.get(query1, stream=True).raw image = np.asarray(Image.open(image_data)) mlm_state.mlm_image = image transformed_image = get_transformed_image(mlm_state.mlm_image) new_col1, new_col2 = st.beta_columns([5, 5]) # Display Image new_col1.image(mlm_state.mlm_image, use_column_width="auto") # Display caption new_col2.write("Write your text with exactly one [MASK] token.") mlm_state.caption = new_col2.text_input( label="Text", value=mlm_state.caption, help="Type your masked caption regarding the image above in one of the four languages.", ) print(mlm_state.currently_maskd_token) print(mlm_state.unmasked_caption) print(mlm_state.caption) if mlm_state.unmasked_caption == mlm_state.caption.replace("[MASK]", mlm_state.currently_masked_token): new_col2.markdown("**Masked Token**: "+mlm_state.currently_masked_token) new_col2.markdown("**English Translation: " + mlm_state.unmasked_caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.unmasked_caption, 'en')) else: new_col2.markdown( f"""**English Translation**: {mlm_state.caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.caption, 'en')}""" ) caption_inputs = get_text_attributes(mlm_state.caption) # Display Top-5 Predictions with st.spinner("Predicting..."): scores = predict(transformed_image, dict(caption_inputs)) scores = softmax(scores) labels, values = get_top_5_predictions(scores) filled_sentence = mlm_state.caption.replace("[MASK]", labels[-1]) st.write("**Filled Sentence**: " + filled_sentence) st.write( f"""**English Translation**: {translate(filled_sentence, 'en')}""") # newer_col1, newer_col2 = st.beta_columns([6,4]) fig = plotly_express_horizontal_bar_plot(values, labels) st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T) st.plotly_chart(fig, use_container_width=True)