stevenkolawole commited on
Commit
d9b4b87
1 Parent(s): 580e072

add comments to app file

Browse files
Files changed (1) hide show
  1. app.py +36 -21
app.py CHANGED
@@ -12,13 +12,18 @@ def main():
12
  st.title("Interactive demo: T5 Multitasking Demo")
13
  st.write("**Demo for T5's different tasks including machine translation, \
14
  text summarization, document similarity, and grammatical correctness of sentences.**")
15
- model = load_model_cache()
16
- dashboard(model)
 
 
 
 
 
17
 
18
 
19
  @st.cache
20
  def load_model_cache():
21
- """Function to retrieve the model from HuggingFace Hub, load it and cache it using st.cache wrapper
22
  """
23
  CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded
24
  if not os.path.exists(CACHE_DIR):
@@ -27,11 +32,15 @@ def load_model_cache():
27
  # download the files from huggingface repo and load the model with tensorflow
28
  snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
29
  saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
30
- model = tf.saved_model.load(saved_model_path, ["serve"])
31
- return model
32
-
33
 
34
  def dashboard(model):
 
 
 
 
 
35
  task_type = st.sidebar.radio("Task Type",
36
  [
37
  "Translate English to French",
@@ -41,22 +50,26 @@ def dashboard(model):
41
  "Text Summarization",
42
  "Document Similarity Score"
43
  ])
44
- if task_type.startswith("Document Similarity"):
45
- sentence1 = st.text("The first document/sentence.",
46
- "I reside in the commercial capital city of Nigeria, which is Lagos")
47
- sentence2 = st.text("The second document/sentence.",
48
- "I live in Lagos")
49
  sentence = sentence1 + "---" + sentence2
50
- elif task_type.startswith("Text Summarization"):
51
- sentence = st.text_area("Input sentence.",
52
- "I don't care about those doing the comparison, but comparing the Ghanaian Jollof Rice to \
53
- Nigerian Jollof Rice is an insult to Nigerians")
54
  else:
55
- sentence = st.text_area("Input sentence.",
56
- "I am Steven and I live in Lagos, Nigeria")
57
- st.write("**Output Text**")
58
- st.write(predict(task_type, sentence, model))
59
 
 
 
 
 
 
 
60
 
61
  def predict(task_type, sentence, model):
62
  """Function to parse the user inputs, run the parsed text through the
@@ -64,6 +77,7 @@ def predict(task_type, sentence, model):
64
  params:
65
  task_type sentence representing the type of task to run on T5 model
66
  sentence sentence to get inference on
 
67
  returns:
68
  text decoded into a human-readable format.
69
  """
@@ -80,13 +94,14 @@ def predict(task_type, sentence, model):
80
  if task_type.startswith("Document Similarity"):
81
  sentences = sentence.split('---')
82
  question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
83
- return predict_fn([question])[0].decode('utf-8')
84
 
85
 
86
  def predict_fn(x, model):
87
  """Function to get inferences from model on live data points.
88
  params:
89
- x input text to run get output on
 
90
  returns:
91
  a numpy array representing the output
92
  """
 
12
  st.title("Interactive demo: T5 Multitasking Demo")
13
  st.write("**Demo for T5's different tasks including machine translation, \
14
  text summarization, document similarity, and grammatical correctness of sentences.**")
15
+ saved_model_path = load_model_cache()
16
+
17
+ # Model is loaded in st.session_state to remain stateless across reloading
18
+ if 'model' not in st.session_state:
19
+ st.session_state.model = tf.saved_model.load(saved_model_path, ["serve"])
20
+
21
+ dashboard(st.session_state.model)
22
 
23
 
24
  @st.cache
25
  def load_model_cache():
26
+ """Function to retrieve the model from HuggingFace Hub and cache it using st.cache wrapper
27
  """
28
  CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded
29
  if not os.path.exists(CACHE_DIR):
 
32
  # download the files from huggingface repo and load the model with tensorflow
33
  snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
34
  saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
35
+ return saved_model_path
36
+
 
37
 
38
  def dashboard(model):
39
+ """"Function to display the inputs and results
40
+ params:
41
+ model stateless model to run inference from
42
+ """
43
+ st.sidebar.write("**Select the Task Type over here**")
44
  task_type = st.sidebar.radio("Task Type",
45
  [
46
  "Translate English to French",
 
50
  "Text Summarization",
51
  "Document Similarity Score"
52
  ])
53
+ if task_type.startswith("Document Similarity"): # document similarity requires two documents
54
+ sentence1 = st.text_area("The first document/sentence",
55
+ "I reside in the commercial capital city of Nigeria, which is Lagos.")
56
+ sentence2 = st.text_area("The second document/sentence",
57
+ "I live in Lagos.")
58
  sentence = sentence1 + "---" + sentence2
59
+ elif task_type.startswith("Text Summarization"): # text summarization's default input should be longer
60
+ sentence = st.text_area("Input sentence",
61
+ "I don't care about those doing the comparison, but comparing the Ghanaian Jollof Rice \
62
+ to Nigerian Jollof Rice is an insult to Nigerians.")
63
  else:
64
+ sentence = st.text_area("Input sentence",
65
+ "I am Steven and I live in Lagos, Nigeria.")
 
 
66
 
67
+ st.write("**Output Text**")
68
+ with st.spinner("Please wait..."): # spinner while model is running inferences
69
+ output_text = predict(task_type, sentence, model)
70
+ st.write(output_text)
71
+ # st.download_button("Download output text", output_text) # download_button is yet to be production-ready
72
+
73
 
74
  def predict(task_type, sentence, model):
75
  """Function to parse the user inputs, run the parsed text through the
 
77
  params:
78
  task_type sentence representing the type of task to run on T5 model
79
  sentence sentence to get inference on
80
+ model model to get inferences from
81
  returns:
82
  text decoded into a human-readable format.
83
  """
 
94
  if task_type.startswith("Document Similarity"):
95
  sentences = sentence.split('---')
96
  question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
97
+ return predict_fn([question], model)[0].decode('utf-8')
98
 
99
 
100
  def predict_fn(x, model):
101
  """Function to get inferences from model on live data points.
102
  params:
103
+ x input text to run get output on
104
+ model model to run inferences from
105
  returns:
106
  a numpy array representing the output
107
  """