File size: 7,168 Bytes
f1984f7
 
 
 
 
 
 
 
 
 
 
 
d352f8e
d9b4b87
 
 
 
 
 
 
f1984f7
 
 
 
d9b4b87
f1984f7
 
 
 
 
 
 
 
d9b4b87
d352f8e
f1984f7
580e072
d352f8e
d9b4b87
 
 
f1984f7
 
d352f8e
 
 
 
 
 
f1984f7
d352f8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1984f7
d352f8e
 
 
 
 
 
 
f1984f7
d9b4b87
d352f8e
d9b4b87
 
d352f8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1984f7
580e072
f1984f7
 
 
 
 
d9b4b87
f1984f7
 
 
 
 
 
 
 
 
580e072
f1984f7
 
 
 
 
 
d9b4b87
f1984f7
 
580e072
f1984f7
 
d9b4b87
 
f1984f7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
from huggingface_hub import snapshot_download

import os  # utility library

# libraries to load the model and serve inference
import tensorflow_text
import tensorflow as tf


def main():
    st.title("Interactive demo: T5 Multitasking Demo")
    st.sidebar.image("https://i.gzn.jp/img/2020/02/25/google-ai-t5/01.png")
    saved_model_path = load_model_cache()

    # Model is loaded in st.session_state to remain stateless across reloading
    if 'model' not in st.session_state:
        st.session_state.model = tf.saved_model.load(saved_model_path, ["serve"])

    dashboard(st.session_state.model)


@st.cache
def load_model_cache():
    """Function to retrieve the model from HuggingFace Hub and cache it using st.cache wrapper
    """
    CACHE_DIR = "hfhub_cache"  # where the library's fork would be stored once downloaded
    if not os.path.exists(CACHE_DIR):
        os.mkdir(CACHE_DIR)

    # download the files from huggingface repo and load the model with tensorflow
    snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
    saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
    return saved_model_path


def dashboard(model):
    """Function to display the inputs and results
    params:
        model   stateless model to run inference from
    """
    task_type = st.sidebar.radio("Task Type",
                                 [
                                     "Translate English to French",
                                     "Translate English to German",
                                     "Translate English to Romanian",
                                     "Grammatical Correctness of Sentence",
                                     "Text Summarization",
                                     "Document Similarity Score"
                                 ])

    default_sentence = "I am Steven and I live in Lagos, Nigeria."
    text_summarization_sentence = "I don't care about those doing the comparison, but comparing \
        the Ghanaian Jollof Rice to Nigerian Jollof Rice is an insult to Nigerians."
    doc_similarity_sentence1 = "I reside in the commercial capital city of Nigeria, which is Lagos."
    doc_similarity_sentence2 = "I live in Lagos."
    help_msg = "You could either type in the sentences to run inferences on or use the upload button to \
        upload text files containing those sentences. The input sentence box, by default, displays sample \
        texts or the texts in the files that you've uploaded. Feel free to erase them and type in new sentences."

    if task_type.startswith("Document Similarity"):  # document similarity requires two documents
        uploaded_file = upload_files(help_msg, text="Upload 2 documents for similarity check", accept_multiple_files=True)
        if uploaded_file:
            sentence1 = st.text_area("Enter first document/sentence", uploaded_file[0], help=help_msg)
            sentence2 = st.text_area("Enter second document/sentence", uploaded_file[1], help=help_msg)
        else:
            sentence1 = st.text_area("Enter first document/sentence", doc_similarity_sentence1)
            sentence2 = st.text_area("Enter second document/sentence", doc_similarity_sentence2)
        sentence = sentence1 + "---" + sentence2  # to be processed like other tasks' single sentences
    else:
        uploaded_file = upload_files(help_msg)
        if uploaded_file:
            sentence = st.text_area("Enter sentence", uploaded_file, help=help_msg)
        elif task_type.startswith("Text Summarization"):  # text summarization's default input should be longer
            sentence = st.text_area("Enter sentence", text_summarization_sentence, help=help_msg)
        else:
            sentence = st.text_area("Enter sentence", default_sentence, help=help_msg)

    st.write("**Output Text**")
    with st.spinner("Waiting for prediction..."):  # spinner while model is running inferences
        output_text = predict(task_type, sentence, model)
        st.write(output_text)
        try:  # to workaround the environment's Streamlit version
            st.download_button("Download output text", output_text)
        except AttributeError:
            st.text("File download not enabled for this Streamlit version \U0001F612")


def upload_files(help_msg, text="Upload a text file here", accept_multiple_files=False):
    """Function to upload text files and return as string text
    params:
        text                    Display label for the upload button
        accept_multiple_files   params for the file_uploader function to accept more than a file
    returns:
        a string or a list of strings (in case of multiple files being uploaded)
    """

    def upload():
        uploaded_files = st.file_uploader(label="Upload text files only", 
                                          type="txt", help=help_msg,
                                          accept_multiple_files=accept_multiple_files)
        if st.button("Process"):
            if not uploaded_files:
                st.write("**No file uploaded!**")
                return None
            st.write("**Upload successful!**")
            if type(uploaded_files) == list:
                return [f.read().decode("utf-8") for f in uploaded_files]
            return uploaded_files.read().decode("utf-8")

    try:  # to workaround the environment's Streamlit version
        with st.expander(text):
            return upload()
    except AttributeError:
        return upload()


def predict(task_type, sentence, model):
    """Function to parse the user inputs, run the parsed text through the
    model and return output in a readable format.
    params:
        task_type   sentence representing the type of task to run on T5 model
        sentence    sentence to get inference on
        model       model to get inferences from
    returns:
        text decoded into a human-readable format.
    """
    task_dict = {
        "Translate English to French": "Translate English to French",
        "Translate English to German": "Translate English to German",
        "Translate English to Romanian": "Translate English to Romanian",
        "Grammatical Correctness of Sentence": "cola sentence",
        "Text Summarization": "summarize",
        "Document Similarity Score": "stsb",
    }
    question = f"{task_dict[task_type]}: {sentence}"  # parsing the user inputs into a format recognized by T5
    # Document Similarity takes in two sentences so it has to be parsed in a separate manner
    if task_type.startswith("Document Similarity"):
        sentences = sentence.split('---')
        question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
    return predict_fn([question], model)[0].decode('utf-8')


def predict_fn(x, model):
    """Function to get inferences from model on live data points.
    params:
        x       input text to run get output on
        model   model to run inferences from
    returns:
        a numpy array representing the output
    """
    return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy()


if __name__ == "__main__":
    main()