inclusive-ml commited on
Commit
97c311c
1 Parent(s): 4ee896d

initial commit

Browse files
Files changed (2) hide show
  1. app.py +132 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from huggingface_hub import snapshot_download
3
+ import os # utility library
4
+ # libraries to load the model and serve inference
5
+ import tensorflow_text
6
+ import tensorflow as tf
7
+ def main():
8
+ st.title("Interactive demo: T5 Multitasking Demo")
9
+ st.sidebar.image("https://i.gzn.jp/img/2020/02/25/google-ai-t5/01.png")
10
+ saved_model_path = load_model_cache()
11
+ # Model is loaded in st.session_state to remain stateless across reloading
12
+ if 'model' not in st.session_state:
13
+ st.session_state.model = tf.saved_model.load(saved_model_path, ["serve"])
14
+ dashboard(st.session_state.model)
15
+ @st.cache
16
+ def load_model_cache():
17
+ """Function to retrieve the model from HuggingFace Hub and cache it using st.cache wrapper
18
+ """
19
+ CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded
20
+ if not os.path.exists(CACHE_DIR):
21
+ os.mkdir(CACHE_DIR)
22
+ # download the files from huggingface repo and load the model with tensorflow
23
+ snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
24
+ saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
25
+ return saved_model_path
26
+ def dashboard(model):
27
+ """Function to display the inputs and results
28
+ params:
29
+ model stateless model to run inference from
30
+ """
31
+ task_type = st.sidebar.radio("Task Type",
32
+ [
33
+ "Translate English to French",
34
+ "Translate English to German",
35
+ "Translate English to Romanian",
36
+ "Grammatical Correctness of Sentence",
37
+ "Text Summarization",
38
+ "Document Similarity Score"
39
+ ])
40
+ default_sentence = "I am Steven and I live in Lagos, Nigeria."
41
+ text_summarization_sentence = "I don't care about those doing the comparison, but comparing \
42
+ the Ghanaian Jollof Rice to Nigerian Jollof Rice is an insult to Nigerians."
43
+ doc_similarity_sentence1 = "I reside in the commercial capital city of Nigeria, which is Lagos."
44
+ doc_similarity_sentence2 = "I live in Lagos."
45
+ help_msg = "You could either type in the sentences to run inferences on or use the upload button to \
46
+ upload text files containing those sentences. The input sentence box, by default, displays sample \
47
+ texts or the texts in the files that you've uploaded. Feel free to erase them and type in new sentences."
48
+ if task_type.startswith("Document Similarity"): # document similarity requires two documents
49
+ uploaded_file = upload_files(help_msg, text="Upload 2 documents for similarity check", accept_multiple_files=True)
50
+ if uploaded_file:
51
+ sentence1 = st.text_area("Enter first document/sentence", uploaded_file[0], help=help_msg)
52
+ sentence2 = st.text_area("Enter second document/sentence", uploaded_file[1], help=help_msg)
53
+ else:
54
+ sentence1 = st.text_area("Enter first document/sentence", doc_similarity_sentence1)
55
+ sentence2 = st.text_area("Enter second document/sentence", doc_similarity_sentence2)
56
+ sentence = sentence1 + "---" + sentence2 # to be processed like other tasks' single sentences
57
+ else:
58
+ uploaded_file = upload_files(help_msg)
59
+ if uploaded_file:
60
+ sentence = st.text_area("Enter sentence", uploaded_file, help=help_msg)
61
+ elif task_type.startswith("Text Summarization"): # text summarization's default input should be longer
62
+ sentence = st.text_area("Enter sentence", text_summarization_sentence, help=help_msg)
63
+ else:
64
+ sentence = st.text_area("Enter sentence", default_sentence, help=help_msg)
65
+ st.write("**Output Text**")
66
+ with st.spinner("Waiting for prediction..."): # spinner while model is running inferences
67
+ output_text = predict(task_type, sentence, model)
68
+ st.write(output_text)
69
+ try: # to workaround the environment's Streamlit version
70
+ st.download_button("Download output text", output_text)
71
+ except AttributeError:
72
+ st.text("File download not enabled for this Streamlit version \U0001F612")
73
+ def upload_files(help_msg, text="Upload a text file here", accept_multiple_files=False):
74
+ """Function to upload text files and return as string text
75
+ params:
76
+ text Display label for the upload button
77
+ accept_multiple_files params for the file_uploader function to accept more than a file
78
+ returns:
79
+ a string or a list of strings (in case of multiple files being uploaded)
80
+ """
81
+ def upload():
82
+ uploaded_files = st.file_uploader(label="Upload text files only",
83
+ type="txt", help=help_msg,
84
+ accept_multiple_files=accept_multiple_files)
85
+ if st.button("Process"):
86
+ if not uploaded_files:
87
+ st.write("**No file uploaded!**")
88
+ return None
89
+ st.write("**Upload successful!**")
90
+ if type(uploaded_files) == list:
91
+ return [f.read().decode("utf-8") for f in uploaded_files]
92
+ return uploaded_files.read().decode("utf-8")
93
+ try: # to workaround the environment's Streamlit version
94
+ with st.expander(text):
95
+ return upload()
96
+ except AttributeError:
97
+ return upload()
98
+ def predict(task_type, sentence, model):
99
+ """Function to parse the user inputs, run the parsed text through the
100
+ model and return output in a readable format.
101
+ params:
102
+ task_type sentence representing the type of task to run on T5 model
103
+ sentence sentence to get inference on
104
+ model model to get inferences from
105
+ returns:
106
+ text decoded into a human-readable format.
107
+ """
108
+ task_dict = {
109
+ "Translate English to French": "Translate English to French",
110
+ "Translate English to German": "Translate English to German",
111
+ "Translate English to Romanian": "Translate English to Romanian",
112
+ "Grammatical Correctness of Sentence": "cola sentence",
113
+ "Text Summarization": "summarize",
114
+ "Document Similarity Score": "stsb",
115
+ }
116
+ question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5
117
+ # Document Similarity takes in two sentences so it has to be parsed in a separate manner
118
+ if task_type.startswith("Document Similarity"):
119
+ sentences = sentence.split('---')
120
+ question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
121
+ return predict_fn([question], model)[0].decode('utf-8')
122
+ def predict_fn(x, model):
123
+ """Function to get inferences from model on live data points.
124
+ params:
125
+ x input text to run get output on
126
+ model model to run inferences from
127
+ returns:
128
+ a numpy array representing the output
129
+ """
130
+ return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
131
+ if __name__ == "__main__":
132
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ t5
2
+ huggingface_hub
3
+ streamlit==1.0.0