Spaces:
Runtime error
Runtime error
from .utils import ( | |
get_text_attributes, | |
get_top_5_predictions, | |
get_transformed_image, | |
plotly_express_horizontal_bar_plot, | |
bert_tokenizer, | |
) | |
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import os | |
import matplotlib.pyplot as plt | |
from mtranslate import translate | |
from .utils import read_markdown | |
import requests | |
from PIL import Image | |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import ( | |
FlaxCLIPVisionBertForMaskedLM, | |
) | |
def softmax(logits): | |
return np.exp(logits) / np.sum(np.exp(logits), axis=0) | |
def app(state): | |
mlm_state = state | |
with st.beta_expander("Usage"): | |
st.write(read_markdown("mlm_usage.md")) | |
st.write(read_markdown("mlm_intro.md")) | |
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported. | |
def predict(transformed_image, caption_inputs): | |
outputs = mlm_state.mlm_model(pixel_values=transformed_image, **caption_inputs) | |
indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[1][0] | |
preds = outputs.logits[0][indices] | |
scores = np.array(preds) | |
return scores | |
# @st.cache(persist=False) | |
def load_model(ckpt): | |
return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt) | |
mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"] | |
dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t") | |
first_index = 15 | |
# Init Session mlm_state | |
if mlm_state.mlm_image_file is None: | |
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"] | |
caption = dummy_data.loc[first_index, "caption"].strip("- ") | |
mlm_state.unmasked_caption = caption | |
ids = bert_tokenizer.encode(caption) | |
mask_index = np.random.randint(1, len(ids) - 1) | |
mlm_state.currently_masked_token = ids[mask_index] | |
ids[mask_index] = bert_tokenizer.mask_token_id | |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1]) | |
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"] | |
image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file) | |
image = plt.imread(image_path) | |
mlm_state.mlm_image = image | |
if mlm_state.mlm_model is None: | |
# Display Top-5 Predictions | |
with st.spinner("Loading model..."): | |
mlm_state.mlm_model = load_model(mlm_checkpoints[0]) | |
if st.button( | |
"Get a random example", | |
help="Get a random example from the 100 `seeded` image-text pairs.", | |
): | |
sample = dummy_data.sample(1).reset_index() | |
mlm_state.mlm_image_file = sample.loc[0, "image_file"] | |
caption = sample.loc[0, "caption"].strip("- ") | |
mlm_state.unmasked_caption = caption | |
ids = bert_tokenizer.encode(caption) | |
mask_index = np.random.randint(1, len(ids) - 1) | |
mlm_state.currently_masked_token = ids[mask_index] | |
ids[mask_index] = bert_tokenizer.mask_token_id | |
mlm_state.caption = bert_tokenizer.decode(ids[1:-1]) | |
mlm_state.caption_lang_id = sample.loc[0, "lang_id"] | |
image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file) | |
image = plt.imread(image_path) | |
mlm_state.mlm_image = image | |
st.write("OR") | |
query1 = st.text_input( | |
"Enter a URL to an image", | |
value="http://images.cocodataset.org/val2017/000000039769.jpg", | |
) | |
if st.button("Use this URL"): | |
image_data = requests.get(query1, stream=True).raw | |
image = np.asarray(Image.open(image_data)) | |
mlm_state.mlm_image = image | |
transformed_image = get_transformed_image(mlm_state.mlm_image) | |
new_col1, new_col2 = st.beta_columns([5, 5]) | |
# Display Image | |
new_col1.image(mlm_state.mlm_image, use_column_width="auto") | |
# Display caption | |
new_col2.write("Write your text with exactly one [MASK] token.") | |
caption = new_col2.text_input( | |
label="Text", | |
value=mlm_state.caption, | |
help="Type your masked caption regarding the image above in one of the four languages.", | |
) | |
if caption == mlm_state.caption: | |
new_col2.markdown("**Masked Token**: "+mlm_state.currently_masked_token) | |
new_col2.markdown("**English Translation: " + mlm_state.unmasked_caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.unmasked_caption, 'en')) | |
else: | |
new_col2.markdown( | |
f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}""" | |
) | |
caption_inputs = get_text_attributes(caption) | |
# Display Top-5 Predictions | |
with st.spinner("Predicting..."): | |
scores = predict(transformed_image, dict(caption_inputs)) | |
scores = softmax(scores) | |
labels, values = get_top_5_predictions(scores) | |
print(labels) | |
# newer_col1, newer_col2 = st.beta_columns([6,4]) | |
fig = plotly_express_horizontal_bar_plot(values, labels) | |
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T) | |
st.plotly_chart(fig, use_container_width=True) | |