Niansuh commited on
Commit
c4810ed
1 Parent(s): 4cc01d5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +671 -0
app.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build txtai workflows.
3
+ Based on this example: https://github.com/neuml/txtai/blob/master/examples/workflows.py
4
+ """
5
+
6
+ import os
7
+
8
+ import nltk
9
+ import yaml
10
+
11
+ import pandas as pd
12
+ import streamlit as st
13
+
14
+ from txtai.embeddings import Documents, Embeddings
15
+ from txtai.pipeline import Segmentation, Summary, Tabular, Textractor, Translation
16
+ from txtai.workflow import ServiceTask, Task, UrlTask, Workflow
17
+
18
+
19
+ class Process:
20
+ """
21
+ Container for an active Workflow process instance.
22
+ """
23
+
24
+ @staticmethod
25
+ @st.cache_resource(ttl=60 * 60, max_entries=3, show_spinner=False)
26
+ def get(components, data):
27
+ """
28
+ Lookup or creates a new workflow process instance.
29
+ Args:
30
+ components: input components
31
+ data: initial data, only passed when indexing
32
+ Returns:
33
+ Process
34
+ """
35
+
36
+ process = Process(data)
37
+
38
+ # Build workflow
39
+ with st.spinner("Building workflow...."):
40
+ process.build(components)
41
+
42
+ return process
43
+
44
+ def __init__(self, data):
45
+ """
46
+ Creates a new Process.
47
+ Args:
48
+ data: initial data, only passed when indexing
49
+ """
50
+
51
+ # Component options
52
+ self.components = {}
53
+
54
+ # Defined pipelines
55
+ self.pipelines = {}
56
+
57
+ # Current workflow
58
+ self.workflow = []
59
+
60
+ # Embeddings index params
61
+ self.embeddings = None
62
+ self.documents = None
63
+ self.data = data
64
+
65
+ def build(self, components):
66
+ """
67
+ Builds a workflow using components.
68
+ Args:
69
+ components: list of components to add to workflow
70
+ """
71
+
72
+ # pylint: disable=W0108
73
+ tasks = []
74
+ for component in components:
75
+ component = dict(component)
76
+ wtype = component.pop("type")
77
+ self.components[wtype] = component
78
+
79
+ if wtype == "embeddings":
80
+ self.embeddings = Embeddings({**component})
81
+ self.documents = Documents()
82
+ tasks.append(Task(self.documents.add, unpack=False))
83
+
84
+ elif wtype == "segmentation":
85
+ self.pipelines[wtype] = Segmentation(**self.components[wtype])
86
+ tasks.append(Task(self.pipelines[wtype]))
87
+
88
+ elif wtype == "service":
89
+ tasks.append(ServiceTask(**self.components[wtype]))
90
+
91
+ elif wtype == "summary":
92
+ self.pipelines[wtype] = Summary(component.pop("path"))
93
+ tasks.append(Task(lambda x: self.pipelines["summary"](x, **self.components["summary"])))
94
+
95
+ elif wtype == "tabular":
96
+ self.pipelines[wtype] = Tabular(**self.components[wtype])
97
+ tasks.append(Task(self.pipelines[wtype]))
98
+
99
+ elif wtype == "textractor":
100
+ self.pipelines[wtype] = Textractor(**self.components[wtype])
101
+ tasks.append(UrlTask(self.pipelines[wtype]))
102
+
103
+ elif wtype == "translation":
104
+ self.pipelines[wtype] = Translation()
105
+ tasks.append(Task(lambda x: self.pipelines["translation"](x, **self.components["translation"])))
106
+
107
+ self.workflow = Workflow(tasks)
108
+
109
+ def run(self, data):
110
+ """
111
+ Runs a workflow using data as input.
112
+ Args:
113
+ data: input data
114
+ """
115
+
116
+ if data and self.workflow:
117
+ # Build tuples for embedding index
118
+ if self.documents:
119
+ data = [(x, element, None) for x, element in enumerate(data)]
120
+
121
+ # Process workflow
122
+ for result in self.workflow(data):
123
+ if not self.documents:
124
+ st.write(result)
125
+
126
+ # Build embeddings index
127
+ if self.documents:
128
+ # Cache data
129
+ self.data = list(self.documents)
130
+
131
+ with st.spinner("Building embedding index...."):
132
+ self.embeddings.index(self.documents)
133
+ self.documents.close()
134
+
135
+ # Clear workflow
136
+ self.documents, self.pipelines, self.workflow = None, None, None
137
+
138
+ def search(self, query):
139
+ """
140
+ Runs a search.
141
+ Args:
142
+ query: input query
143
+ """
144
+
145
+ if self.embeddings and query:
146
+ st.markdown(
147
+ """
148
+ <style>
149
+ table td:nth-child(1) {
150
+ display: none
151
+ }
152
+ table th:nth-child(1) {
153
+ display: none
154
+ }
155
+ table {text-align: left !important}
156
+ </style>
157
+ """,
158
+ unsafe_allow_html=True,
159
+ )
160
+
161
+ limit = min(5, len(self.data))
162
+
163
+ results = []
164
+ for result in self.embeddings.search(query, limit):
165
+ # Tuples are returned when an index doesn't have stored content
166
+ if isinstance(result, tuple):
167
+ uid, score = result
168
+ results.append({"text": self.find(uid), "score": f"{score:.2}"})
169
+ else:
170
+ if "id" in result and "text" in result:
171
+ result["text"] = self.content(result.pop("id"), result["text"])
172
+ if "score" in result and result["score"]:
173
+ result["score"] = f'{result["score"]:.2}'
174
+
175
+ results.append(result)
176
+
177
+ df = pd.DataFrame(results)
178
+ st.write(df.to_html(escape=False), unsafe_allow_html=True)
179
+
180
+ def find(self, key):
181
+ """
182
+ Lookup record from cached data by uid key.
183
+ Args:
184
+ key: id to search for
185
+ Returns:
186
+ text for matching id
187
+ """
188
+
189
+ # Lookup text by id
190
+ text = [text for uid, text, _ in self.data if uid == key][0]
191
+ return self.content(key, text)
192
+
193
+ def content(self, uid, text):
194
+ """
195
+ Builds a content reference for uid and text.
196
+ Args:
197
+ uid: record id
198
+ text: record text
199
+ Returns:
200
+ content
201
+ """
202
+
203
+ if uid and uid.lower().startswith("http"):
204
+ return f"<a href='{uid}' rel='noopener noreferrer' target='blank'>{text}</a>"
205
+
206
+ return text
207
+
208
+
209
+ class Application:
210
+ """
211
+ Main application.
212
+ """
213
+
214
+ def __init__(self, directory):
215
+ """
216
+ Creates a new application.
217
+ """
218
+
219
+ # Workflow configuration directory
220
+ self.directory = directory
221
+
222
+ def default(self, names):
223
+ """
224
+ Gets default workflow index.
225
+ Args:
226
+ names: list of workflow names
227
+ Returns:
228
+ default workflow index
229
+ """
230
+
231
+ # Get names as lowercase to match case-insensitive
232
+ lnames = [name.lower() for name in names]
233
+
234
+ # Get default workflow param
235
+ params = st.experimental_get_query_params()
236
+ index = params.get("default")
237
+ index = index[0].lower() if index else 0
238
+
239
+ # Lookup index of workflow name, add 1 to account for "--"
240
+ if index and index in lnames:
241
+ return lnames.index(index) + 1
242
+
243
+ # Workflow not found, default to index 0
244
+ return 0
245
+
246
+ def load(self, components):
247
+ """
248
+ Load an existing workflow file.
249
+ Args:
250
+ components: list of components to load
251
+ Returns:
252
+ (names of components loaded, workflow config)
253
+ """
254
+
255
+ with open(os.path.join(self.directory, "config.yml"), encoding="utf-8") as f:
256
+ config = yaml.safe_load(f)
257
+
258
+ names = [row["name"] for row in config]
259
+ files = [row["file"] for row in config]
260
+
261
+ selected = st.selectbox("Load workflow", ["--"] + names, self.default(names))
262
+ if selected != "--":
263
+ index = [x for x, name in enumerate(names) if name == selected][0]
264
+ with open(os.path.join(self.directory, files[index]), encoding="utf-8") as f:
265
+ workflow = yaml.safe_load(f)
266
+
267
+ st.markdown("---")
268
+
269
+ # Get tasks for first workflow
270
+ tasks = list(workflow["workflow"].values())[0]["tasks"]
271
+ selected = []
272
+
273
+ for task in tasks:
274
+ name = task.get("action", task.get("task"))
275
+ if name in components:
276
+ selected.append(name)
277
+ elif name in ["index", "upsert"]:
278
+ selected.append("embeddings")
279
+
280
+ return (selected, workflow)
281
+
282
+ return (None, None)
283
+
284
+ def state(self, key):
285
+ """
286
+ Lookup a session state variable.
287
+ Args:
288
+ key: variable key
289
+ Returns:
290
+ variable value
291
+ """
292
+
293
+ if key in st.session_state:
294
+ return st.session_state[key]
295
+
296
+ return None
297
+
298
+ def appsetting(self, workflow, name):
299
+ """
300
+ Looks up an application configuration setting.
301
+ Args:
302
+ workflow: workflow configuration
303
+ name: setting name
304
+ Returns:
305
+ app setting value
306
+ """
307
+
308
+ if workflow:
309
+ config = workflow.get("app")
310
+ if config:
311
+ return config.get(name)
312
+
313
+ return None
314
+
315
+ def setting(self, config, name, default=None):
316
+ """
317
+ Looks up a component configuration setting.
318
+ Args:
319
+ config: component configuration
320
+ name: setting name
321
+ default: default setting value
322
+ Returns:
323
+ setting value
324
+ """
325
+
326
+ return config.get(name, default) if config else default
327
+
328
+ def text(self, label, component, config, name, default=None):
329
+ """
330
+ Create a new text input field.
331
+ Args:
332
+ label: field label
333
+ component: component name
334
+ config: component configuration
335
+ name: setting name
336
+ default: default setting value
337
+ Returns:
338
+ text input field value
339
+ """
340
+
341
+ default = self.setting(config, name, default)
342
+ if not default:
343
+ default = ""
344
+ elif isinstance(default, list):
345
+ default = ",".join(default)
346
+ elif isinstance(default, dict):
347
+ default = ",".join(default.keys())
348
+
349
+ st.caption(label)
350
+ st.code(default, language="yaml")
351
+ return default
352
+
353
+ def number(self, label, component, config, name, default=None):
354
+ """
355
+ Creates a new numeric input field.
356
+ Args:
357
+ label: field label
358
+ component: component name
359
+ config: component configuration
360
+ name: setting name
361
+ default: default setting value
362
+ Returns:
363
+ numeric value
364
+ """
365
+
366
+ value = self.text(label, component, config, name, default)
367
+ return int(value) if value else None
368
+
369
+ def boolean(self, label, component, config, name, default=False):
370
+ """
371
+ Creates a new checkbox field.
372
+ Args:
373
+ label: field label
374
+ component: component name
375
+ config: component configuration
376
+ name: setting name
377
+ default: default setting value
378
+ Returns:
379
+ boolean value
380
+ """
381
+
382
+ default = self.setting(config, name, default)
383
+
384
+ st.caption(label)
385
+ st.markdown(":white_check_mark:" if default else ":white_large_square:")
386
+ return default
387
+
388
+ def select(self, label, component, config, name, options, default=0):
389
+ """
390
+ Creates a new select box field.
391
+ Args:
392
+ label: field label
393
+ component: component name
394
+ config: component configuration
395
+ name: setting name
396
+ options: list of dropdown options
397
+ default: default setting value
398
+ Returns:
399
+ boolean value
400
+ """
401
+
402
+ index = self.setting(config, name)
403
+ index = [x for x, option in enumerate(options) if option == default]
404
+
405
+ # Derive default index
406
+ default = index[0] if index else default
407
+
408
+ st.caption(label)
409
+ st.code(options[default], language="yaml")
410
+ return options[default]
411
+
412
+ def split(self, text):
413
+ """
414
+ Splits text on commas and returns a list.
415
+ Args:
416
+ text: input text
417
+ Returns:
418
+ list
419
+ """
420
+
421
+ return [x.strip() for x in text.split(",")]
422
+
423
+ def options(self, component, workflow, index):
424
+ """
425
+ Extracts component settings into a component configuration dict.
426
+ Args:
427
+ component: component type
428
+ workflow: existing workflow, can be None
429
+ index: task index
430
+ Returns:
431
+ dict with component settings
432
+ """
433
+
434
+ # pylint: disable=R0912, R0915
435
+ options = {"type": component}
436
+
437
+ # Lookup component configuration
438
+ # - Runtime components have config defined within tasks
439
+ # - Pipeline components have config defined at workflow root
440
+ config = None
441
+ if workflow:
442
+ if component in ["service", "translation"]:
443
+ # Service config is found in tasks section
444
+ tasks = list(workflow["workflow"].values())[0]["tasks"]
445
+ tasks = [task for task in tasks if task.get("task") == component or task.get("action") == component]
446
+ if tasks:
447
+ config = tasks[0]
448
+ else:
449
+ config = workflow.get(component)
450
+
451
+ if component == "embeddings":
452
+ st.markdown(f"** {index + 1}.) Embeddings Index** \n*Index workflow output*")
453
+ options["path"] = self.text("Embeddings model path", component, config, "path", "sentence-transformers/nli-mpnet-base-v2")
454
+ options["upsert"] = self.boolean("Upsert", component, config, "upsert")
455
+ options["content"] = self.boolean("Content", component, config, "content")
456
+
457
+ elif component in ("segmentation", "textractor"):
458
+ if component == "segmentation":
459
+ st.markdown(f"** {index + 1}.) Segment** \n*Split text into semantic units*")
460
+ else:
461
+ st.markdown(f"** {index + 1}.) Textract** \n*Extract text from documents*")
462
+
463
+ options["sentences"] = self.boolean("Split sentences", component, config, "sentences")
464
+ options["lines"] = self.boolean("Split lines", component, config, "lines")
465
+ options["paragraphs"] = self.boolean("Split paragraphs", component, config, "paragraphs")
466
+ options["join"] = self.boolean("Join tokenized", component, config, "join")
467
+ options["minlength"] = self.number("Min section length", component, config, "minlength")
468
+
469
+ elif component == "service":
470
+ st.markdown(f"** {index + 1}.) Service** \n*Extract data from an API*")
471
+ options["url"] = self.text("URL", component, config, "url")
472
+ options["method"] = self.select("Method", component, config, "method", ["get", "post"], 0)
473
+ options["params"] = self.text("URL parameters", component, config, "params")
474
+ options["batch"] = self.boolean("Run as batch", component, config, "batch", True)
475
+ options["extract"] = self.text("Subsection(s) to extract", component, config, "extract")
476
+
477
+ if options["params"]:
478
+ options["params"] = {key: None for key in self.split(options["params"])}
479
+ if options["extract"]:
480
+ options["extract"] = self.split(options["extract"])
481
+
482
+ elif component == "summary":
483
+ st.markdown(f"** {index + 1}.) Summary** \n*Abstractive text summarization*")
484
+ options["path"] = self.text("Model", component, config, "path", "sshleifer/distilbart-cnn-12-6")
485
+ options["minlength"] = self.number("Min length", component, config, "minlength")
486
+ options["maxlength"] = self.number("Max length", component, config, "maxlength")
487
+
488
+ elif component == "tabular":
489
+ st.markdown(f"** {index + 1}.) Tabular** \n*Split tabular data into rows and columns*")
490
+ options["idcolumn"] = self.text("Id columns", component, config, "idcolumn")
491
+ options["textcolumns"] = self.text("Text columns", component, config, "textcolumns")
492
+ options["content"] = self.text("Content", component, config, "content")
493
+
494
+ if options["textcolumns"]:
495
+ options["textcolumns"] = self.split(options["textcolumns"])
496
+
497
+ if options["content"]:
498
+ options["content"] = self.split(options["content"])
499
+ if len(options["content"]) == 1 and options["content"][0] == "1":
500
+ options["content"] = options["content"][0]
501
+
502
+ elif component == "translation":
503
+ st.markdown(f"** {index + 1}.) Translate** \n*Machine translation*")
504
+ options["target"] = self.text("Target language code", component, config, "args", "en")
505
+
506
+ st.markdown("---")
507
+
508
+ return options
509
+
510
+ def yaml(self, components):
511
+ """
512
+ Builds a yaml string for components.
513
+ Args:
514
+ components: list of components to export to YAML
515
+ Returns:
516
+ (workflow name, YAML string)
517
+ """
518
+
519
+ data = {"app": {"data": self.state("data"), "query": self.state("query")}}
520
+ tasks = []
521
+ name = None
522
+
523
+ for component in components:
524
+ component = dict(component)
525
+ name = wtype = component.pop("type")
526
+
527
+ if wtype == "embeddings":
528
+ upsert = component.pop("upsert")
529
+
530
+ data[wtype] = component
531
+ data["writable"] = True
532
+
533
+ name = "index"
534
+ tasks.append({"action": "upsert" if upsert else "index"})
535
+
536
+ elif wtype == "segmentation":
537
+ data[wtype] = component
538
+ tasks.append({"action": wtype})
539
+
540
+ elif wtype == "service":
541
+ config = dict(**component)
542
+ config["task"] = wtype
543
+ tasks.append(config)
544
+
545
+ elif wtype == "summary":
546
+ data[wtype] = {"path": component.pop("path")}
547
+ tasks.append({"action": wtype})
548
+
549
+ elif wtype == "tabular":
550
+ data[wtype] = component
551
+ tasks.append({"action": wtype})
552
+
553
+ elif wtype == "textractor":
554
+ data[wtype] = component
555
+ tasks.append({"action": wtype, "task": "url"})
556
+
557
+ elif wtype == "translation":
558
+ data[wtype] = {}
559
+ tasks.append({"action": wtype, "args": list(component.values())})
560
+
561
+ # Add in workflow
562
+ data["workflow"] = {name: {"tasks": tasks}}
563
+
564
+ return (name, yaml.dump(data))
565
+
566
+ def data(self, workflow):
567
+ """
568
+ Gets input data.
569
+ Args:
570
+ workflow: workflow configuration
571
+ Returns:
572
+ input data
573
+ """
574
+
575
+ # Get default data setting
576
+ data = self.appsetting(workflow, "data")
577
+ if not self.appsetting(workflow, "query"):
578
+ data = st.text_input("Input", value=data)
579
+
580
+ # Save data state
581
+ st.session_state["data"] = data
582
+
583
+ # Wrap data as list for workflow processing
584
+ return [data]
585
+
586
+ def query(self, workflow, index):
587
+ """
588
+ Gets input query.
589
+ Args:
590
+ workflow: workflow configuration
591
+ index: True if this is an indexing workflow
592
+ Returns:
593
+ input query
594
+ """
595
+
596
+ default = self.appsetting(workflow, "query")
597
+ default = default if default else ""
598
+
599
+ # Get query if this is an indexing workflow
600
+ query = st.text_input("Query", value=default) if index else None
601
+
602
+ # Save query state
603
+ st.session_state["query"] = query
604
+
605
+ return query
606
+
607
+ def process(self, workflow, components, index):
608
+ """
609
+ Processes the current application action.
610
+ Args:
611
+ workflow: workflow configuration
612
+ components: workflow components
613
+ index: True if this is an indexing workflow
614
+ """
615
+
616
+ # Get input data and initialize query
617
+ data = self.data(workflow)
618
+ query = self.query(workflow, index)
619
+
620
+ # Get workflow process
621
+ process = Process.get(components, data if index else None)
622
+
623
+ # Run workflow process
624
+ process.run(data)
625
+
626
+ # Run search
627
+ if index:
628
+ process.search(query)
629
+
630
+ def run(self):
631
+ """
632
+ Runs Streamlit application.
633
+ """
634
+
635
+ with st.sidebar:
636
+ st.image("https://github.com/neuml/txtai/raw/master/logo.png", width=256)
637
+ st.markdown("# txtai \n*Build and apply workflows to data* ")
638
+ st.markdown("---")
639
+
640
+ # Component configuration
641
+ components = ["embeddings", "segmentation", "service", "summary", "tabular", "textractor", "translation"]
642
+
643
+ selected, workflow = self.load(components)
644
+ if selected:
645
+ # Get selected options
646
+ components = [self.options(component, workflow, x) for x, component in enumerate(selected)]
647
+
648
+ if selected:
649
+ # Process current action
650
+ self.process(workflow, components, "embeddings" in selected)
651
+
652
+ with st.sidebar:
653
+ # Generate export button after workflow is complete
654
+ _, config = self.yaml(components)
655
+ st.download_button("Export", config, file_name="workflow.yml", help="Export the API workflow as YAML")
656
+ else:
657
+ st.info("Select a workflow from the sidebar")
658
+
659
+
660
+ if __name__ == "__main__":
661
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
662
+
663
+ # pylint: disable=W0702
664
+ try:
665
+ nltk.sent_tokenize("This is a test. Split")
666
+ except:
667
+ nltk.download("punkt")
668
+
669
+ # Create and run application
670
+ app = Application("workflows")
671
+ app.run()