stevenkolawole's picture
add file uploader & download functions
d352f8e
raw
history blame
7.17 kB
import streamlit as st
from huggingface_hub import snapshot_download
import os # utility library
# libraries to load the model and serve inference
import tensorflow_text
import tensorflow as tf
def main():
st.title("Interactive demo: T5 Multitasking Demo")
st.sidebar.image("https://i.gzn.jp/img/2020/02/25/google-ai-t5/01.png")
saved_model_path = load_model_cache()
# Model is loaded in st.session_state to remain stateless across reloading
if 'model' not in st.session_state:
st.session_state.model = tf.saved_model.load(saved_model_path, ["serve"])
dashboard(st.session_state.model)
@st.cache
def load_model_cache():
"""Function to retrieve the model from HuggingFace Hub and cache it using st.cache wrapper
"""
CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded
if not os.path.exists(CACHE_DIR):
os.mkdir(CACHE_DIR)
# download the files from huggingface repo and load the model with tensorflow
snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
return saved_model_path
def dashboard(model):
"""Function to display the inputs and results
params:
model stateless model to run inference from
"""
task_type = st.sidebar.radio("Task Type",
[
"Translate English to French",
"Translate English to German",
"Translate English to Romanian",
"Grammatical Correctness of Sentence",
"Text Summarization",
"Document Similarity Score"
])
default_sentence = "I am Steven and I live in Lagos, Nigeria."
text_summarization_sentence = "I don't care about those doing the comparison, but comparing \
the Ghanaian Jollof Rice to Nigerian Jollof Rice is an insult to Nigerians."
doc_similarity_sentence1 = "I reside in the commercial capital city of Nigeria, which is Lagos."
doc_similarity_sentence2 = "I live in Lagos."
help_msg = "You could either type in the sentences to run inferences on or use the upload button to \
upload text files containing those sentences. The input sentence box, by default, displays sample \
texts or the texts in the files that you've uploaded. Feel free to erase them and type in new sentences."
if task_type.startswith("Document Similarity"): # document similarity requires two documents
uploaded_file = upload_files(help_msg, text="Upload 2 documents for similarity check", accept_multiple_files=True)
if uploaded_file:
sentence1 = st.text_area("Enter first document/sentence", uploaded_file[0], help=help_msg)
sentence2 = st.text_area("Enter second document/sentence", uploaded_file[1], help=help_msg)
else:
sentence1 = st.text_area("Enter first document/sentence", doc_similarity_sentence1)
sentence2 = st.text_area("Enter second document/sentence", doc_similarity_sentence2)
sentence = sentence1 + "---" + sentence2 # to be processed like other tasks' single sentences
else:
uploaded_file = upload_files(help_msg)
if uploaded_file:
sentence = st.text_area("Enter sentence", uploaded_file, help=help_msg)
elif task_type.startswith("Text Summarization"): # text summarization's default input should be longer
sentence = st.text_area("Enter sentence", text_summarization_sentence, help=help_msg)
else:
sentence = st.text_area("Enter sentence", default_sentence, help=help_msg)
st.write("**Output Text**")
with st.spinner("Waiting for prediction..."): # spinner while model is running inferences
output_text = predict(task_type, sentence, model)
st.write(output_text)
try: # to workaround the environment's Streamlit version
st.download_button("Download output text", output_text)
except AttributeError:
st.text("File download not enabled for this Streamlit version \U0001F612")
def upload_files(help_msg, text="Upload a text file here", accept_multiple_files=False):
"""Function to upload text files and return as string text
params:
text Display label for the upload button
accept_multiple_files params for the file_uploader function to accept more than a file
returns:
a string or a list of strings (in case of multiple files being uploaded)
"""
def upload():
uploaded_files = st.file_uploader(label="Upload text files only",
type="txt", help=help_msg,
accept_multiple_files=accept_multiple_files)
if st.button("Process"):
if not uploaded_files:
st.write("**No file uploaded!**")
return None
st.write("**Upload successful!**")
if type(uploaded_files) == list:
return [f.read().decode("utf-8") for f in uploaded_files]
return uploaded_files.read().decode("utf-8")
try: # to workaround the environment's Streamlit version
with st.expander(text):
return upload()
except AttributeError:
return upload()
def predict(task_type, sentence, model):
"""Function to parse the user inputs, run the parsed text through the
model and return output in a readable format.
params:
task_type sentence representing the type of task to run on T5 model
sentence sentence to get inference on
model model to get inferences from
returns:
text decoded into a human-readable format.
"""
task_dict = {
"Translate English to French": "Translate English to French",
"Translate English to German": "Translate English to German",
"Translate English to Romanian": "Translate English to Romanian",
"Grammatical Correctness of Sentence": "cola sentence",
"Text Summarization": "summarize",
"Document Similarity Score": "stsb",
}
question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5
# Document Similarity takes in two sentences so it has to be parsed in a separate manner
if task_type.startswith("Document Similarity"):
sentences = sentence.split('---')
question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
return predict_fn([question], model)[0].decode('utf-8')
def predict_fn(x, model):
"""Function to get inferences from model on live data points.
params:
x input text to run get output on
model model to run inferences from
returns:
a numpy array representing the output
"""
return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
if __name__ == "__main__":
main()