davidmezzetti commited on
Commit
f37320f
1 Parent(s): c50b07e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +394 -2
app.py CHANGED
@@ -1,4 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build txtai workflows.
3
+
4
+ Based on this example: https://github.com/neuml/txtai/blob/master/examples/workflows.py
5
+ """
6
+
7
+ import os
8
+ import re
9
+
10
+ import yaml
11
+
12
+ import pandas as pd
13
  import streamlit as st
14
 
15
+ from txtai.embeddings import Documents, Embeddings
16
+ from txtai.pipeline import Segmentation, Summary, Tabular, Textractor, Transcription, Translation
17
+ from txtai.workflow import ServiceTask, Task, UrlTask, Workflow
18
+
19
+
20
+ class Application:
21
+ """
22
+ Streamlit application.
23
+ """
24
+
25
+ def __init__(self):
26
+ """
27
+ Creates a new Streamlit application.
28
+ """
29
+
30
+ # Component options
31
+ self.components = {}
32
+
33
+ # Defined pipelines
34
+ self.pipelines = {}
35
+
36
+ # Current workflow
37
+ self.workflow = []
38
+
39
+ # Embeddings index params
40
+ self.embeddings = None
41
+ self.documents = None
42
+ self.data = None
43
+
44
+ def number(self, label):
45
+ """
46
+ Extracts a number from a text input field.
47
+
48
+ Args:
49
+ label: label to use for text input field
50
+
51
+ Returns:
52
+ numeric input
53
+ """
54
+
55
+ value = st.sidebar.text_input(label)
56
+ return int(value) if value else None
57
+
58
+ def split(self, text):
59
+ """
60
+ Splits text on commas and returns a list.
61
+
62
+ Args:
63
+ text: input text
64
+
65
+ Returns:
66
+ list
67
+ """
68
+
69
+ return [x.strip() for x in text.split(",")]
70
+
71
+ def options(self, component):
72
+ """
73
+ Extracts component settings into a component configuration dict.
74
+
75
+ Args:
76
+ component: component type
77
+
78
+ Returns:
79
+ dict with component settings
80
+ """
81
+
82
+ options = {"type": component}
83
+
84
+ st.sidebar.markdown("---")
85
+
86
+ if component == "embeddings":
87
+ st.sidebar.markdown("**Embeddings Index** \n*Index workflow output*")
88
+ options["path"] = st.sidebar.text_area("Embeddings model path", value="sentence-transformers/nli-mpnet-base-v2")
89
+ options["upsert"] = st.sidebar.checkbox("Upsert")
90
+
91
+ elif component == "summary":
92
+ st.sidebar.markdown("**Summary** \n*Abstractive text summarization*")
93
+ options["path"] = st.sidebar.text_input("Model", value="sshleifer/distilbart-cnn-12-6")
94
+ options["minlength"] = self.number("Min length")
95
+ options["maxlength"] = self.number("Max length")
96
+
97
+ elif component in ("segment", "textract"):
98
+ if component == "segment":
99
+ st.sidebar.markdown("**Segment** \n*Split text into semantic units*")
100
+ else:
101
+ st.sidebar.markdown("**Textractor** \n*Extract text from documents*")
102
+
103
+ options["sentences"] = st.sidebar.checkbox("Split sentences")
104
+ options["lines"] = st.sidebar.checkbox("Split lines")
105
+ options["paragraphs"] = st.sidebar.checkbox("Split paragraphs")
106
+ options["join"] = st.sidebar.checkbox("Join tokenized")
107
+ options["minlength"] = self.number("Min section length")
108
+
109
+ elif component == "service":
110
+ options["url"] = st.sidebar.text_input("URL")
111
+ options["method"] = st.sidebar.selectbox("Method", ["get", "post"], index=0)
112
+ options["params"] = st.sidebar.text_input("URL parameters")
113
+ options["batch"] = st.sidebar.checkbox("Run as batch", value=True)
114
+ options["extract"] = st.sidebar.text_input("Subsection(s) to extract")
115
+
116
+ if options["params"]:
117
+ options["params"] = {key: None for key in self.split(options["params"])}
118
+ if options["extract"]:
119
+ options["extract"] = self.split(options["extract"])
120
+
121
+ elif component == "tabular":
122
+ options["idcolumn"] = st.sidebar.text_input("Id columns")
123
+ options["textcolumns"] = st.sidebar.text_input("Text columns")
124
+ if options["textcolumns"]:
125
+ options["textcolumns"] = self.split(options["textcolumns"])
126
+
127
+ elif component == "transcribe":
128
+ st.sidebar.markdown("**Transcribe** \n*Transcribe audio to text*")
129
+ options["path"] = st.sidebar.text_input("Model", value="facebook/wav2vec2-base-960h")
130
+
131
+ elif component == "translate":
132
+ st.sidebar.markdown("**Translate** \n*Machine translation*")
133
+ options["target"] = st.sidebar.text_input("Target language code", value="en")
134
+
135
+ return options
136
+
137
+ def build(self, components):
138
+ """
139
+ Builds a workflow using components.
140
+
141
+ Args:
142
+ components: list of components to add to workflow
143
+ """
144
+
145
+ # Clear application
146
+ self.__init__()
147
+
148
+ # pylint: disable=W0108
149
+ tasks = []
150
+ for component in components:
151
+ component = dict(component)
152
+ wtype = component.pop("type")
153
+ self.components[wtype] = component
154
+
155
+ if wtype == "embeddings":
156
+ self.embeddings = Embeddings({**component})
157
+ self.documents = Documents()
158
+ tasks.append(Task(self.documents.add, unpack=False))
159
+
160
+ elif wtype == "segment":
161
+ self.pipelines[wtype] = Segmentation(**self.components["segment"])
162
+ tasks.append(Task(self.pipelines["segment"]))
163
+
164
+ elif wtype == "service":
165
+ tasks.append(ServiceTask(**self.components["service"]))
166
+
167
+ elif wtype == "summary":
168
+ self.pipelines[wtype] = Summary(component.pop("path"))
169
+ tasks.append(Task(lambda x: self.pipelines["summary"](x, **self.components["summary"])))
170
+
171
+ elif wtype == "tabular":
172
+ self.pipelines[wtype] = Tabular(**self.components["tabular"])
173
+ tasks.append(Task(self.pipelines["tabular"]))
174
+
175
+ elif wtype == "textract":
176
+ self.pipelines[wtype] = Textractor(**self.components["textract"])
177
+ tasks.append(UrlTask(self.pipelines["textract"]))
178
+
179
+ elif wtype == "transcribe":
180
+ self.pipelines[wtype] = Transcription(component.pop("path"))
181
+ tasks.append(UrlTask(self.pipelines["transcribe"], r".\.wav$"))
182
+
183
+ elif wtype == "translate":
184
+ self.pipelines[wtype] = Translation()
185
+ tasks.append(Task(lambda x: self.pipelines["translate"](x, **self.components["translate"])))
186
+
187
+ self.workflow = Workflow(tasks)
188
+
189
+ def yaml(self, components):
190
+ """
191
+ Builds a yaml string for components.
192
+
193
+ Args:
194
+ components: list of components to export to YAML
195
+
196
+ Returns:
197
+ YAML string
198
+ """
199
+
200
+ # pylint: disable=W0108
201
+ data = {}
202
+ tasks = []
203
+ name = None
204
+
205
+ for component in components:
206
+ component = dict(component)
207
+ name = wtype = component.pop("type")
208
+
209
+ if wtype == "summary":
210
+ data["summary"] = {"path": component.pop("path")}
211
+ tasks.append({"action": "summary"})
212
+
213
+ elif wtype == "segment":
214
+ data["segmentation"] = component
215
+ tasks.append({"action": "segmentation"})
216
+
217
+ elif wtype == "service":
218
+ config = dict(**component)
219
+ config["task"] = "service"
220
+ tasks.append(config)
221
+
222
+ elif wtype == "tabular":
223
+ data["tabular"] = component
224
+ tasks.append({"action": "tabular"})
225
+
226
+ elif wtype == "textract":
227
+ data["textractor"] = component
228
+ tasks.append({"action": "textractor", "task": "url"})
229
+
230
+ elif wtype == "transcribe":
231
+ data["transcription"] = {"path": component.pop("path")}
232
+ tasks.append({"action": "transcription", "task": "url"})
233
+
234
+ elif wtype == "translate":
235
+ data["translation"] = {}
236
+ tasks.append({"action": "translation", "args": list(component.values())})
237
+
238
+ elif wtype == "embeddings":
239
+ index = component.pop("index")
240
+ upsert = component.pop("upsert")
241
+
242
+ data["embeddings"] = component
243
+ data["writable"] = True
244
+
245
+ if index:
246
+ data["path"] = index
247
+
248
+ name = "index"
249
+ tasks.append({"action": "upsert" if upsert else "index"})
250
+
251
+ # Add in workflow
252
+ data["workflow"] = {name: {"tasks": tasks}}
253
+
254
+ return (name, yaml.dump(data))
255
+
256
+ def find(self, key):
257
+ """
258
+ Lookup record from cached data by uid key.
259
+
260
+ Args:
261
+ key: uid to search for
262
+
263
+ Returns:
264
+ text for matching uid
265
+ """
266
+
267
+ return [text for uid, text, _ in self.data if uid == key][0]
268
+
269
+ def process(self, data):
270
+ """
271
+ Processes the current application action.
272
+
273
+ Args:
274
+ data: input data
275
+ """
276
+
277
+ if data and self.workflow:
278
+ # Build tuples for embedding index
279
+ if self.documents:
280
+ data = [(x, element, None) for x, element in enumerate(data)]
281
+
282
+ # Process workflow
283
+ for result in self.workflow(data):
284
+ if not self.documents:
285
+ st.write(result)
286
+
287
+ # Build embeddings index
288
+ if self.documents:
289
+ # Cache data
290
+ self.data = list(self.documents)
291
+
292
+ with st.spinner("Building embedding index...."):
293
+ self.embeddings.index(self.documents)
294
+ self.documents.close()
295
+
296
+ # Clear workflow
297
+ self.documents, self.pipelines, self.workflow = None, None, None
298
+
299
+ if self.embeddings and self.data:
300
+ # Set query and limit
301
+ query = st.text_input("Query")
302
+ limit = min(5, len(self.data))
303
+
304
+ st.markdown(
305
+ """
306
+ <style>
307
+ table td:nth-child(1) {
308
+ display: none
309
+ }
310
+ table th:nth-child(1) {
311
+ display: none
312
+ }
313
+ table {text-align: left !important}
314
+ </style>
315
+ """,
316
+ unsafe_allow_html=True,
317
+ )
318
+
319
+ if query:
320
+ df = pd.DataFrame([{"content": self.find(uid), "score": score} for uid, score in self.embeddings.search(query, limit)])
321
+ st.table(df)
322
+
323
+ def parse(self, data):
324
+ """
325
+ Parse input data, splits on new lines depending on type of tasks and format of input.
326
+
327
+ Args:
328
+ data: input data
329
+
330
+ Returns:
331
+ parsed data
332
+ """
333
+
334
+ if re.match(r"^(http|https|file):\/\/", data) or (self.workflow and isinstance(self.workflow.tasks[0], ServiceTask)):
335
+ return [x for x in data.split("\n") if x]
336
+
337
+ return [data]
338
+
339
+ def run(self):
340
+ """
341
+ Runs Streamlit application.
342
+ """
343
+
344
+ st.sidebar.image("https://github.com/neuml/txtai/raw/master/logo.png", width=256)
345
+ st.sidebar.markdown("# Workflow builder \n*Build and apply workflows to data* ")
346
+
347
+ # Get selected components
348
+ components = ["embeddings", "segment", "service", "summary", "tabular", "textract", "transcribe", "translate"]
349
+ selected = st.sidebar.multiselect("Select components", components)
350
+
351
+ # Get selected options
352
+ components = [self.options(component) for component in selected]
353
+ st.sidebar.markdown("---")
354
+
355
+ with st.sidebar:
356
+ col1, col2 = st.columns(2)
357
+
358
+ # Build or re-build workflow when build button clicked
359
+ build = col1.button("Build", help="Build the workflow and run within this application")
360
+ if build:
361
+ with st.spinner("Building workflow...."):
362
+ self.build(components)
363
+
364
+ # Generate API configuration
365
+ _, config = self.yaml(components)
366
+
367
+ col2.download_button("Export", config, file_name="workflow.yml", mime="text/yaml", help="Export the API workflow as YAML")
368
+
369
+ with st.expander("Data", expanded=not self.data):
370
+ data = st.text_area("Input", height=10)
371
+
372
+ # Parse text items
373
+ data = self.parse(data)
374
+
375
+ # Process current action
376
+ self.process(data)
377
+
378
+
379
+ @st.cache(allow_output_mutation=True)
380
+ def create():
381
+ """
382
+ Creates and caches a Streamlit application.
383
+
384
+ Returns:
385
+ Application
386
+ """
387
+
388
+ return Application()
389
+
390
+
391
+ if __name__ == "__main__":
392
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
393
+
394
+ # Create and run application
395
+ app = create()
396
+ app.run()