Spaces:
Runtime error
Runtime error
import streamlit as st | |
from huggingface_hub import snapshot_download | |
import os # utility library | |
# libraries to load the model and serve inference | |
import tensorflow_text | |
import tensorflow as tf | |
def main(): | |
st.title("Interactive demo: T5 Multitasking Demo") | |
st.write("**Demo for T5's different tasks including machine translation, \ | |
text summarization, document similarity, and grammatical correctness of sentences.**") | |
load_model_cache() | |
dashboard() | |
def load_model_cache(): | |
"""Function to retrieve the model from HuggingFace Hub, load it and cache it using st.cache wrapper | |
""" | |
CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded | |
if not os.path.exists(CACHE_DIR): | |
os.mkdir(CACHE_DIR) | |
# download the files from huggingface repo and load the model with tensorflow | |
snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR) | |
saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0]) | |
global model | |
model = tf.saved_model.load(saved_model_path, ["serve"]) | |
def dashboard(): | |
task_type = st.sidebar.radio("Task Type", | |
[ | |
"Translate English to French", | |
"Translate English to German", | |
"Translate English to Romanian", | |
"Grammatical Correctness of Sentence", | |
"Text Summarization", | |
"Document Similarity Score" | |
]) | |
if task_type.startswith("Document Similarity"): | |
sentence1 = st.text("The first document/sentence.", | |
"I reside in the commercial capital city of Nigeria, which is Lagos") | |
sentence2 = st.text("The second document/sentence.", | |
"I live in Lagos") | |
sentence = sentence1 + "---" + sentence2 | |
elif task_type.startswith("Text Summarization"): | |
sentence = st.text_area("Input sentence.", | |
"I don't care about those doing the comparison, but comparing the Ghanaian Jollof Rice to \ | |
Nigerian Jollof Rice is an insult to Nigerians") | |
else: | |
sentence = st.text_area("Input sentence.", | |
"I am Steven and I live in Lagos, Nigeria") | |
st.write("**Output Text**") | |
st.write(predict(task_type, sentence)) | |
def predict(task_type, sentence): | |
"""Function to parse the user inputs, run the parsed text through the | |
model and return output in a readable format. | |
params: | |
task_type sentence representing the type of task to run on T5 model | |
sentence sentence to get inference on | |
returns: | |
text decoded into a human-readable format. | |
""" | |
task_dict = { | |
"Translate English to French": "Translate English to French", | |
"Translate English to German": "Translate English to German", | |
"Translate English to Romanian": "Translate English to Romanian", | |
"Grammatical Correctness of Sentence": "cola sentence", | |
"Text Summarization": "summarize", | |
"Document Similarity Score (separate the 2 sentences with 3 dashes `---`)": "stsb", | |
} | |
question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5 | |
# Document Similarity takes in two sentences so it has to be parsed in a separate manner | |
if task_type.startswith("Document Similarity"): | |
sentences = sentence.split('---') | |
question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}" | |
return predict_fn([question])[0].decode('utf-8') | |
def predict_fn(x): | |
"""Function to get inferences from model on live data points. | |
params: | |
x input text to run get output on | |
returns: | |
a numpy array representing the output | |
""" | |
return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy() | |
if __name__ == "__main__": | |
main() |