from .utils import ( get_text_attributes, get_top_5_predictions, get_transformed_image, plotly_express_horizontal_bar_plot, translate_labels, bert_tokenizer ) import streamlit as st import numpy as np import pandas as pd import os import matplotlib.pyplot as plt from session import _get_state from .model.flax_clip_vision_bert.modeling_clip_vision_bert import ( FlaxCLIPVisionBertForMaskedLM, ) def softmax(logits): return np.exp(logits) / np.sum(np.exp(logits), axis=0) def app(): state = _get_state() @st.cache(persist=False) def predict(transformed_image, caption_inputs): outputs = state.model(pixel_values=transformed_image, **caption_inputs) indices = np.where(caption_inputs['input_ids']==bert_tokenizer.mask_token_id) preds = outputs.logits[indices][0] sorted_indices = np.argsort(preds)[::-1] # Get reverse sorted scores top_5_indices = sorted_indices[:5] top_5_tokens = bert_tokenizer.convert_ids_to_tokens(top_5_indices) top_5_scores = np.array(preds[top_5_indices]) return top_5_tokens, top_5_scores @st.cache(persist=False) def load_model(ckpt): return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt) mlm_checkpoints = ['flax-community/clip-vision-bert-cc12m-70k'] dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t") first_index = 20 # Init Session State if state.image_file is None: state.image_file = dummy_data.loc[first_index, "image_file"] caption = dummy_data.loc[first_index, "caption"].strip("- ") ids = bert_tokenizer(caption) ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id state.caption = bert_tokenizer.decode(ids) state.caption_lang_id = dummy_data.loc[first_index, "lang_id"] image_path = os.path.join("cc12m_data/images_vqa", state.image_file) image = plt.imread(image_path) state.image = image if state.model is None: # Display Top-5 Predictions with st.spinner("Loading model..."): state.model = load_model(mlm_checkpoints[0]) if st.button( "Get a random example", help="Get a random example from the 100 `seeded` image-text pairs.", ): sample = dummy_data.sample(1).reset_index() state.image_file = sample.loc[0, "image_file"] caption = sample.loc[0, "caption"].strip("- ") ids = bert_tokenizer(caption) ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id state.caption = bert_tokenizer.decode(ids) state.caption_lang_id = sample.loc[0, "lang_id"] image_path = os.path.join("cc12m_data/images_vqa", state.image_file) image = plt.imread(image_path) state.image = image transformed_image = get_transformed_image(state.image) new_col1, new_col2 = st.beta_columns([5, 5]) # Display Image new_col1.image(state.image, use_column_width="always") # Display caption new_col2.write("Write your text with exactly one [MASK] token.") caption = new_col2.text_input( label="Text", value=state.caption, help="Type your masked caption regarding the image above in one of the four languages.", ) caption_inputs = get_text_attributes(caption) # Display Top-5 Predictions with st.spinner("Predicting..."): logits = predict(transformed_image, dict(caption_inputs)) logits = softmax(logits) labels, values = get_top_5_predictions(logits) fig = plotly_express_horizontal_bar_plot(values, labels) st.plotly_chart(fig, use_container_width=True)