rungalileo commited on
Commit
cce3221
·
1 Parent(s): 8a59595

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +364 -0
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import sys
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ from dataclasses import dataclass
6
+ from typing import List
7
+ from uuid import uuid4
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import plotly.express as px
12
+ import streamlit as st
13
+ from numerize.numerize import numerize
14
+ from sentence_transformers import SentenceTransformer
15
+ from streamlit.elements.file_uploader import UploadedFile
16
+ from streamlit_plotly_events import plotly_events
17
+ from umap import UMAP
18
+
19
+ sys.path.append(".")
20
+
21
+
22
+ @dataclass
23
+ class SessionKey:
24
+ model: str = "model"
25
+ figure_state: str = "figure_state"
26
+ file: str = "file"
27
+ df: str = "df"
28
+ active_ids: str = "active_ids"
29
+ fig: str = "fig"
30
+ selected_points: str = "selected_points"
31
+ has_xy: str = "has_xy"
32
+ marker_size: str = "marker_size"
33
+ color: str = "color"
34
+ labels: str = "labels"
35
+ label_assignments: str = "label_assignments"
36
+ is_expanded: str = "is_expanded"
37
+ default_index: str = "default_index"
38
+ chosen_label: str = "chosen_label"
39
+ label_select_key: str = "label_select_key"
40
+
41
+
42
+ @dataclass
43
+ class InternalCol:
44
+ hovertext: str = "hovertext"
45
+ x: str = "x"
46
+ y: str = "y"
47
+
48
+
49
+ INTERNAL_COLS = [InternalCol.hovertext, InternalCol.x, InternalCol.y]
50
+
51
+
52
+ def get_export_df() -> pd.DataFrame:
53
+ df2 = st.session_state[SessionKey.df].copy()
54
+ id_label = st.session_state[SessionKey.label_assignments]
55
+ df2["label"] = df2["id"].apply(lambda id_: id_label.get(id_, -1))
56
+ df2 = df2[df2["label"] != -1]
57
+
58
+ cols = [c for c in df2.columns if c not in INTERNAL_COLS]
59
+ return df2[cols]
60
+
61
+
62
+ # os.environ["TOKENIZERS_PARALLELISM"] = "false"
63
+ umap_model = UMAP(n_neighbors=15, random_state=42, verbose=True)
64
+ st.set_page_config(layout="wide")
65
+ st.title("Laboratory 🧪")
66
+ col1, col2 = st.columns([3, 4])
67
+
68
+ # The registered labels by the user
69
+ if SessionKey.labels not in st.session_state:
70
+ st.session_state[SessionKey.labels] = []
71
+
72
+ # The assigned labels {tuple of ids: label}
73
+ if SessionKey.label_assignments not in st.session_state:
74
+ st.session_state[SessionKey.label_assignments] = {}
75
+
76
+ if SessionKey.label_select_key not in st.session_state:
77
+ st.session_state[SessionKey.label_select_key] = uuid4()
78
+
79
+
80
+ # Pre download the model
81
+ if SessionKey.model not in st.session_state.keys():
82
+ SentenceTransformer("paraphrase-MiniLM-L3-v2")
83
+ st.session_state[SessionKey.model] = True
84
+
85
+
86
+ def reset_plotly_figure(force: bool = False) -> None:
87
+ """Reload the plotly chart from scratch, remove all state
88
+
89
+ We are using a tool called streamlit_plotly_events to capture and maintain selected
90
+ points from a lasso select. The issue with the package is that there's no way
91
+ (from what I can tell) to drop the state of the chart (the selected points).
92
+
93
+ But sometimes (often) a user wants to refresh and remove the selected points. The
94
+ package does come with a kwarg `key` which defines the `id` of the chart in case
95
+ you want to have multiple to keep tabs of. So if we change the `id` of the chart,
96
+ we can essentially refresh the chart and remove the selected points
97
+ """
98
+ # The first time we call this, there won't be a figure state, but all subsequent
99
+ # times there will be, so we include the `force` param to opt into "re-clearing"
100
+ # the chart
101
+ if SessionKey.figure_state not in st.session_state or force:
102
+ st.session_state[SessionKey.figure_state] = str(uuid4())
103
+
104
+
105
+ def clear_state() -> None:
106
+ """Clear the global state.
107
+
108
+ Either when a new file is uploaded, or when the user wants to "start over" on their
109
+ work
110
+ """
111
+ for key in st.session_state.keys():
112
+ # No reason to delete the model, we have it downloaded and it doesn't change
113
+ if key != SessionKey.model:
114
+ del st.session_state[key]
115
+
116
+
117
+ def reset_embeddings() -> None:
118
+ """Reset the embeddings view to full dataframe
119
+
120
+ Remove all global state that involves the embeddings or filters on the dataframe
121
+ """
122
+ for key in [SessionKey.selected_points, SessionKey.fig, SessionKey.active_ids]:
123
+ if key in st.session_state:
124
+ del st.session_state[key]
125
+ reset_plotly_figure(force=True)
126
+ st.experimental_rerun()
127
+
128
+
129
+ def get_dataframe_file() -> UploadedFile:
130
+ file = st.sidebar.file_uploader(
131
+ "Upload your CSV text file", type="csv", on_change=clear_state
132
+ )
133
+ if SessionKey.file in st.session_state.keys():
134
+ return st.session_state[SessionKey.file]
135
+ st.session_state[SessionKey.file] = file
136
+ return file
137
+
138
+
139
+ def apply_emb_model(text_chunk: List[str]) -> np.ndarray:
140
+ model = SentenceTransformer("paraphrase-MiniLM-L3-v2")
141
+ return model.encode(text_chunk)
142
+
143
+
144
+ @st.cache(allow_output_mutation=True)
145
+ def get_text_embeddings(texts: List[str]) -> np.ndarray:
146
+ return apply_emb_model(texts)
147
+ # embs = []
148
+ # chunk_size = math.ceil(len(texts) / 10)
149
+ # text_chunks = [texts[i : i + chunk_size] for i in range(0, len(texts), chunk_size)]
150
+ #
151
+ # with ProcessPoolExecutor(max_workers=10) as pool:
152
+ # for text_chunk in text_chunks:
153
+ # embs.append(pool.submit(apply_emb_model, text_chunk))
154
+ #
155
+ # embs = [i.result() for i in embs]
156
+ # return np.concatenate(embs)
157
+
158
+
159
+ @st.cache(allow_output_mutation=True)
160
+ def get_umap_embeddings(embs: np.ndarray) -> np.ndarray:
161
+ return umap_model.fit_transform(embs)
162
+
163
+
164
+ def add_umap_embeddings(df: pd.DataFrame, emb_xy: np.ndarray) -> pd.DataFrame:
165
+ df["x"] = emb_xy[:, 0]
166
+ df["y"] = emb_xy[:, 1]
167
+ return df
168
+
169
+
170
+ def clear_state_after_export() -> None:
171
+ num_samples_exported = len(st.session_state[SessionKey.label_assignments])
172
+ st.info(f"Exported {num_samples_exported} labeled samples!", icon="ℹ️")
173
+ clear_state()
174
+
175
+
176
+ def export_label_assignments() -> None:
177
+ if SessionKey.df in st.session_state and len(st.session_state[SessionKey.df]):
178
+ df2 = get_export_df()
179
+ st.sidebar.download_button(
180
+ f"Download {numerize(len(df2))} samples",
181
+ df2.to_csv(index=False).encode("utf-8"),
182
+ file_name="export.csv",
183
+ mime="text/csv",
184
+ on_click=clear_state_after_export,
185
+ )
186
+
187
+
188
+ def assign_label() -> None:
189
+ """Saves a given label with a list of IDs to apply to"""
190
+ key = st.session_state[SessionKey.label_select_key]
191
+ chosen_label = st.session_state[key]
192
+ ids_key = st.session_state[SessionKey.active_ids]
193
+ if chosen_label in st.session_state[SessionKey.labels]:
194
+ print(f"Setting {len(ids_key)} label to {chosen_label}")
195
+ for id_key in ids_key:
196
+ st.session_state[SessionKey.label_assignments][id_key] = chosen_label
197
+
198
+ st.session_state[SessionKey.default_index] = 0
199
+ st.session_state[SessionKey.is_expanded] = False
200
+ st.info(f"{len(ids_key)} samples labeled {chosen_label}", icon="ℹ️")
201
+ reset_plotly_figure(force=True)
202
+ st.session_state[SessionKey.label_select_key] = uuid4()
203
+
204
+
205
+ class Laboratory:
206
+ def __init__(self) -> None:
207
+ reset_plotly_figure()
208
+ # On page refresh, we need to reload our stateful attributes via session state
209
+ self.file = st.session_state.get(SessionKey.file)
210
+ self.df = st.session_state.get(SessionKey.df)
211
+ self.embs: np.ndarray = np.ndarray([])
212
+ self.umap_xy: np.ndarray = np.ndarray([])
213
+ self.selected_points: List[int] = []
214
+ self.ids = st.session_state.get(SessionKey.active_ids)
215
+ self.force_new_fig = False
216
+
217
+ self.sidebar()
218
+
219
+ # We create the scatterplot and then refresh the app, so that it's the
220
+ # first thing rendered. We need to do this because of the way that plotly_events
221
+ # works. It stores the selected samples from the lasso, and we need to first
222
+ # get those points and then filter the dataframe/embeddings based on them.
223
+ if SessionKey.fig in st.session_state:
224
+ with col2:
225
+ self.plot_figure()
226
+
227
+ if self.file:
228
+ with col1:
229
+ self.dataframe()
230
+ with col2:
231
+ self.embeddings()
232
+ self.create_figure()
233
+
234
+ def sidebar(self) -> None:
235
+ self.file = get_dataframe_file()
236
+ new_label = st.sidebar.text_input("Register Label")
237
+ if new_label and new_label not in st.session_state[SessionKey.labels]:
238
+ st.session_state[SessionKey.labels].append(new_label)
239
+ # all_labels = st.sidebar.empty()
240
+ with st.sidebar.expander("Current Labels"):
241
+ for label in st.session_state[SessionKey.labels]:
242
+ st.write(label)
243
+ assigned = st.session_state.get(SessionKey.label_assignments) or {}
244
+ if st.sidebar.button(
245
+ f"Export {len(assigned)} Assigned labels", disabled=not assigned
246
+ ):
247
+ export_label_assignments()
248
+
249
+ st.sidebar.markdown("---")
250
+ # We don't want to be able to filter the dataframe until its fully processed
251
+ self.search_term = st.sidebar.text_input(
252
+ "Text Search", disabled=not st.session_state.get(SessionKey.has_xy, False)
253
+ )
254
+ st.sidebar.markdown("---")
255
+ if st.sidebar.button("Reset Selection"):
256
+ # We want to clear all selected points as well as the figure, and rerun
257
+ # the app. This will cause all lasso selections to go away and give us
258
+ # a fresh embedding scatterplot
259
+ print("exporting")
260
+ reset_embeddings()
261
+
262
+ default = st.session_state.get(SessionKey.marker_size, 2)
263
+ st.session_state[SessionKey.marker_size] = st.sidebar.slider(
264
+ "point size", min_value=1, max_value=20, value=default
265
+ )
266
+ color_by = ["<select>"]
267
+ if SessionKey.df in st.session_state:
268
+ df = st.session_state[SessionKey.df]
269
+ color_by += [c for c in df.columns if c not in ("id", "text", "hovertext")]
270
+ default_color = st.session_state.get(SessionKey.color) or "<select>"
271
+ default_index = color_by.index(default_color)
272
+ color = st.sidebar.selectbox("Color By", color_by, index=default_index)
273
+ st.session_state[SessionKey.color] = None if color == "<select>" else color
274
+
275
+ def dataframe(self) -> None:
276
+ st.subheader("DataFrame")
277
+
278
+ if SessionKey.df not in st.session_state and self.file is not None:
279
+ self.df = pd.read_csv(self.file)
280
+ self.df["id"] = self.df.index
281
+ self.df["text_length"] = self.df["text"].str.len()
282
+ self.df["hovertext"] = self.df.text.str.wrap(30).str.replace("\n", "<br>")
283
+ # Checkpoint the df
284
+ st.session_state[SessionKey.df] = self.df
285
+
286
+ assert self.df is not None
287
+ # Apply search
288
+ self.df = self.df[
289
+ self.df.apply(lambda row: self.search_term in row["text"], axis=1)
290
+ ]
291
+ if st.session_state.get(SessionKey.selected_points):
292
+ filter_ids = [
293
+ i["pointIndex"] for i in st.session_state[SessionKey.selected_points]
294
+ ]
295
+ self.df = self.df[self.df["id"].isin(filter_ids)]
296
+ # If this is a new lasso selection (new filter_ids), then we want to
297
+ # redraw and re-render the embedding scatter.
298
+
299
+ self.ids = self.df["id"].tolist()
300
+ # Checkpoint the filtered ids
301
+ st.session_state[SessionKey.active_ids] = self.ids
302
+
303
+ showcols = [c for c in self.df.columns if c not in ("hovertext", "Unnamed: 0")]
304
+ st.write(f"({len(self.df)}) active rows")
305
+ label_assigner = st.expander(
306
+ "Set label for selection",
307
+ expanded=st.session_state.get(SessionKey.is_expanded, False),
308
+ )
309
+ with label_assigner:
310
+ avl_labels = ["<select>"] + st.session_state[SessionKey.labels]
311
+ st.selectbox(
312
+ "Choose Label",
313
+ avl_labels,
314
+ key=st.session_state[SessionKey.label_select_key],
315
+ on_change=assign_label,
316
+ )
317
+
318
+ st.dataframe(self.df[showcols], height=800)
319
+
320
+ def create_figure(self) -> None:
321
+ p = px.scatter(
322
+ self.df,
323
+ x="x",
324
+ y="y",
325
+ color=st.session_state[SessionKey.color],
326
+ hover_data=["hovertext"],
327
+ )
328
+ p.update_traces(marker_size=st.session_state[SessionKey.marker_size])
329
+ # If there's no figure yet or it's changed, refresh and replot it
330
+ if (
331
+ SessionKey.fig not in st.session_state
332
+ or st.session_state[SessionKey.fig] != p
333
+ ):
334
+ st.session_state[SessionKey.fig] = p
335
+ print("Forcing refresh")
336
+ st.experimental_rerun()
337
+
338
+ def plot_figure(self) -> None:
339
+ st.subheader("Embeddings")
340
+ st.session_state[SessionKey.selected_points] = plotly_events(
341
+ st.session_state[SessionKey.fig],
342
+ select_event=True,
343
+ override_height=800,
344
+ key=st.session_state[SessionKey.figure_state],
345
+ )
346
+
347
+ def embeddings(self) -> None:
348
+ # Only calculate the UMAP embeddings once for a given dataframe. If we've
349
+ # already done it, save the `has_xy` state and don't recalculate
350
+ if SessionKey.has_xy not in st.session_state and self.df is not None:
351
+ progress = st.empty()
352
+ progress.text("Getting embeddings for text")
353
+ self.embs = get_text_embeddings(self.df.text.tolist())
354
+ progress.text("Applying UMAP")
355
+ self.umap_xy = get_umap_embeddings(self.embs)
356
+ progress.text("")
357
+ self.df = add_umap_embeddings(self.df, self.umap_xy)
358
+ st.session_state[SessionKey.df] = self.df
359
+ # Set so we don't have to recalculate this on every interaction with the app
360
+ st.session_state[SessionKey.has_xy] = True
361
+
362
+
363
+ if __name__ == "__main__":
364
+ Laboratory()