from io import BytesIO import streamlit as st import pandas as pd import json import os import numpy as np from streamlit import caching from PIL import Image from model.flax_clip_vision_mbart.modeling_clip_vision_mbart import ( FlaxCLIPVisionMBartForConditionalGeneration, ) from transformers import MBart50TokenizerFast from utils import ( get_transformed_image, ) import matplotlib.pyplot as plt from mtranslate import translate from session import _get_state state = _get_state() @st.cache def load_model(ckpt): return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt) tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50") language_mapping = { "en": "en_XX", "de": "de_DE", "fr": "fr_XX", "es": "es_XX" } code_to_name = { "en": "English", "fr": "French", "de": "German", "es": "Spanish", } @st.cache def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p): lang_code = language_mapping[lang_code] output_ids = state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p) print(output_ids) output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64) return output_sequence def read_markdown(path, parent="./sections/"): with open(os.path.join(parent, path)) as f: return f.read() checkpoints = ["./ckpt/ckpt-49499"] # TODO: Maybe add more checkpoints? dummy_data = pd.read_csv("reference.tsv", sep="\t") st.set_page_config( page_title="Multilingual Image Captioning", layout="wide", initial_sidebar_state="collapsed", page_icon="./misc/mic-logo.png", ) st.title("Multilingual Image Captioning") st.write( "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)" ) st.sidebar.title("Generation Parameters") num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.") temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}") top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}") if st.sidebar.button("Clear All Cache"): caching.clear_cache() image_col, intro_col = st.beta_columns([3, 8]) image_col.image("./misc/mic-logo.png", use_column_width="always") intro_col.write(read_markdown("intro.md")) with st.beta_expander("Usage"): st.markdown(read_markdown("usage.md")) with st.beta_expander("Article"): st.write(read_markdown("abstract.md")) st.write(read_markdown("caveats.md")) st.write("## Methodology") st.image( "./misc/Multilingual-IC.png" ) st.markdown(read_markdown("pretraining.md")) st.write(read_markdown("challenges.md")) st.write(read_markdown("social_impact.md")) st.write(read_markdown("future_scope.md")) st.write(read_markdown("references.md")) # st.write(read_markdown("checkpoints.md")) st.write(read_markdown("acknowledgements.md")) if state.model is None: with st.spinner("Loading model..."): state.model = load_model(checkpoints[0]) first_index = 25 # Init Session State if state.image_file is None: state.image_file = dummy_data.loc[first_index, "image_file"] state.caption = dummy_data.loc[first_index, "caption"].strip("- ") state.lang_id = dummy_data.loc[first_index, "lang_id"] image_path = os.path.join("images", state.image_file) image = plt.imread(image_path) state.image = image new_col1, new_col2 = st.beta_columns([5,5]) if new_col2.button("Get a random example", help="Get a random example from one of the seeded examples."): sample = dummy_data.sample(1).reset_index() state.image_file = sample.loc[0, "image_file"] state.caption = sample.loc[0, "caption"].strip("- ") state.lang_id = sample.loc[0, "lang_id"] image_path = os.path.join("images", state.image_file) image = plt.imread(image_path) state.image = image transformed_image = get_transformed_image(state.image) # Display Image new_col1.image(state.image, use_column_width="always") # Display Reference Caption new_col2.write("**Reference Caption**: " + state.caption) new_col2.markdown( f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}""" ) # Select Language options = list(code_to_name.keys()) lang_id = new_col2.selectbox( "Language", index=options.index(state.lang_id), options=options, format_func=lambda x: code_to_name[x], help="The language in which caption is to be generated." ) sequence = [''] if new_col2.button("Generate Caption", help="Generate a caption in the specified language."): with st.spinner("Generating Sequence..."): sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p) # print(sequence) if sequence!=['']: st.write( "**Generated Caption**: "+sequence[0] ) st.write( "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0]) )