Spaces:
Runtime error
Runtime error
File size: 5,470 Bytes
405f2d4 2c8f495 405f2d4 2c8f495 fb3c77c 3f280c5 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 fb3c77c 2c8f495 405f2d4 4b29c6a 2c8f495 405f2d4 2c8f495 405f2d4 f15eef4 2c8f495 405f2d4 f384719 2c8f495 f384719 2c8f495 405f2d4 a4ce24c 405f2d4 2c8f495 405f2d4 4b29c6a 0cb8576 4b29c6a 405f2d4 89ea6a7 95fff6e 405f2d4 2c8f495 405f2d4 f384719 2c8f495 f384719 2c8f495 405f2d4 a4ce24c 405f2d4 2c8f495 405f2d4 89ea6a7 3f280c5 2c8f495 405f2d4 f384719 405f2d4 3b5bd24 405f2d4 2c8f495 405f2d4 3b5bd24 f384719 3b5bd24 f384719 3b5bd24 405f2d4 2c8f495 4bb9586 2c8f495 405f2d4 2c8f495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from .utils import (
get_text_attributes,
get_top_5_predictions,
get_transformed_image,
plotly_express_horizontal_bar_plot,
bert_tokenizer,
)
import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from mtranslate import translate
from .utils import read_markdown
import requests
from PIL import Image
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
FlaxCLIPVisionBertForMaskedLM,
)
def softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
def app(state):
mlm_state = state
with st.beta_expander("Usage"):
st.write(read_markdown("mlm_usage.md"))
st.write(read_markdown("mlm_intro.md"))
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
def predict(transformed_image, caption_inputs):
outputs = mlm_state.mlm_model(pixel_values=transformed_image, **caption_inputs)
indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[1][0]
preds = outputs.logits[0][indices]
scores = np.array(preds)
return scores
# @st.cache(persist=False)
def load_model(ckpt):
return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
first_index = 15
# Init Session mlm_state
if mlm_state.mlm_image_file is None:
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
caption = dummy_data.loc[first_index, "caption"].strip("- ")
mlm_state.unmasked_caption = caption
ids = bert_tokenizer.encode(caption)
mask_index = np.random.randint(1, len(ids) - 1)
mlm_state.currently_masked_token = ids[mask_index]
ids[mask_index] = bert_tokenizer.mask_token_id
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file)
image = plt.imread(image_path)
mlm_state.mlm_image = image
if mlm_state.mlm_model is None:
# Display Top-5 Predictions
with st.spinner("Loading model..."):
mlm_state.mlm_model = load_model(mlm_checkpoints[0])
query1 = st.text_input(
"Enter a URL to an image",
value="http://images.cocodataset.org/val2017/000000039769.jpg",
)
col1, col2, col3 = st.beta_columns([2,1, 2])
if col1.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
mlm_state.mlm_image_file = sample.loc[0, "image_file"]
caption = sample.loc[0, "caption"].strip("- ")
mlm_state.unmasked_caption = caption
ids = bert_tokenizer.encode(caption)
mask_index = np.random.randint(1, len(ids) - 1)
mlm_state.currently_masked_token = ids[mask_index]
ids[mask_index] = bert_tokenizer.mask_token_id
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("cc12m_data/resized_images_vqa", mlm_state.mlm_image_file)
image = plt.imread(image_path)
mlm_state.mlm_image = image
col2.write("OR")
if col3.button("Use above URL"):
image_data = requests.get(query1, stream=True).raw
image = np.asarray(Image.open(image_data))
mlm_state.mlm_image = image
transformed_image = get_transformed_image(mlm_state.mlm_image)
new_col1, new_col2 = st.beta_columns([5, 5])
# Display Image
new_col1.image(mlm_state.mlm_image, use_column_width="auto")
# Display caption
new_col2.write("Write your text with exactly one [MASK] token.")
mlm_state.caption = new_col2.text_input(
label="Text",
value=mlm_state.caption,
help="Type your masked caption regarding the image above in one of the four languages.",
)
if mlm_state.unmasked_caption == mlm_state.caption.replace("[MASK]", mlm_state.currently_masked_token):
new_col2.markdown("**Masked Token**: "+mlm_state.currently_masked_token)
new_col2.markdown("**English Translation: " + mlm_state.unmasked_caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.unmasked_caption, 'en'))
else:
new_col2.markdown(
f"""**English Translation**: {mlm_state.caption if mlm_state.caption_lang_id == "en" else translate(mlm_state.caption, 'en')}"""
)
caption_inputs = get_text_attributes(mlm_state.caption)
# Display Top-5 Predictions
with st.spinner("Predicting..."):
scores = predict(transformed_image, dict(caption_inputs))
scores = softmax(scores)
labels, values = get_top_5_predictions(scores)
filled_sentence = mlm_state.caption.replace("[MASK]", labels[0])
st.write("Filled Sentence: " + filled_sentence)
st.write( f"""**English Translation**: {translate(filled_sentence, 'en')}""")
# newer_col1, newer_col2 = st.beta_columns([6,4])
fig = plotly_express_horizontal_bar_plot(values, labels)
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
st.plotly_chart(fig, use_container_width=True)
|