kasand commited on
Commit
3d6098f
1 Parent(s): 8003b0e

i did something but i need to make workflows now

Browse files
Files changed (1) hide show
  1. app.py +532 -1
app.py CHANGED
@@ -7,6 +7,537 @@ import pandas as pd
7
  import streamlit as st
8
 
9
  from txtai.embeddings import Documents, Embeddings
10
- from txtai.pipeline import Summary, Tabular, Textractor, Translation
11
  from txtai.workflow import ServiceTask, Task, UrlTask, Workflow
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import streamlit as st
8
 
9
  from txtai.embeddings import Documents, Embeddings
10
+ from txtai.pipeline import Segmentation, Summary, Tabular, Textractor, Translation
11
  from txtai.workflow import ServiceTask, Task, UrlTask, Workflow
12
 
13
+ class Process:
14
+
15
+ @staticmethod
16
+ @st.cache(ttl=60 * 60, max_entries=3, allow_output_mutation=True, show_spinner=False)
17
+ def get(components, data):
18
+ """
19
+ Lookup or creates a new workflow process instance
20
+ """
21
+
22
+ process = Process(data)
23
+
24
+ with st.spinner("Building workflow...."):
25
+ process.build(components)
26
+
27
+ return process
28
+
29
+ def __init__(self, data):
30
+ """
31
+ Create new Process
32
+ """
33
+
34
+ self.components = {}
35
+
36
+ self.pipelines = {}
37
+
38
+ self. workflow = []
39
+
40
+ self.embeddings = None
41
+ self.documents = None
42
+ self.data = data
43
+
44
+ def build(self, components):
45
+ """
46
+ Builds a workflow using components
47
+ """
48
+
49
+ tasks = []
50
+
51
+ for component in components:
52
+ component = dict(component)
53
+ wtype = component.pop(type)
54
+ self.components[wtype] = component
55
+
56
+ if wtype == "embeddings":
57
+ self.embeddings = Embeddings({**component})
58
+ self.documents = Documents()
59
+ tasks.append(Task(self.documents.add, unpack=False))
60
+
61
+ elif wtype == "segmentation":
62
+ self.pipelines[wtype] = Segmentation(**self.components[wtype])
63
+ tasks.append(Task(self.pipelines[wtype]))
64
+
65
+ elif wtype == "service":
66
+ tasks.append(ServiceTask(**self.components[wtype]))
67
+
68
+ elif wtype == "summary":
69
+ self.pipelines[wtype] = Summary(component.pop("path"))
70
+ tasks.append(Task(lambda x: self.pipelines["summary"](x, **self.components["summary"])))
71
+
72
+ elif wtype == "tabular":
73
+ self.pipelines[wtype] = Tabular(**self.components[wtype])
74
+ tasks.append(Task(self.pipelines[wtype]))
75
+
76
+ elif wtype == "textractor":
77
+ self.pipelines[wtype] = Textractor(**self.components[wtype])
78
+ tasks.append(UrlTask(self.pipelines[wtype]))
79
+
80
+ elif wtype == "translation":
81
+ self.pipelines[wtype] = Translation()
82
+ tasks.append(Task(lambda x: self.pipelines["translation"](x, **self.components["translation"])))
83
+
84
+ self.workflow = Workflow(tasks)
85
+
86
+ def run(self, data):
87
+ """
88
+ Runs a workflow using data as input
89
+ """
90
+
91
+ if data and self.workflow:
92
+ # Builds tuples for embedding index
93
+ if self.documents:
94
+ data = [(x, element, None) for x, element in enumerate(data)]
95
+
96
+ # Process workflow
97
+ for result in self.workflow(data):
98
+ if not self.documents:
99
+ st.write(result)
100
+
101
+ # Build embedding index
102
+ if self.documents:
103
+ # Cache data
104
+ self.data = list(self.documents)
105
+
106
+ with st.spinner("Building embedding index...."):
107
+ self.embeddings.index(self.documents)
108
+ self.documents.close()
109
+
110
+ # Clear workflow
111
+ self.documents, self.pipelines, self.workflow = None, None, None
112
+
113
+ def search(self, query):
114
+ """
115
+ Runs a search for query
116
+ """
117
+ if self.embeddings and query:
118
+ st.markdown(
119
+ """
120
+ <style>
121
+ table td:nth-child(1) {
122
+ display: none
123
+ }
124
+ table th:nth-child(1) {
125
+ display: none
126
+ }
127
+ table {text-align: left !important}
128
+ </style>
129
+ """,
130
+ unsafe_allow_html=True,
131
+ )
132
+
133
+ limit = min(5, len(self.data))
134
+
135
+ results = []
136
+ for result in self.embeddings.search(query, limit):
137
+ # Tuples are returned when an index doesn't have stored content
138
+ if isinstance(result, tuple):
139
+ uid, score = result
140
+ results.append({"text": self.find(uid), "score": f"{score:.2}"})
141
+ else:
142
+ if "id" in result and "text" in result:
143
+ result["text"] = self.content(result.pop("id"), result["text"])
144
+ if "score" in result and result["score"]:
145
+ result["score"] = f'{result["score"]:.2}'
146
+
147
+ results.append(result)
148
+
149
+ df = pd.DataFrame(results)
150
+ st.write(df.to_html(escape=False), unsafe_allow_html=True)
151
+
152
+ def find(self, key):
153
+ """
154
+ Lookup record from cached data by uid key
155
+ """
156
+
157
+ # Lookup text by id
158
+ text = [text for uid, text, _ in self.data if uid == key][0]
159
+ return self.content(key, text)
160
+
161
+ def content(self, uid, text):
162
+ """
163
+ Builds a content reference for uid and text
164
+ """
165
+
166
+ if uid and uid.lower().startswith("http"):
167
+ return f"<a href='{uid}' rel='noopener noreferrer' target='blank'>{text}</a>"
168
+
169
+ return text
170
+
171
+ class Application:
172
+ """
173
+ Main application
174
+ """
175
+
176
+ def __init__(self, directory):
177
+ """
178
+ Creates a new application
179
+ """
180
+
181
+ # Workflow configuration directory
182
+ self.directory = directory
183
+
184
+ def default(self, names):
185
+ """
186
+ Gets default workflow index
187
+ """
188
+
189
+ # Gets names as lowercase to match case sensitive
190
+ lnames = [name.lower() for name in names]
191
+
192
+ # Get default workflow param
193
+ params = st.experimental_get_query_params()
194
+ index = params.get("default")
195
+ index = index[0].lower() if index else 0
196
+
197
+ # Lookup index of workflow name, add 1 to account for "--"
198
+ if index and index in lnames:
199
+ return lnames.index(index) + 1
200
+
201
+ # Workflow not found, default to index 0
202
+ return 0
203
+
204
+ def load(self, components):
205
+ """
206
+ Load an existing workflow file
207
+ """
208
+
209
+ with open(os.path.join(self.directory, "config.yml"), encoding="utf-8") as f:
210
+ config = yaml.safe_load(f)
211
+
212
+ names = [row["name"] for row in config]
213
+ files = [row["file"] for row in config]
214
+
215
+ selected = st.selectbox("Load workflow", ["--"] + names, self.default(names))
216
+ if selected != "--":
217
+ index = [x for x, name in enumerate(names) if name == selected][0]
218
+ with open(os.path.join(self.directory, files[index]), encoding="utf-8") as f:
219
+ workflow = yaml.safe_load(f)
220
+
221
+ st.markdown("---")
222
+
223
+ # Get tasks for first workflow
224
+ tasks = list(workflow["workflow"].values())[0]["tasks"]
225
+ selected = []
226
+
227
+ for task in tasks:
228
+ name = task.get("action", task.get("task"))
229
+ if name in components:
230
+ selected.append(name)
231
+ elif name in ["index", "upsert"]:
232
+ selected.append("embeddings")
233
+
234
+ return (selected, workflow)
235
+
236
+ return (None, None)
237
+
238
+ def state(self, key):
239
+ """
240
+ Lookup a session state variable
241
+ """
242
+
243
+ if key in st.session_state:
244
+ return st.session_state[key]
245
+
246
+ return None
247
+
248
+ def appsetting(self, workflow, name):
249
+ """
250
+ Looks up an application configuration setting
251
+ """
252
+
253
+ if workflow:
254
+ config = workflow.get("app")
255
+ if config:
256
+ return config.get(name)
257
+
258
+ return None
259
+
260
+ def setting(self, config, name, default=None):
261
+ """
262
+ Looks up a component configuration settings
263
+ """
264
+
265
+ return config.get(name, default) if config else default
266
+
267
+ def text(self, label, component, config, name, default=None):
268
+ """
269
+ Create a new text input field
270
+ """
271
+
272
+ default = self.setting(config, name, default)
273
+ if not default:
274
+ default = ""
275
+ elif isinstance(default, list):
276
+ default = ",".join(default)
277
+ elif isinstance(default, dict):
278
+ default = ",".join(default.keys())
279
+
280
+ st.caption(label)
281
+ st.code(default, language="yaml")
282
+ return default
283
+
284
+ def number(self, label, component, config, name, default=None):
285
+ """
286
+ Creates a new numeric input field
287
+ """
288
+
289
+ value = self.text(label, component, config, name, default)
290
+ return int(value) if value else None
291
+
292
+ def boolean(self, label, component, config, name, default=None):
293
+ """
294
+ Creates a new checkbox field
295
+ """
296
+
297
+ default = self.setting(config, name, default)
298
+
299
+ st.caption(label)
300
+ st.markdown(":white_check_mark:" if default else ":white_large_square:")
301
+ return default
302
+
303
+ def select(self, label, component, config, name, options, default=0):
304
+ """
305
+ Creates a new select box field
306
+ """
307
+
308
+ index = self.setting(config, name)
309
+ index = [x for x, option in enumerate(options) if option == default]
310
+
311
+ # Derive default index
312
+ default = index[0] if index else default
313
+
314
+ st.caption(label)
315
+ st.code(options[default], language="yaml")
316
+ return options[default]
317
+
318
+ def split(self, text):
319
+ """
320
+ Splits text on commas and returns a list
321
+ """
322
+
323
+ return [x.strip() for x in text.split(",")]
324
+
325
+ def options(self, component, workflow, index):
326
+ """
327
+ Extracts component settings into a component configuration dict
328
+ """
329
+
330
+ options = {"type": component}
331
+
332
+ config = None
333
+ if workflow:
334
+ if component in ["service", "translation"]:
335
+ tasks = list(workflow["workflow"].values())[0]["tasks"]
336
+ tasks = [task for task in tasks if task.get("task") == component or task.get("action") == component]
337
+ if tasks:
338
+ config = tasks[0]
339
+ else:
340
+ config = workflow.get(component)
341
+
342
+ if component == "embeddings":
343
+ st.markdown(f"** {index + 1}.) Embeddings Index** \n*Index workflow output*")
344
+ options["path"] = self.text("Embeddings model path", component, config, "path", "sentence-transformers/nli-mpnet-base-v2")
345
+ options["upsert"] = self.boolean("Upsert", component, config, "upsert")
346
+ options["content"] = self.boolean("Content", component, config, "content")
347
+
348
+ elif component in ("segmentation", "textractor"):
349
+ if component == "segmentation":
350
+ st.markdown(f"** {index + 1}.) Segment** \n*Split text into semantic units*")
351
+ else:
352
+ st.markdown(f"** {index + 1}.) Textract** \n*Extract text from documents")
353
+
354
+ options["sentences"] = self.boolean("Split sentences", component, config, "sentences")
355
+ options["lines"] = self.boolean("Split lines", component, config, "lines")
356
+ options["paragraphs"] = self.boolean("Split paragraphs", component, config, "paragraphs")
357
+ options["joint"] = self.boolean("Join tokenized", component, config, "join")
358
+ options["minlength"] = self.number("Min section length", component, config, "minlength")
359
+
360
+ elif component == "service":
361
+ st.markdown(f"** {index + 1}.) Service** \n*Extract data from an API*")
362
+ options["url"] = self.text("URL", component, config, "url")
363
+ options["method"] = self.select("Method", component, config, "method", ["get", "post"], 0)
364
+ options["params"] = self.text("URL parameters", component, config, "params")
365
+ options["batch"] = self.boolean("Run as batch", component, config, "batch", True)
366
+ options["extract"] = self.text("Subsection(s) to extract", component, config, "extract")
367
+
368
+ if options["params"]:
369
+ options["params"] = {key: None for key in self.split(options["params"])}
370
+ if options["extract"]:
371
+ options["extract"] = self.split(options["extract"])
372
+
373
+ elif component == "summary":
374
+ st.markdown(f"** {index + 1}.) Summary** \n*Abstractive text summarization*")
375
+ options["path"] = self.text("Model", component, config, "path", "sshleifer/distilbart-cnn-12-6")
376
+ options["minlength"] = self.number("Min length", component, config, "minlength")
377
+ options["maxlength"] = self.number("Max length", component, config, "maxlength")
378
+
379
+ elif component == "tabular":
380
+ st.markdown(f"** {index + 1}.) Tabular** \n*Split tabular data into rows and columns*")
381
+ options["idcolumn"] = self.text("Id columns", component, config, "idcolumn")
382
+ options["textcolumns"] = self.text("Text columns", component, config, "textcolumns")
383
+ options["content"] = self.text("Content", component, config, "content")
384
+
385
+ if options["textcolumns"]:
386
+ options["textcolumns"] = self.split(options["textcolumns"])
387
+
388
+ if options["content"]:
389
+ options["content"] = self.split(options["content"])
390
+ if len(options["content"]) == 1 and options["content"][0] == "1":
391
+ options["content"] = options["content"][0]
392
+
393
+ elif component == "translation":
394
+ st.markdown(f"** {index + 1}.) Translate** \n*Machine translation*")
395
+ options["target"] = self.text("Target language code", component, config, "args", "en")
396
+
397
+ st.markdown("---")
398
+
399
+ return options
400
+
401
+ def yaml(self, components):
402
+ """
403
+ Builds yaml string for components
404
+ """
405
+
406
+ data = {"app": {"data": self.state("data"), "query": self.state("query")}}
407
+ tasks = []
408
+ name = None
409
+
410
+ for component in components:
411
+ component = dict(component)
412
+ name = wtype = component.pop("type")
413
+
414
+ if wtype == "embeddings":
415
+ upsert = component.pop("upsert")
416
+
417
+ data[wtype] = component
418
+ data["writable"] = True
419
+
420
+ name = "index"
421
+ tasks.append({"action": "upsert" if upsert else "index"})
422
+
423
+ elif wtype == "segmentation":
424
+ data[wtype] = component
425
+ tasks.append({"action": wtype})
426
+
427
+ elif wtype == "service":
428
+ config = dict(**component)
429
+ config["task"] = wtype
430
+ tasks.append(config)
431
+
432
+ elif wtype == "summary":
433
+ data[wtype] = {"path": component.pop("path")}
434
+ tasks.append({"action": wtype})
435
+
436
+ elif wtype == "tabular":
437
+ data[wtype] = component
438
+ tasks.append({"action": wtype})
439
+
440
+ elif wtype == "textractor":
441
+ data[wtype] = component
442
+ tasks.append({"action": wtype, "tasks": "url"})
443
+
444
+ elif wtype == "translation":
445
+ data[wtype] = component
446
+ tasks.append({"action": wtype, "args": list(component.values())})
447
+
448
+ # Add in workflow
449
+ data["workflow"] = {name: {"tasks": tasks}}
450
+
451
+ return (name, yaml.dump(data))
452
+
453
+ def data(self, workflow):
454
+ """
455
+ Gets input data
456
+ """
457
+
458
+ # Get default data setting
459
+ data = self.appsetting(workflow, "data")
460
+ if not self.appsetting(workflow, "query"):
461
+ data = st.text_input("Input", value=data)
462
+
463
+ # Save data state
464
+ st.session_state["data"] = data
465
+
466
+ # Wrap data as list for workflow processing
467
+ return [data]
468
+
469
+ def query(self, workflow, index):
470
+ """
471
+ Gets input query
472
+ """
473
+
474
+ default = self.appsetting(workflow, "query")
475
+ default = default if default else ""
476
+
477
+ # Get query if this is an indexing workflow
478
+ query = st.text_input("Query", value=default) if index else None
479
+
480
+ # Save query state
481
+ st.session_state["query"] = query
482
+
483
+ return query
484
+
485
+ def process(self, workflow, components, index):
486
+ """
487
+ Processes the current application action
488
+ """
489
+
490
+ # Get input data and initialize query
491
+ data = self.data(workflow)
492
+ query = self.query(workflow, index)
493
+
494
+ # Get workflow process
495
+ process = Process.get(components, data if index else None)
496
+
497
+ # Run workflow process
498
+ process.run(data)
499
+
500
+ # Run search
501
+ if index:
502
+ process.search(query)
503
+
504
+ def run(self):
505
+ """
506
+ Runs Streamlit application
507
+ """
508
+
509
+ with st.sidebar:
510
+ st.markdown("# Workflow builder for Station \n*Build and apply workflows to data about articles* ")
511
+ st.markdown("This is a demo for Station and the data used is from [Hugging Face](https://huggingface.co/datasets/ag_news/viewer/default/train).")
512
+ st.markdown("---")
513
+
514
+ # Component configuration
515
+ components = ["embeddings", "segmentation", "service", "summary", "tabular", "textractor", "translation"]
516
+
517
+ selected, workflow = self.load(components)
518
+ if selected:
519
+ # Get selected options
520
+ components = [self.options(component, workflow, x) for x, component in enumerate(selected)]
521
+
522
+ if selected:
523
+ # Process current action
524
+ self.process(workflow, components, "embeddings" in selected)
525
+
526
+ with st.sidebar:
527
+ # Generate export button after workflow is complete
528
+ _, config = self.yaml(components)
529
+ st.download_button("Export", config, file_name="workflow.yaml", help="Export the API workflow as YAML")
530
+ else:
531
+ st.info("Selected a workflow from the sidebar")
532
+
533
+ if __name__ == "__main__":
534
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
535
+
536
+ try:
537
+ nltk.sent_tokenize("This is a test. Split")
538
+ except:
539
+ nltk.download("punkt")
540
+
541
+ # Create and run application
542
+ app = Application("workflows")
543
+ app.run()