gchhablani's picture
Add MLM task
405f2d4
raw
history blame
3.67 kB
from .utils import (
get_text_attributes,
get_top_5_predictions,
get_transformed_image,
plotly_express_horizontal_bar_plot,
translate_labels,
bert_tokenizer
)
import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from session import _get_state
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
FlaxCLIPVisionBertForMaskedLM,
)
def softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
def app():
state = _get_state()
@st.cache(persist=False)
def predict(transformed_image, caption_inputs):
outputs = state.model(pixel_values=transformed_image, **caption_inputs)
indices = np.where(caption_inputs['input_ids']==bert_tokenizer.mask_token_id)
preds = outputs.logits[indices][0]
sorted_indices = np.argsort(preds)[::-1] # Get reverse sorted scores
top_5_indices = sorted_indices[:5]
top_5_tokens = bert_tokenizer.convert_ids_to_tokens(top_5_indices)
top_5_scores = np.array(preds[top_5_indices])
return top_5_tokens, top_5_scores
@st.cache(persist=False)
def load_model(ckpt):
return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
mlm_checkpoints = ['flax-community/clip-vision-bert-cc12m-70k']
dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
first_index = 20
# Init Session State
if state.image_file is None:
state.image_file = dummy_data.loc[first_index, "image_file"]
caption = dummy_data.loc[first_index, "caption"].strip("- ")
ids = bert_tokenizer(caption)
ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id
state.caption = bert_tokenizer.decode(ids)
state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("cc12m_data/images_vqa", state.image_file)
image = plt.imread(image_path)
state.image = image
if state.model is None:
# Display Top-5 Predictions
with st.spinner("Loading model..."):
state.model = load_model(mlm_checkpoints[0])
if st.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
state.image_file = sample.loc[0, "image_file"]
caption = sample.loc[0, "caption"].strip("- ")
ids = bert_tokenizer(caption)
ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id
state.caption = bert_tokenizer.decode(ids)
state.caption_lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("cc12m_data/images_vqa", state.image_file)
image = plt.imread(image_path)
state.image = image
transformed_image = get_transformed_image(state.image)
new_col1, new_col2 = st.beta_columns([5, 5])
# Display Image
new_col1.image(state.image, use_column_width="always")
# Display caption
new_col2.write("Write your text with exactly one [MASK] token.")
caption = new_col2.text_input(
label="Text",
value=state.caption,
help="Type your masked caption regarding the image above in one of the four languages.",
)
caption_inputs = get_text_attributes(caption)
# Display Top-5 Predictions
with st.spinner("Predicting..."):
logits = predict(transformed_image, dict(caption_inputs))
logits = softmax(logits)
labels, values = get_top_5_predictions(logits)
fig = plotly_express_horizontal_bar_plot(values, labels)
st.plotly_chart(fig, use_container_width=True)