gchhablani commited on
Commit
324f080
1 Parent(s): c2067d8

Init basic app

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import json
5
+ import os
6
+ import numpy as np
7
+ from streamlit.elements import markdown
8
+ from PIL import Image
9
+ from model.flax_clip_vision_mbart.modeling_clip_vision_mbart import (
10
+ FlaxCLIPVisionMBartForConditionalGeneration,
11
+ )
12
+ from transformers import MBart50TokenizerFast
13
+ from utils import (
14
+ get_transformed_image,
15
+ )
16
+ import matplotlib.pyplot as plt
17
+ from mtranslate import translate
18
+
19
+
20
+ from session import _get_state
21
+
22
+ state = _get_state()
23
+
24
+
25
+ @st.cache
26
+ def load_model(ckpt):
27
+ return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt)
28
+
29
+
30
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
31
+
32
+ language_mapping = {
33
+ "en": "en_XX",
34
+ "de": "de_DE",
35
+ "fr": "fr_XX",
36
+ "es": "es_XX"
37
+ }
38
+
39
+ code_to_name = {
40
+ "en": "English",
41
+ "fr": "French",
42
+ "de": "German",
43
+ "es": "Spanish",
44
+ }
45
+
46
+ @st.cache(persist=True)
47
+ def generate_sequence(pixel_values, lang_code, num_beams):
48
+ lang_code = language_mapping[lang_code]
49
+ output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams)
50
+ print(output_ids)
51
+ output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
52
+ return output_sequence
53
+
54
+ def read_markdown(path, parent="./sections/"):
55
+ with open(os.path.join(parent, path)) as f:
56
+ return f.read()
57
+
58
+
59
+ checkpoints = ["./ckpt/ckpt-22499"] # TODO: Maybe add more checkpoints?
60
+ dummy_data = pd.read_csv("reference.tsv", sep="\t")
61
+
62
+ st.set_page_config(
63
+ page_title="Multilingual Image Captioning",
64
+ layout="wide",
65
+ initial_sidebar_state="collapsed",
66
+ )
67
+
68
+ st.title("Multilingual Image Captioning")
69
+ st.write(
70
+ "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)"
71
+ )
72
+
73
+ st.sidebar.title("Settings")
74
+ num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
75
+
76
+ with st.beta_expander("Usage"):
77
+ st.markdown(read_markdown("usage.md"))
78
+
79
+ first_index = 20
80
+ # Init Session State
81
+ if state.image_file is None:
82
+ state.image_file = dummy_data.loc[first_index, "image_file"]
83
+ state.caption = dummy_data.loc[first_index, "caption"].strip("- ")
84
+ state.lang_id = dummy_data.loc[first_index, "lang_id"]
85
+
86
+ image_path = os.path.join("images", state.image_file)
87
+ image = plt.imread(image_path)
88
+ state.image = image
89
+
90
+ col1, col2 = st.beta_columns([6, 4])
91
+
92
+ if col2.button("Get a random example"):
93
+ sample = dummy_data.sample(1).reset_index()
94
+ state.image_file = sample.loc[0, "image_file"]
95
+ state.caption = sample.loc[0, "caption"].strip("- ")
96
+ state.lang_id = sample.loc[0, "lang_id"]
97
+
98
+ image_path = os.path.join("images", state.image_file)
99
+ image = plt.imread(image_path)
100
+ state.image = image
101
+
102
+ col2.write("OR")
103
+
104
+ uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
105
+ if uploaded_file is not None:
106
+ state.image_file = os.path.join("images", uploaded_file.name)
107
+ state.image = np.array(Image.open(uploaded_file))
108
+
109
+ transformed_image = get_transformed_image(state.image)
110
+
111
+ # Display Image
112
+ col1.image(state.image, use_column_width="auto")
113
+
114
+ # Display Reference Caption
115
+ col2.write("**Reference Caption**: " + state.caption)
116
+ col2.markdown(
117
+ f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}"""
118
+ )
119
+
120
+ # Select Language
121
+ options = list(code_to_name.keys())
122
+ lang_id = col2.selectbox(
123
+ "Language",
124
+ index=options.index(state.lang_id),
125
+ options=options,
126
+ format_func=lambda x: code_to_name[x],
127
+ )
128
+ # Display Top-5 Predictions
129
+ with st.spinner("Loading model..."):
130
+ model = load_model(checkpoints[0])
131
+
132
+ sequence = ['']
133
+ if col2.button("Generate Caption"):
134
+ with st.spinner("Generating Sequence..."):
135
+ sequence = generate_sequence(transformed_image, lang_id, num_beams)
136
+ # print(sequence)
137
+
138
+ if sequence!=['']:
139
+ st.write(
140
+ "**Generated Caption**: "+sequence[0]
141
+ )
142
+
143
+ st.write(
144
+ "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
145
+ )
146
+ st.write(read_markdown("abstract.md"))
147
+ st.write(read_markdown("caveats.md"))
148
+ # st.write("# Methodology")
149
+ # st.image(
150
+ # "./misc/Multilingual-IC.png", caption="Seq2Seq model for Image-text Captioning."
151
+ # )
152
+ st.markdown(read_markdown("pretraining.md"))
153
+ st.write(read_markdown("challenges.md"))
154
+ st.write(read_markdown("social_impact.md"))
155
+ st.write(read_markdown("references.md"))
156
+ # st.write(read_markdown("checkpoints.md"))
157
+ st.write(read_markdown("acknowledgements.md"))