stevenkolawole commited on
Commit
d352f8e
1 Parent(s): d9b4b87

add file uploader & download functions

Browse files
Files changed (2) hide show
  1. app.py +70 -26
  2. requirements.txt +1 -1
app.py CHANGED
@@ -10,8 +10,7 @@ import tensorflow as tf
10
 
11
  def main():
12
  st.title("Interactive demo: T5 Multitasking Demo")
13
- st.write("**Demo for T5's different tasks including machine translation, \
14
- text summarization, document similarity, and grammatical correctness of sentences.**")
15
  saved_model_path = load_model_cache()
16
 
17
  # Model is loaded in st.session_state to remain stateless across reloading
@@ -33,43 +32,88 @@ def load_model_cache():
33
  snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
34
  saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
35
  return saved_model_path
36
-
37
 
38
  def dashboard(model):
39
- """"Function to display the inputs and results
40
  params:
41
  model stateless model to run inference from
42
  """
43
- st.sidebar.write("**Select the Task Type over here**")
44
  task_type = st.sidebar.radio("Task Type",
45
  [
46
- "Translate English to French",
47
- "Translate English to German",
48
- "Translate English to Romanian",
49
- "Grammatical Correctness of Sentence",
50
- "Text Summarization",
51
- "Document Similarity Score"
52
  ])
53
- if task_type.startswith("Document Similarity"): # document similarity requires two documents
54
- sentence1 = st.text_area("The first document/sentence",
55
- "I reside in the commercial capital city of Nigeria, which is Lagos.")
56
- sentence2 = st.text_area("The second document/sentence",
57
- "I live in Lagos.")
58
- sentence = sentence1 + "---" + sentence2
59
- elif task_type.startswith("Text Summarization"): # text summarization's default input should be longer
60
- sentence = st.text_area("Input sentence",
61
- "I don't care about those doing the comparison, but comparing the Ghanaian Jollof Rice \
62
- to Nigerian Jollof Rice is an insult to Nigerians.")
 
 
 
 
 
 
 
 
 
63
  else:
64
- sentence = st.text_area("Input sentence",
65
- "I am Steven and I live in Lagos, Nigeria.")
 
 
 
 
 
66
 
67
  st.write("**Output Text**")
68
- with st.spinner("Please wait..."): # spinner while model is running inferences
69
  output_text = predict(task_type, sentence, model)
70
  st.write(output_text)
71
- # st.download_button("Download output text", output_text) # download_button is yet to be production-ready
72
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def predict(task_type, sentence, model):
75
  """Function to parse the user inputs, run the parsed text through the
10
 
11
  def main():
12
  st.title("Interactive demo: T5 Multitasking Demo")
13
+ st.sidebar.image("https://i.gzn.jp/img/2020/02/25/google-ai-t5/01.png")
 
14
  saved_model_path = load_model_cache()
15
 
16
  # Model is loaded in st.session_state to remain stateless across reloading
32
  snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
33
  saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
34
  return saved_model_path
35
+
36
 
37
  def dashboard(model):
38
+ """Function to display the inputs and results
39
  params:
40
  model stateless model to run inference from
41
  """
 
42
  task_type = st.sidebar.radio("Task Type",
43
  [
44
+ "Translate English to French",
45
+ "Translate English to German",
46
+ "Translate English to Romanian",
47
+ "Grammatical Correctness of Sentence",
48
+ "Text Summarization",
49
+ "Document Similarity Score"
50
  ])
51
+
52
+ default_sentence = "I am Steven and I live in Lagos, Nigeria."
53
+ text_summarization_sentence = "I don't care about those doing the comparison, but comparing \
54
+ the Ghanaian Jollof Rice to Nigerian Jollof Rice is an insult to Nigerians."
55
+ doc_similarity_sentence1 = "I reside in the commercial capital city of Nigeria, which is Lagos."
56
+ doc_similarity_sentence2 = "I live in Lagos."
57
+ help_msg = "You could either type in the sentences to run inferences on or use the upload button to \
58
+ upload text files containing those sentences. The input sentence box, by default, displays sample \
59
+ texts or the texts in the files that you've uploaded. Feel free to erase them and type in new sentences."
60
+
61
+ if task_type.startswith("Document Similarity"): # document similarity requires two documents
62
+ uploaded_file = upload_files(help_msg, text="Upload 2 documents for similarity check", accept_multiple_files=True)
63
+ if uploaded_file:
64
+ sentence1 = st.text_area("Enter first document/sentence", uploaded_file[0], help=help_msg)
65
+ sentence2 = st.text_area("Enter second document/sentence", uploaded_file[1], help=help_msg)
66
+ else:
67
+ sentence1 = st.text_area("Enter first document/sentence", doc_similarity_sentence1)
68
+ sentence2 = st.text_area("Enter second document/sentence", doc_similarity_sentence2)
69
+ sentence = sentence1 + "---" + sentence2 # to be processed like other tasks' single sentences
70
  else:
71
+ uploaded_file = upload_files(help_msg)
72
+ if uploaded_file:
73
+ sentence = st.text_area("Enter sentence", uploaded_file, help=help_msg)
74
+ elif task_type.startswith("Text Summarization"): # text summarization's default input should be longer
75
+ sentence = st.text_area("Enter sentence", text_summarization_sentence, help=help_msg)
76
+ else:
77
+ sentence = st.text_area("Enter sentence", default_sentence, help=help_msg)
78
 
79
  st.write("**Output Text**")
80
+ with st.spinner("Waiting for prediction..."): # spinner while model is running inferences
81
  output_text = predict(task_type, sentence, model)
82
  st.write(output_text)
83
+ try: # to workaround the environment's Streamlit version
84
+ st.download_button("Download output text", output_text)
85
+ except AttributeError:
86
+ st.text("File download not enabled for this Streamlit version \U0001F612")
87
+
88
+
89
+ def upload_files(help_msg, text="Upload a text file here", accept_multiple_files=False):
90
+ """Function to upload text files and return as string text
91
+ params:
92
+ text Display label for the upload button
93
+ accept_multiple_files params for the file_uploader function to accept more than a file
94
+ returns:
95
+ a string or a list of strings (in case of multiple files being uploaded)
96
+ """
97
+
98
+ def upload():
99
+ uploaded_files = st.file_uploader(label="Upload text files only",
100
+ type="txt", help=help_msg,
101
+ accept_multiple_files=accept_multiple_files)
102
+ if st.button("Process"):
103
+ if not uploaded_files:
104
+ st.write("**No file uploaded!**")
105
+ return None
106
+ st.write("**Upload successful!**")
107
+ if type(uploaded_files) == list:
108
+ return [f.read().decode("utf-8") for f in uploaded_files]
109
+ return uploaded_files.read().decode("utf-8")
110
+
111
+ try: # to workaround the environment's Streamlit version
112
+ with st.expander(text):
113
+ return upload()
114
+ except AttributeError:
115
+ return upload()
116
+
117
 
118
  def predict(task_type, sentence, model):
119
  """Function to parse the user inputs, run the parsed text through the
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  t5
2
  huggingface_hub
3
- streamlit
1
  t5
2
  huggingface_hub
3
+ streamlit==1.0.0