gchhablani's picture
Move back to online checkpoints
7fdcddd
raw history blame
No virus
5.77 kB
from .utils import (
get_text_attributes,
get_top_5_predictions,
get_transformed_image,
plotly_express_horizontal_bar_plot,
bert_tokenizer,
)
import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from mtranslate import translate
from .utils import read_markdown
import requests
from PIL import Image
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
FlaxCLIPVisionBertForMaskedLM,
)
def softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
def app(state):
mlm_state = state
st.header("Visuo-linguistic Mask Filling Demo")
with st.beta_expander("Usage"):
st.write(read_markdown("mlm_usage.md"))
st.info(read_markdown("mlm_intro.md"))
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
def predict(transformed_image, caption_inputs):
outputs = mlm_state.mlm_model(pixel_values=transformed_image, **caption_inputs)
indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[1][0]
preds = outputs.logits[0][indices]
scores = np.array(preds)
return scores
# @st.cache(persist=False)
def load_model(ckpt):
return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
#mlm_checkpoints = ["./ckpt/mlm/ckpt-60k"]
dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
first_index = 15
# Init Session mlm_state
if mlm_state.mlm_image_file is None:
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
caption = dummy_data.loc[first_index, "caption"].strip("- ")
mlm_state.unmasked_caption = caption
ids = bert_tokenizer.encode(caption)
mask_index = np.random.randint(1, len(ids) - 1)
mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens([ids[mask_index]])[0]
ids[mask_index] = bert_tokenizer.mask_token_id
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file)
image = plt.imread(image_path)
mlm_state.mlm_image = image
if mlm_state.mlm_model is None:
# Display Top-5 Predictions
with st.spinner("Loading model..."):
mlm_state.mlm_model = load_model(mlm_checkpoints[0])
query1 = st.text_input(
"Enter a URL to an image",
value="http://images.cocodataset.org/val2017/000000039769.jpg",
)
col1, col2, col3 = st.beta_columns([2,1, 2])
if col1.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
mlm_state.mlm_image_file = sample.loc[0, "image_file"]
caption = sample.loc[0, "caption"].strip("- ")
mlm_state.unmasked_caption = caption
ids = bert_tokenizer.encode(caption)
mask_index = np.random.randint(1, len(ids) - 1)
mlm_state.currently_masked_token = bert_tokenizer.convert_ids_to_tokens([ids[mask_index]])[0]
ids[mask_index] = bert_tokenizer.mask_token_id
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file)
image = plt.imread(image_path)
mlm_state.mlm_image = image
col2.write("OR")
if col3.button("Use above URL"):
image_data = requests.get(query1, stream=True).raw
image = np.asarray(Image.open(image_data))
mlm_state.mlm_image = image
transformed_image = get_transformed_image(mlm_state.mlm_image)
new_col1, new_col2 = st.beta_columns([5, 5])
# Display Image
new_col1.image(mlm_state.mlm_image, use_column_width="auto")
# Display caption
new_col2.write("Write your text with exactly one [MASK] token.")
mlm_state.caption = new_col2.text_input(
label="Text",
value=mlm_state.caption,
help="Type your masked caption regarding the image above in one of the four languages.",
)
print(mlm_state.currently_maskd_token)
print(mlm_state.unmasked_caption)
print(mlm_state.caption)
if mlm_state.unmasked_caption == mlm_state.caption.replace("[MASK]", mlm_state.currently_masked_token):
new_col2.markdown("**Masked Token**: "+mlm_state.currently_masked_token)
new_col2.markdown("**English Translation: " + mlm_state.unmasked_caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.unmasked_caption, 'en'))
else:
new_col2.markdown(
f"""**English Translation**: {mlm_state.caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.caption, 'en')}"""
)
caption_inputs = get_text_attributes(mlm_state.caption)
# Display Top-5 Predictions
with st.spinner("Predicting..."):
scores = predict(transformed_image, dict(caption_inputs))
scores = softmax(scores)
labels, values = get_top_5_predictions(scores)
filled_sentence = mlm_state.caption.replace("[MASK]", labels[-1])
st.write("**Filled Sentence**: " + filled_sentence)
st.write( f"""**English Translation**: {translate(filled_sentence, 'en')}""")
# newer_col1, newer_col2 = st.beta_columns([6,4])
fig = plotly_express_horizontal_bar_plot(values, labels)
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
st.plotly_chart(fig, use_container_width=True)