from .utils import get_transformed_image import streamlit as st import numpy as np import pandas as pd import os import matplotlib.pyplot as plt import re from mtranslate import translate from .utils import ( read_markdown, tokenizer, language_mapping, code_to_name, voicerss_tts ) import requests from PIL import Image from .model.flax_clip_vision_mbart.modeling_clip_vision_mbart import ( FlaxCLIPVisionMBartForConditionalGeneration, ) from streamlit import caching def app(state): mic_state = state with st.beta_expander("Usage"): st.write(read_markdown("usage.md")) st.write("\n") st.write(read_markdown("intro.md")) # st.sidebar.title("Generation Parameters") max_length = 64 with st.sidebar.beta_expander('Generation Parameters'): do_sample = st.checkbox("Sample", value=False, help="Sample from the model instead of using beam search.") top_k = st.number_input("Top K", min_value=10, max_value=200, value=50, step=1, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.") num_beams = st.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.") temperature = st.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}") top_p = st.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}") if st.button("Clear All Cache"): caching.clear_cache() @st.cache def load_model(ckpt): return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt) @st.cache def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length): lang_code = language_mapping[lang_code] output_ids = mic_state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=max_length, num_beams=num_beams, temperature=temperature, top_p = top_p, top_k=top_k, do_sample=do_sample) print(output_ids) output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length) return output_sequence mic_checkpoints = ["flax-community/clip-vit-base-patch32_mbart-large-50"] # TODO: Maybe add more checkpoints? dummy_data = pd.read_csv("reference.tsv", sep="\t") first_index = 25 # Init Session State if mic_state.image_file is None: mic_state.image_file = dummy_data.loc[first_index, "image_file"] mic_state.caption = dummy_data.loc[first_index, "caption"].strip("- ") mic_state.lang_id = dummy_data.loc[first_index, "lang_id"] image_path = os.path.join("images", mic_state.image_file) image = plt.imread(image_path) mic_state.image = image if mic_state.model is None: # Display Top-5 Predictions with st.spinner("Loading model..."): mic_state.model = load_model(mic_checkpoints[0]) query1 = st.text_input( "Enter a URL to an image", value="http://images.cocodataset.org/val2017/000000397133.jpg", ) col1, col2, col3 = st.beta_columns([2,1, 2]) if col1.button( "Get a random example", help="Get a random example from the 100 `seeded` image-text pairs.", ): sample = dummy_data.sample(1).reset_index() mic_state.image_file = sample.loc[0, "image_file"] mic_state.caption = sample.loc[0, "caption"].strip("- ") mic_state.lang_id = sample.loc[0, "lang_id"] image_path = os.path.join("images", mic_state.image_file) image = plt.imread(image_path) mic_state.image = image col2.write("OR") if col3.button("Use above URL"): image_data = requests.get(query1, stream=True).raw image = np.asarray(Image.open(image_data)) mic_state.image = image transformed_image = get_transformed_image(mic_state.image) new_col1, new_col2 = st.beta_columns([5,5]) # Display Image new_col1.image(mic_state.image, use_column_width="always") # Display Reference Caption with new_col1.beta_expander("Reference Caption"): st.write("**Reference Caption**: " + mic_state.caption) st.markdown( f"""**English Translation**: {mic_state.caption if mic_state.lang_id == "en" else translate(mic_state.caption, 'en')}""" ) # Select Language options = list(code_to_name.keys()) lang_id = new_col2.selectbox( "Language", index=options.index(mic_state.lang_id), options=options, format_func=lambda x: code_to_name[x], help="The language in which caption is to be generated." ) sequence = [''] if new_col2.button("Generate Caption", help="Generate a caption in the specified language."): with st.spinner("Generating Sequence... This might take some time, you can read our Article meanwhile!"): sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p, do_sample, top_k, max_length) # print(sequence) if sequence!=['']: new_col2.write( "**Generated Caption**: "+sequence[0] ) new_col2.write( "**English Translation**: "+ (sequence[0] if lang_id=="en" else translate(sequence[0])) ) with new_col2: try: clean_text = re.sub(r'[^A-Za-z0-9 ]+', '', sequence[0]) # st.write("**Cleaned Text**: ",clean_text) audio_bytes = voicerss_tts(clean_text, lang_id) st.markdown("**Audio for the generated caption**") st.audio(audio_bytes) except: st.info("Unabled to generate audio. Please try again in some time.")