import streamlit as st from huggingface_hub import snapshot_download import os # utility library # libraries to load the model and serve inference import tensorflow_text import tensorflow as tf def main(): st.title("Interactive demo: T5 Multitasking Demo") st.write("**Demo for T5's different tasks including machine translation, \ text summarization, document similarity, and grammatical correctness of sentences.**") load_model_cache() dashboard() @st.cache def load_model_cache(): """Function to retrieve the model from HuggingFace Hub, load it and cache it using st.cache wrapper """ CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded if not os.path.exists(CACHE_DIR): os.mkdir(CACHE_DIR) # download the files from huggingface repo and load the model with tensorflow snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR) saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0]) global model model = tf.saved_model.load(saved_model_path, ["serve"]) def dashboard(): task_type = st.sidebar.radio("Task Type", [ "Translate English to French", "Translate English to German", "Translate English to Romanian", "Grammatical Correctness of Sentence", "Text Summarization", "Document Similarity Score" ]) if task_type.startswith("Document Similarity"): sentence1 = st.text("The first document/sentence.", "I reside in the commercial capital city of Nigeria, which is Lagos") sentence2 = st.text("The second document/sentence.", "I live in Lagos") sentence = sentence1 + "---" + sentence2 elif task_type.startswith("Text Summarization"): sentence = st.text_area("Input sentence.", "I don't care about those doing the comparison, but comparing the Ghanaian Jollof Rice to \ Nigerian Jollof Rice is an insult to Nigerians") else: sentence = st.text_area("Input sentence.", "I am Steven and I live in Lagos, Nigeria") st.write("**Output Text**") st.write(predict(task_type, sentence)) def predict(task_type, sentence): """Function to parse the user inputs, run the parsed text through the model and return output in a readable format. params: task_type sentence representing the type of task to run on T5 model sentence sentence to get inference on returns: text decoded into a human-readable format. """ task_dict = { "Translate English to French": "Translate English to French", "Translate English to German": "Translate English to German", "Translate English to Romanian": "Translate English to Romanian", "Grammatical Correctness of Sentence": "cola sentence", "Text Summarization": "summarize", "Document Similarity Score (separate the 2 sentences with 3 dashes `---`)": "stsb", } question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5 # Document Similarity takes in two sentences so it has to be parsed in a separate manner if task_type.startswith("Document Similarity"): sentences = sentence.split('---') question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}" return predict_fn([question])[0].decode('utf-8') def predict_fn(x): """Function to get inferences from model on live data points. params: x input text to run get output on returns: a numpy array representing the output """ return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy() if __name__ == "__main__": main()