stevenkolawole commited on
Commit
f1984f7
1 Parent(s): 0a91998

create application file

Browse files
Files changed (2) hide show
  1. app.py +97 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from huggingface_hub import snapshot_download
3
+
4
+ import os # utility library
5
+
6
+ # libraries to load the model and serve inference
7
+ import tensorflow_text
8
+ import tensorflow as tf
9
+
10
+
11
+ def main():
12
+ st.title("Interactive demo: T5 Multitasking Demo")
13
+ st.write("**Demo for T5's different tasks including machine translation, \
14
+ text summarization, document similarity, and grammatical correctness of sentences.**")
15
+ load_model_cache()
16
+ dashboard()
17
+
18
+
19
+ @st.cache
20
+ def load_model_cache():
21
+ """Function to retrieve the model from HuggingFace Hub, load it and cache it using st.cache wrapper
22
+ """
23
+ CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded
24
+ if not os.path.exists(CACHE_DIR):
25
+ os.mkdir(CACHE_DIR)
26
+
27
+ # download the files from huggingface repo and load the model with tensorflow
28
+ snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
29
+ saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
30
+ global model
31
+ model = tf.saved_model.load(saved_model_path, ["serve"])
32
+
33
+
34
+ def dashboard():
35
+ task_type = st.sidebar.radio("Task Type",
36
+ [
37
+ "Translate English to French",
38
+ "Translate English to German",
39
+ "Translate English to Romanian",
40
+ "Grammatical Correctness of Sentence",
41
+ "Text Summarization",
42
+ "Document Similarity Score"
43
+ ])
44
+ if task_type.startswith("Document Similarity"):
45
+ sentence1 = st.text("The first document/sentence.",
46
+ "I reside in the commercial capital city of Nigeria, which is Lagos")
47
+ sentence2 = st.text("The second document/sentence.",
48
+ "I live in Lagos")
49
+ sentence = sentence1 + "---" + sentence2
50
+ elif task_type.startswith("Text Summarization"):
51
+ sentence = st.text_area("Input sentence.",
52
+ "I don't care about those doing the comparison, but comparing the Ghanaian Jollof Rice to \
53
+ Nigerian Jollof Rice is an insult to Nigerians")
54
+ else:
55
+ sentence = st.text_area("Input sentence.",
56
+ "I am Steven and I live in Lagos, Nigeria")
57
+ st.write("**Output Text**")
58
+ st.write(predict(task_type, sentence))
59
+
60
+
61
+ def predict(task_type, sentence):
62
+ """Function to parse the user inputs, run the parsed text through the
63
+ model and return output in a readable format.
64
+ params:
65
+ task_type sentence representing the type of task to run on T5 model
66
+ sentence sentence to get inference on
67
+ returns:
68
+ text decoded into a human-readable format.
69
+ """
70
+ task_dict = {
71
+ "Translate English to French": "Translate English to French",
72
+ "Translate English to German": "Translate English to German",
73
+ "Translate English to Romanian": "Translate English to Romanian",
74
+ "Grammatical Correctness of Sentence": "cola sentence",
75
+ "Text Summarization": "summarize",
76
+ "Document Similarity Score (separate the 2 sentences with 3 dashes `---`)": "stsb",
77
+ }
78
+ question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5
79
+ # Document Similarity takes in two sentences so it has to be parsed in a separate manner
80
+ if task_type.startswith("Document Similarity"):
81
+ sentences = sentence.split('---')
82
+ question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
83
+ return predict_fn([question])[0].decode('utf-8')
84
+
85
+
86
+ def predict_fn(x):
87
+ """Function to get inferences from model on live data points.
88
+ params:
89
+ x input text to run get output on
90
+ returns:
91
+ a numpy array representing the output
92
+ """
93
+ return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ t5
2
+ huggingface_hub
3
+ streamlit