bhavitvyamalik's picture
change static image
e6c6e8b
raw
history blame
8.53 kB
from apps import article, mic
import streamlit as st
from session import _get_state
from multiapp import MultiApp
# from io import BytesIO
# from apps.utils import read_markdown
# from apps import article
# import streamlit as st
# import pandas as pd
# import os
# import numpy as np
# from streamlit import caching
# from PIL import Image
# from model.flax_clip_vision_mbart.modeling_clip_vision_mbart import (
# FlaxCLIPVisionMBartForConditionalGeneration,
# )
# import matplotlib.pyplot as plt
# from mtranslate import translate
# from session import _get_state
# state = _get_state()
# @st.cache
# def load_model(ckpt):
# return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt)
# tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
# language_mapping = {
# "en": "en_XX",
# "de": "de_DE",
# "fr": "fr_XX",
# "es": "es_XX"
# }
# code_to_name = {
# "en": "English",
# "fr": "French",
# "de": "German",
# "es": "Spanish",
# }
# @st.cache
# def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length):
# lang_code = language_mapping[lang_code]
# output_ids = state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=max_length, num_beams=num_beams, temperature=temperature, top_p = top_p, top_k=top_k, do_sample=do_sample)
# print(output_ids)
# output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length)
# return output_sequence
# checkpoints = ["./ckpt/ckpt-51999"] # TODO: Maybe add more checkpoints?
# dummy_data = pd.read_csv("reference.tsv", sep="\t")
# st.sidebar.title("Generation Parameters")
# # max_length = st.sidebar.number_input("Max Length", min_value=16, max_value=128, value=64, step=1, help="The maximum length of sequence to be generated.")
# max_length = 64
# do_sample = st.sidebar.checkbox("Sample", value=False, help="Sample from the model instead of using beam search.")
# top_k = st.sidebar.number_input("Top K", min_value=10, max_value=200, value=50, step=1, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
# num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
# temperature = st.sidebar.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
# top_p = st.sidebar.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
# if st.sidebar.button("Clear All Cache"):
# caching.clear_cache()
# image_col, intro_col = st.beta_columns([3, 8])
# image_col.image("./misc/mic-logo.png", use_column_width="always")
# intro_col.write(read_markdown("intro.md"))
# with st.beta_expander("Usage"):
# st.markdown(read_markdown("usage.md"))
# with st.beta_expander("Article"):
# st.write(read_markdown("abstract.md"))
# st.write("## Methodology")
# st.image(
# "./misc/Multilingual-IC.png"
# )
# st.markdown(read_markdown("pretraining.md"))
# st.write(read_markdown("challenges.md"))
# st.write(read_markdown("social_impact.md"))
# st.write(read_markdown("bias.md"))
# col1, col2, col3, col4 = st.beta_columns([0.5,2.5,2.5,0.5])
# with col2:
# st.image("./misc/examples/female_dev_1.jpg", width=350, caption = 'German Caption: <PERSON> arbeitet an einem Computer.', use_column_width='always')
# with col3:
# st.image("./misc/examples/female_doctor.jpg", width=350, caption = 'English Caption: A portrait of <PERSON>, a doctor who specializes in health care.', use_column_width='always')
# col1, col2, col3, col4 = st.beta_columns([0.5,2.5,2.5,0.5])
# with col2:
# st.image("./misc/examples/female_doctor_1.jpg", width=350, caption = 'Spanish Caption: El Dr. <PERSON> es un estudiante de posgrado.', use_column_width='always')
# with col3:
# st.image("./misc/examples/women_cricket.jpg", width=350, caption = 'English Caption: <PERSON> of India bats against <PERSON> of Australia during the first Twenty20 match between India and Australia at Indian Bowl Stadium in New Delhi on Friday. - PTI', use_column_width='always')
# col1, col2, col3, col4 = st.beta_columns([0.5,2.5,2.5,0.5])
# with col2:
# st.image("./misc/examples/female_dev_2.jpg", width=350, caption = "French Caption: Un écran d'ordinateur avec un écran d'ordinateur ouvert.", use_column_width='always')
# with col3:
# st.image("./misc/examples/female_biker_resized.jpg", width=350, caption = 'German Caption: <PERSON> auf dem Motorrad von <PERSON>.', use_column_width='always')
# st.write(read_markdown("future_scope.md"))
# st.write(read_markdown("references.md"))
# # st.write(read_markdown("checkpoints.md"))
# st.write(read_markdown("acknowledgements.md"))
# if state.model is None:
# with st.spinner("Loading model..."):
# state.model = load_model(checkpoints[0])
# first_index = 25
# # Init Session State
# if state.image_file is None:
# state.image_file = dummy_data.loc[first_index, "image_file"]
# state.caption = dummy_data.loc[first_index, "caption"].strip("- ")
# state.lang_id = dummy_data.loc[first_index, "lang_id"]
# image_path = os.path.join("images", state.image_file)
# image = plt.imread(image_path)
# state.image = image
# if st.button("Get a random example", help="Get a random example from one of the seeded examples."):
# sample = dummy_data.sample(1).reset_index()
# state.image_file = sample.loc[0, "image_file"]
# state.caption = sample.loc[0, "caption"].strip("- ")
# state.lang_id = sample.loc[0, "lang_id"]
# image_path = os.path.join("images", state.image_file)
# image = plt.imread(image_path)
# state.image = image
# transformed_image = get_transformed_image(state.image)
# new_col1, new_col2 = st.beta_columns([5,5])
# # Display Image
# new_col1.image(state.image, use_column_width="always")
# # Display Reference Caption
# with new_col1.beta_expander("Reference Caption"):
# st.write("**Reference Caption**: " + state.caption)
# st.markdown(
# f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}"""
# )
# # Select Language
# options = list(code_to_name.keys())
# lang_id = new_col2.selectbox(
# "Language",
# index=options.index(state.lang_id),
# options=options,
# format_func=lambda x: code_to_name[x],
# help="The language in which caption is to be generated."
# )
# sequence = ['']
# if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
# with st.spinner("Generating Sequence..."):
# sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p, do_sample, top_k, max_length)
# # print(sequence)
# if sequence!=['']:
# new_col2.write(
# "**Generated Caption**: "+sequence[0]
# )
# new_col2.write(
# "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
# )
def main():
state = _get_state()
st.set_page_config(
page_title="Multilingual Image Captioning",
layout="wide",
initial_sidebar_state="auto",
page_icon="./misc/mic-logo.png",
)
st.title("Multilingual Image Captioning")
st.write(
"[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)"
)
st.sidebar.title("Multilingual Image Captioning")
logo = st.sidebar.image("./misc/mic-logo.png")
st.sidebar.write("Multilingual Image Captioning addresses the challenge of caption generation for an image in a multilingual setting. Here, we fuse CLIP Vision transformer into mBART50 and perform training on translated version of Conceptual-12M dataset. Please use the radio buttons below to navigate.")
app = MultiApp(state)
app.add_app("Article", article.app)
app.add_app("Multilingual Image Captioning", mic.app)
app.run()
state.sync()
if __name__ == "__main__":
main()