stevenkolawole commited on
Commit
876e501
β€’
1 Parent(s): 37a9248

Add application files

Browse files
Files changed (2) hide show
  1. app.py +89 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import snapshot_download
3
+
4
+ import os # utility library
5
+
6
+ # libraries to load the model and serve inference
7
+ import tensorflow_text
8
+ import tensorflow as tf
9
+
10
+ CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded
11
+ if not os.path.exists(CACHE_DIR):
12
+ os.mkdir(CACHE_DIR)
13
+
14
+ # download the files from huggingface repo and load the model with tensorflow
15
+ snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
16
+ saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
17
+ model = tf.saved_model.load(saved_model_path, ["serve"])
18
+
19
+ title = "Interactive demo: T5 Multitasking Demo"
20
+ description = "Demo for T5's different tasks including machine translation, \
21
+ text summarization, document similarity, grammatical correctness of sentences"
22
+
23
+
24
+ def predict_fn(x):
25
+ """Function to get inferences from model on live data points.
26
+ params:
27
+ x input text to run get output on
28
+ returns:
29
+ a numpy array representing the output
30
+ """
31
+ return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
32
+
33
+
34
+ def predict(task_type, sentence):
35
+ """Function to parse the user inputs, run the parsed text through through the
36
+ model and return output in a readable format.
37
+ params:
38
+ task_type sentence representing the type of task to run on T5 model
39
+ sentence sentence to get inference on
40
+ returns:
41
+ text decoded into a human-readable format.
42
+ """
43
+ task_dict = {
44
+ "Translate English to French": "Translate English to French",
45
+ "Translate English to German": "Translate English to German",
46
+ "Translate English to Romanian": "Translate English to Romanian",
47
+ "Grammatical Correctness of Sentence": "cola sentence",
48
+ "Text Summarization": "summarize",
49
+ "Document Similarity Score (separate the 2 sentences with 3 dashes `---`)": "stsb",
50
+ }
51
+ question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5
52
+
53
+ # Document Similarity takes in two sentences so it has to be parsed in a separate manner
54
+ if task_type.startswith("Document Similarity"):
55
+ sentences = sentence.split('---')
56
+ question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
57
+
58
+ return predict_fn([question])[0].decode('utf-8')
59
+
60
+
61
+ iface = gr.Interface(fn=predict,
62
+ inputs=[gr.inputs.Radio(
63
+ choices=["Text Summarization",
64
+ "Translate English to French",
65
+ "Translate English to German",
66
+ "Translate English to Romanian",
67
+ "Grammatical correctness of sentence",
68
+ "Document Similarity Score (separate the 2 sentences with 3 dashes `---`)"],
69
+ label="Task Type"
70
+ ),
71
+ gr.inputs.Textbox(label="Sentence")],
72
+ outputs="text",
73
+ title=title,
74
+ description=description,
75
+ examples=[
76
+ ["Translate English to French", "I am Steven and I live in Lagos, Nigeria"],
77
+ ["Translate English to German", "I am Steven and I live in Lagos, Nigeria"],
78
+ ["Translate English to Romanian", "I am Steven and I live in Lagos, Nigeria"],
79
+ ["Grammatical correctness of sentence", "I am Steven and I live in Lagos, Nigeria"],
80
+ ["Text Summarization",
81
+ "I don't care about those doing the comparison, but comparing the Ghanaian Jollof Rice to \
82
+ Nigerian Jollof Rice is an insult to Nigerians"],
83
+ ["Document Similarity Score (separate the 2 sentences with 3 dashes `---`)",
84
+ "I reside in the commercial capital city of Nigeria, which is Lagos---I live in Lagos"],
85
+ ],
86
+ theme="huggingface",
87
+ enable_queue=True)
88
+
89
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ t5
2
+ huggingface_hub
3
+ gradio