stevenkolawole's picture
create application file
f1984f7
raw
history blame
4.08 kB
import streamlit as st
from huggingface_hub import snapshot_download
import os # utility library
# libraries to load the model and serve inference
import tensorflow_text
import tensorflow as tf
def main():
st.title("Interactive demo: T5 Multitasking Demo")
st.write("**Demo for T5's different tasks including machine translation, \
text summarization, document similarity, and grammatical correctness of sentences.**")
load_model_cache()
dashboard()
@st.cache
def load_model_cache():
"""Function to retrieve the model from HuggingFace Hub, load it and cache it using st.cache wrapper
"""
CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded
if not os.path.exists(CACHE_DIR):
os.mkdir(CACHE_DIR)
# download the files from huggingface repo and load the model with tensorflow
snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
global model
model = tf.saved_model.load(saved_model_path, ["serve"])
def dashboard():
task_type = st.sidebar.radio("Task Type",
[
"Translate English to French",
"Translate English to German",
"Translate English to Romanian",
"Grammatical Correctness of Sentence",
"Text Summarization",
"Document Similarity Score"
])
if task_type.startswith("Document Similarity"):
sentence1 = st.text("The first document/sentence.",
"I reside in the commercial capital city of Nigeria, which is Lagos")
sentence2 = st.text("The second document/sentence.",
"I live in Lagos")
sentence = sentence1 + "---" + sentence2
elif task_type.startswith("Text Summarization"):
sentence = st.text_area("Input sentence.",
"I don't care about those doing the comparison, but comparing the Ghanaian Jollof Rice to \
Nigerian Jollof Rice is an insult to Nigerians")
else:
sentence = st.text_area("Input sentence.",
"I am Steven and I live in Lagos, Nigeria")
st.write("**Output Text**")
st.write(predict(task_type, sentence))
def predict(task_type, sentence):
"""Function to parse the user inputs, run the parsed text through the
model and return output in a readable format.
params:
task_type sentence representing the type of task to run on T5 model
sentence sentence to get inference on
returns:
text decoded into a human-readable format.
"""
task_dict = {
"Translate English to French": "Translate English to French",
"Translate English to German": "Translate English to German",
"Translate English to Romanian": "Translate English to Romanian",
"Grammatical Correctness of Sentence": "cola sentence",
"Text Summarization": "summarize",
"Document Similarity Score (separate the 2 sentences with 3 dashes `---`)": "stsb",
}
question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5
# Document Similarity takes in two sentences so it has to be parsed in a separate manner
if task_type.startswith("Document Similarity"):
sentences = sentence.split('---')
question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
return predict_fn([question])[0].decode('utf-8')
def predict_fn(x):
"""Function to get inferences from model on live data points.
params:
x input text to run get output on
returns:
a numpy array representing the output
"""
return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
if __name__ == "__main__":
main()