from io import BytesIO import streamlit as st import pandas as pd import json import os import numpy as np from streamlit.elements import markdown from PIL import Image from model.flax_clip_vision_mbart.modeling_clip_vision_mbart import ( FlaxCLIPVisionMBartForConditionalGeneration, ) from transformers import MBart50TokenizerFast from utils import ( get_transformed_image, ) import matplotlib.pyplot as plt from mtranslate import translate from session import _get_state state = _get_state() @st.cache def load_model(ckpt): return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt) tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50") language_mapping = { "en": "en_XX", "de": "de_DE", "fr": "fr_XX", "es": "es_XX" } code_to_name = { "en": "English", "fr": "French", "de": "German", "es": "Spanish", } @st.cache(persist=True) def generate_sequence(pixel_values, lang_code, num_beams): lang_code = language_mapping[lang_code] output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams) print(output_ids) output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64) return output_sequence def read_markdown(path, parent="./sections/"): with open(os.path.join(parent, path)) as f: return f.read() checkpoints = ["./ckpt/ckpt-22499"] # TODO: Maybe add more checkpoints? dummy_data = pd.read_csv("reference.tsv", sep="\t") st.set_page_config( page_title="Multilingual Image Captioning", layout="wide", initial_sidebar_state="collapsed", ) st.title("Multilingual Image Captioning") st.write( "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)" ) st.sidebar.title("Settings") num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.") with st.beta_expander("Usage"): st.markdown(read_markdown("usage.md")) first_index = 20 # Init Session State if state.image_file is None: state.image_file = dummy_data.loc[first_index, "image_file"] state.caption = dummy_data.loc[first_index, "caption"].strip("- ") state.lang_id = dummy_data.loc[first_index, "lang_id"] image_path = os.path.join("images", state.image_file) image = plt.imread(image_path) state.image = image col1, col2 = st.beta_columns([6, 4]) if col2.button("Get a random example"): sample = dummy_data.sample(1).reset_index() state.image_file = sample.loc[0, "image_file"] state.caption = sample.loc[0, "caption"].strip("- ") state.lang_id = sample.loc[0, "lang_id"] image_path = os.path.join("images", state.image_file) image = plt.imread(image_path) state.image = image col2.write("OR") uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: state.image_file = os.path.join("images", uploaded_file.name) state.image = np.array(Image.open(uploaded_file)) transformed_image = get_transformed_image(state.image) # Display Image col1.image(state.image, use_column_width="auto") # Display Reference Caption col2.write("**Reference Caption**: " + state.caption) col2.markdown( f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}""" ) # Select Language options = list(code_to_name.keys()) lang_id = col2.selectbox( "Language", index=options.index(state.lang_id), options=options, format_func=lambda x: code_to_name[x], ) # Display Top-5 Predictions with st.spinner("Loading model..."): model = load_model(checkpoints[0]) sequence = [''] if col2.button("Generate Caption"): with st.spinner("Generating Sequence..."): sequence = generate_sequence(transformed_image, lang_id, num_beams) # print(sequence) if sequence!=['']: st.write( "**Generated Caption**: "+sequence[0] ) st.write( "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0]) ) st.write(read_markdown("abstract.md")) st.write(read_markdown("caveats.md")) # st.write("# Methodology") # st.image( # "./misc/Multilingual-IC.png", caption="Seq2Seq model for Image-text Captioning." # ) st.markdown(read_markdown("pretraining.md")) st.write(read_markdown("challenges.md")) st.write(read_markdown("social_impact.md")) st.write(read_markdown("references.md")) # st.write(read_markdown("checkpoints.md")) st.write(read_markdown("acknowledgements.md"))