stevenkolawole commited on
Commit
580e072
1 Parent(s): f1984f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -12,8 +12,8 @@ def main():
12
  st.title("Interactive demo: T5 Multitasking Demo")
13
  st.write("**Demo for T5's different tasks including machine translation, \
14
  text summarization, document similarity, and grammatical correctness of sentences.**")
15
- load_model_cache()
16
- dashboard()
17
 
18
 
19
  @st.cache
@@ -27,11 +27,11 @@ def load_model_cache():
27
  # download the files from huggingface repo and load the model with tensorflow
28
  snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
29
  saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
30
- global model
31
  model = tf.saved_model.load(saved_model_path, ["serve"])
 
 
32
 
33
-
34
- def dashboard():
35
  task_type = st.sidebar.radio("Task Type",
36
  [
37
  "Translate English to French",
@@ -55,10 +55,10 @@ def dashboard():
55
  sentence = st.text_area("Input sentence.",
56
  "I am Steven and I live in Lagos, Nigeria")
57
  st.write("**Output Text**")
58
- st.write(predict(task_type, sentence))
59
 
60
 
61
- def predict(task_type, sentence):
62
  """Function to parse the user inputs, run the parsed text through the
63
  model and return output in a readable format.
64
  params:
@@ -73,7 +73,7 @@ def predict(task_type, sentence):
73
  "Translate English to Romanian": "Translate English to Romanian",
74
  "Grammatical Correctness of Sentence": "cola sentence",
75
  "Text Summarization": "summarize",
76
- "Document Similarity Score (separate the 2 sentences with 3 dashes `---`)": "stsb",
77
  }
78
  question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5
79
  # Document Similarity takes in two sentences so it has to be parsed in a separate manner
@@ -83,7 +83,7 @@ def predict(task_type, sentence):
83
  return predict_fn([question])[0].decode('utf-8')
84
 
85
 
86
- def predict_fn(x):
87
  """Function to get inferences from model on live data points.
88
  params:
89
  x input text to run get output on
 
12
  st.title("Interactive demo: T5 Multitasking Demo")
13
  st.write("**Demo for T5's different tasks including machine translation, \
14
  text summarization, document similarity, and grammatical correctness of sentences.**")
15
+ model = load_model_cache()
16
+ dashboard(model)
17
 
18
 
19
  @st.cache
 
27
  # download the files from huggingface repo and load the model with tensorflow
28
  snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
29
  saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
 
30
  model = tf.saved_model.load(saved_model_path, ["serve"])
31
+ return model
32
+
33
 
34
+ def dashboard(model):
 
35
  task_type = st.sidebar.radio("Task Type",
36
  [
37
  "Translate English to French",
 
55
  sentence = st.text_area("Input sentence.",
56
  "I am Steven and I live in Lagos, Nigeria")
57
  st.write("**Output Text**")
58
+ st.write(predict(task_type, sentence, model))
59
 
60
 
61
+ def predict(task_type, sentence, model):
62
  """Function to parse the user inputs, run the parsed text through the
63
  model and return output in a readable format.
64
  params:
 
73
  "Translate English to Romanian": "Translate English to Romanian",
74
  "Grammatical Correctness of Sentence": "cola sentence",
75
  "Text Summarization": "summarize",
76
+ "Document Similarity Score": "stsb",
77
  }
78
  question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5
79
  # Document Similarity takes in two sentences so it has to be parsed in a separate manner
 
83
  return predict_fn([question])[0].decode('utf-8')
84
 
85
 
86
+ def predict_fn(x, model):
87
  """Function to get inferences from model on live data points.
88
  params:
89
  x input text to run get output on