Spaces:
Configuration error
Configuration error
rungalileo
commited on
Commit
·
cce3221
1
Parent(s):
8a59595
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from concurrent.futures import ProcessPoolExecutor
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import List
|
7 |
+
from uuid import uuid4
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import plotly.express as px
|
12 |
+
import streamlit as st
|
13 |
+
from numerize.numerize import numerize
|
14 |
+
from sentence_transformers import SentenceTransformer
|
15 |
+
from streamlit.elements.file_uploader import UploadedFile
|
16 |
+
from streamlit_plotly_events import plotly_events
|
17 |
+
from umap import UMAP
|
18 |
+
|
19 |
+
sys.path.append(".")
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class SessionKey:
|
24 |
+
model: str = "model"
|
25 |
+
figure_state: str = "figure_state"
|
26 |
+
file: str = "file"
|
27 |
+
df: str = "df"
|
28 |
+
active_ids: str = "active_ids"
|
29 |
+
fig: str = "fig"
|
30 |
+
selected_points: str = "selected_points"
|
31 |
+
has_xy: str = "has_xy"
|
32 |
+
marker_size: str = "marker_size"
|
33 |
+
color: str = "color"
|
34 |
+
labels: str = "labels"
|
35 |
+
label_assignments: str = "label_assignments"
|
36 |
+
is_expanded: str = "is_expanded"
|
37 |
+
default_index: str = "default_index"
|
38 |
+
chosen_label: str = "chosen_label"
|
39 |
+
label_select_key: str = "label_select_key"
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class InternalCol:
|
44 |
+
hovertext: str = "hovertext"
|
45 |
+
x: str = "x"
|
46 |
+
y: str = "y"
|
47 |
+
|
48 |
+
|
49 |
+
INTERNAL_COLS = [InternalCol.hovertext, InternalCol.x, InternalCol.y]
|
50 |
+
|
51 |
+
|
52 |
+
def get_export_df() -> pd.DataFrame:
|
53 |
+
df2 = st.session_state[SessionKey.df].copy()
|
54 |
+
id_label = st.session_state[SessionKey.label_assignments]
|
55 |
+
df2["label"] = df2["id"].apply(lambda id_: id_label.get(id_, -1))
|
56 |
+
df2 = df2[df2["label"] != -1]
|
57 |
+
|
58 |
+
cols = [c for c in df2.columns if c not in INTERNAL_COLS]
|
59 |
+
return df2[cols]
|
60 |
+
|
61 |
+
|
62 |
+
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
63 |
+
umap_model = UMAP(n_neighbors=15, random_state=42, verbose=True)
|
64 |
+
st.set_page_config(layout="wide")
|
65 |
+
st.title("Laboratory 🧪")
|
66 |
+
col1, col2 = st.columns([3, 4])
|
67 |
+
|
68 |
+
# The registered labels by the user
|
69 |
+
if SessionKey.labels not in st.session_state:
|
70 |
+
st.session_state[SessionKey.labels] = []
|
71 |
+
|
72 |
+
# The assigned labels {tuple of ids: label}
|
73 |
+
if SessionKey.label_assignments not in st.session_state:
|
74 |
+
st.session_state[SessionKey.label_assignments] = {}
|
75 |
+
|
76 |
+
if SessionKey.label_select_key not in st.session_state:
|
77 |
+
st.session_state[SessionKey.label_select_key] = uuid4()
|
78 |
+
|
79 |
+
|
80 |
+
# Pre download the model
|
81 |
+
if SessionKey.model not in st.session_state.keys():
|
82 |
+
SentenceTransformer("paraphrase-MiniLM-L3-v2")
|
83 |
+
st.session_state[SessionKey.model] = True
|
84 |
+
|
85 |
+
|
86 |
+
def reset_plotly_figure(force: bool = False) -> None:
|
87 |
+
"""Reload the plotly chart from scratch, remove all state
|
88 |
+
|
89 |
+
We are using a tool called streamlit_plotly_events to capture and maintain selected
|
90 |
+
points from a lasso select. The issue with the package is that there's no way
|
91 |
+
(from what I can tell) to drop the state of the chart (the selected points).
|
92 |
+
|
93 |
+
But sometimes (often) a user wants to refresh and remove the selected points. The
|
94 |
+
package does come with a kwarg `key` which defines the `id` of the chart in case
|
95 |
+
you want to have multiple to keep tabs of. So if we change the `id` of the chart,
|
96 |
+
we can essentially refresh the chart and remove the selected points
|
97 |
+
"""
|
98 |
+
# The first time we call this, there won't be a figure state, but all subsequent
|
99 |
+
# times there will be, so we include the `force` param to opt into "re-clearing"
|
100 |
+
# the chart
|
101 |
+
if SessionKey.figure_state not in st.session_state or force:
|
102 |
+
st.session_state[SessionKey.figure_state] = str(uuid4())
|
103 |
+
|
104 |
+
|
105 |
+
def clear_state() -> None:
|
106 |
+
"""Clear the global state.
|
107 |
+
|
108 |
+
Either when a new file is uploaded, or when the user wants to "start over" on their
|
109 |
+
work
|
110 |
+
"""
|
111 |
+
for key in st.session_state.keys():
|
112 |
+
# No reason to delete the model, we have it downloaded and it doesn't change
|
113 |
+
if key != SessionKey.model:
|
114 |
+
del st.session_state[key]
|
115 |
+
|
116 |
+
|
117 |
+
def reset_embeddings() -> None:
|
118 |
+
"""Reset the embeddings view to full dataframe
|
119 |
+
|
120 |
+
Remove all global state that involves the embeddings or filters on the dataframe
|
121 |
+
"""
|
122 |
+
for key in [SessionKey.selected_points, SessionKey.fig, SessionKey.active_ids]:
|
123 |
+
if key in st.session_state:
|
124 |
+
del st.session_state[key]
|
125 |
+
reset_plotly_figure(force=True)
|
126 |
+
st.experimental_rerun()
|
127 |
+
|
128 |
+
|
129 |
+
def get_dataframe_file() -> UploadedFile:
|
130 |
+
file = st.sidebar.file_uploader(
|
131 |
+
"Upload your CSV text file", type="csv", on_change=clear_state
|
132 |
+
)
|
133 |
+
if SessionKey.file in st.session_state.keys():
|
134 |
+
return st.session_state[SessionKey.file]
|
135 |
+
st.session_state[SessionKey.file] = file
|
136 |
+
return file
|
137 |
+
|
138 |
+
|
139 |
+
def apply_emb_model(text_chunk: List[str]) -> np.ndarray:
|
140 |
+
model = SentenceTransformer("paraphrase-MiniLM-L3-v2")
|
141 |
+
return model.encode(text_chunk)
|
142 |
+
|
143 |
+
|
144 |
+
@st.cache(allow_output_mutation=True)
|
145 |
+
def get_text_embeddings(texts: List[str]) -> np.ndarray:
|
146 |
+
return apply_emb_model(texts)
|
147 |
+
# embs = []
|
148 |
+
# chunk_size = math.ceil(len(texts) / 10)
|
149 |
+
# text_chunks = [texts[i : i + chunk_size] for i in range(0, len(texts), chunk_size)]
|
150 |
+
#
|
151 |
+
# with ProcessPoolExecutor(max_workers=10) as pool:
|
152 |
+
# for text_chunk in text_chunks:
|
153 |
+
# embs.append(pool.submit(apply_emb_model, text_chunk))
|
154 |
+
#
|
155 |
+
# embs = [i.result() for i in embs]
|
156 |
+
# return np.concatenate(embs)
|
157 |
+
|
158 |
+
|
159 |
+
@st.cache(allow_output_mutation=True)
|
160 |
+
def get_umap_embeddings(embs: np.ndarray) -> np.ndarray:
|
161 |
+
return umap_model.fit_transform(embs)
|
162 |
+
|
163 |
+
|
164 |
+
def add_umap_embeddings(df: pd.DataFrame, emb_xy: np.ndarray) -> pd.DataFrame:
|
165 |
+
df["x"] = emb_xy[:, 0]
|
166 |
+
df["y"] = emb_xy[:, 1]
|
167 |
+
return df
|
168 |
+
|
169 |
+
|
170 |
+
def clear_state_after_export() -> None:
|
171 |
+
num_samples_exported = len(st.session_state[SessionKey.label_assignments])
|
172 |
+
st.info(f"Exported {num_samples_exported} labeled samples!", icon="ℹ️")
|
173 |
+
clear_state()
|
174 |
+
|
175 |
+
|
176 |
+
def export_label_assignments() -> None:
|
177 |
+
if SessionKey.df in st.session_state and len(st.session_state[SessionKey.df]):
|
178 |
+
df2 = get_export_df()
|
179 |
+
st.sidebar.download_button(
|
180 |
+
f"Download {numerize(len(df2))} samples",
|
181 |
+
df2.to_csv(index=False).encode("utf-8"),
|
182 |
+
file_name="export.csv",
|
183 |
+
mime="text/csv",
|
184 |
+
on_click=clear_state_after_export,
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
def assign_label() -> None:
|
189 |
+
"""Saves a given label with a list of IDs to apply to"""
|
190 |
+
key = st.session_state[SessionKey.label_select_key]
|
191 |
+
chosen_label = st.session_state[key]
|
192 |
+
ids_key = st.session_state[SessionKey.active_ids]
|
193 |
+
if chosen_label in st.session_state[SessionKey.labels]:
|
194 |
+
print(f"Setting {len(ids_key)} label to {chosen_label}")
|
195 |
+
for id_key in ids_key:
|
196 |
+
st.session_state[SessionKey.label_assignments][id_key] = chosen_label
|
197 |
+
|
198 |
+
st.session_state[SessionKey.default_index] = 0
|
199 |
+
st.session_state[SessionKey.is_expanded] = False
|
200 |
+
st.info(f"{len(ids_key)} samples labeled {chosen_label}", icon="ℹ️")
|
201 |
+
reset_plotly_figure(force=True)
|
202 |
+
st.session_state[SessionKey.label_select_key] = uuid4()
|
203 |
+
|
204 |
+
|
205 |
+
class Laboratory:
|
206 |
+
def __init__(self) -> None:
|
207 |
+
reset_plotly_figure()
|
208 |
+
# On page refresh, we need to reload our stateful attributes via session state
|
209 |
+
self.file = st.session_state.get(SessionKey.file)
|
210 |
+
self.df = st.session_state.get(SessionKey.df)
|
211 |
+
self.embs: np.ndarray = np.ndarray([])
|
212 |
+
self.umap_xy: np.ndarray = np.ndarray([])
|
213 |
+
self.selected_points: List[int] = []
|
214 |
+
self.ids = st.session_state.get(SessionKey.active_ids)
|
215 |
+
self.force_new_fig = False
|
216 |
+
|
217 |
+
self.sidebar()
|
218 |
+
|
219 |
+
# We create the scatterplot and then refresh the app, so that it's the
|
220 |
+
# first thing rendered. We need to do this because of the way that plotly_events
|
221 |
+
# works. It stores the selected samples from the lasso, and we need to first
|
222 |
+
# get those points and then filter the dataframe/embeddings based on them.
|
223 |
+
if SessionKey.fig in st.session_state:
|
224 |
+
with col2:
|
225 |
+
self.plot_figure()
|
226 |
+
|
227 |
+
if self.file:
|
228 |
+
with col1:
|
229 |
+
self.dataframe()
|
230 |
+
with col2:
|
231 |
+
self.embeddings()
|
232 |
+
self.create_figure()
|
233 |
+
|
234 |
+
def sidebar(self) -> None:
|
235 |
+
self.file = get_dataframe_file()
|
236 |
+
new_label = st.sidebar.text_input("Register Label")
|
237 |
+
if new_label and new_label not in st.session_state[SessionKey.labels]:
|
238 |
+
st.session_state[SessionKey.labels].append(new_label)
|
239 |
+
# all_labels = st.sidebar.empty()
|
240 |
+
with st.sidebar.expander("Current Labels"):
|
241 |
+
for label in st.session_state[SessionKey.labels]:
|
242 |
+
st.write(label)
|
243 |
+
assigned = st.session_state.get(SessionKey.label_assignments) or {}
|
244 |
+
if st.sidebar.button(
|
245 |
+
f"Export {len(assigned)} Assigned labels", disabled=not assigned
|
246 |
+
):
|
247 |
+
export_label_assignments()
|
248 |
+
|
249 |
+
st.sidebar.markdown("---")
|
250 |
+
# We don't want to be able to filter the dataframe until its fully processed
|
251 |
+
self.search_term = st.sidebar.text_input(
|
252 |
+
"Text Search", disabled=not st.session_state.get(SessionKey.has_xy, False)
|
253 |
+
)
|
254 |
+
st.sidebar.markdown("---")
|
255 |
+
if st.sidebar.button("Reset Selection"):
|
256 |
+
# We want to clear all selected points as well as the figure, and rerun
|
257 |
+
# the app. This will cause all lasso selections to go away and give us
|
258 |
+
# a fresh embedding scatterplot
|
259 |
+
print("exporting")
|
260 |
+
reset_embeddings()
|
261 |
+
|
262 |
+
default = st.session_state.get(SessionKey.marker_size, 2)
|
263 |
+
st.session_state[SessionKey.marker_size] = st.sidebar.slider(
|
264 |
+
"point size", min_value=1, max_value=20, value=default
|
265 |
+
)
|
266 |
+
color_by = ["<select>"]
|
267 |
+
if SessionKey.df in st.session_state:
|
268 |
+
df = st.session_state[SessionKey.df]
|
269 |
+
color_by += [c for c in df.columns if c not in ("id", "text", "hovertext")]
|
270 |
+
default_color = st.session_state.get(SessionKey.color) or "<select>"
|
271 |
+
default_index = color_by.index(default_color)
|
272 |
+
color = st.sidebar.selectbox("Color By", color_by, index=default_index)
|
273 |
+
st.session_state[SessionKey.color] = None if color == "<select>" else color
|
274 |
+
|
275 |
+
def dataframe(self) -> None:
|
276 |
+
st.subheader("DataFrame")
|
277 |
+
|
278 |
+
if SessionKey.df not in st.session_state and self.file is not None:
|
279 |
+
self.df = pd.read_csv(self.file)
|
280 |
+
self.df["id"] = self.df.index
|
281 |
+
self.df["text_length"] = self.df["text"].str.len()
|
282 |
+
self.df["hovertext"] = self.df.text.str.wrap(30).str.replace("\n", "<br>")
|
283 |
+
# Checkpoint the df
|
284 |
+
st.session_state[SessionKey.df] = self.df
|
285 |
+
|
286 |
+
assert self.df is not None
|
287 |
+
# Apply search
|
288 |
+
self.df = self.df[
|
289 |
+
self.df.apply(lambda row: self.search_term in row["text"], axis=1)
|
290 |
+
]
|
291 |
+
if st.session_state.get(SessionKey.selected_points):
|
292 |
+
filter_ids = [
|
293 |
+
i["pointIndex"] for i in st.session_state[SessionKey.selected_points]
|
294 |
+
]
|
295 |
+
self.df = self.df[self.df["id"].isin(filter_ids)]
|
296 |
+
# If this is a new lasso selection (new filter_ids), then we want to
|
297 |
+
# redraw and re-render the embedding scatter.
|
298 |
+
|
299 |
+
self.ids = self.df["id"].tolist()
|
300 |
+
# Checkpoint the filtered ids
|
301 |
+
st.session_state[SessionKey.active_ids] = self.ids
|
302 |
+
|
303 |
+
showcols = [c for c in self.df.columns if c not in ("hovertext", "Unnamed: 0")]
|
304 |
+
st.write(f"({len(self.df)}) active rows")
|
305 |
+
label_assigner = st.expander(
|
306 |
+
"Set label for selection",
|
307 |
+
expanded=st.session_state.get(SessionKey.is_expanded, False),
|
308 |
+
)
|
309 |
+
with label_assigner:
|
310 |
+
avl_labels = ["<select>"] + st.session_state[SessionKey.labels]
|
311 |
+
st.selectbox(
|
312 |
+
"Choose Label",
|
313 |
+
avl_labels,
|
314 |
+
key=st.session_state[SessionKey.label_select_key],
|
315 |
+
on_change=assign_label,
|
316 |
+
)
|
317 |
+
|
318 |
+
st.dataframe(self.df[showcols], height=800)
|
319 |
+
|
320 |
+
def create_figure(self) -> None:
|
321 |
+
p = px.scatter(
|
322 |
+
self.df,
|
323 |
+
x="x",
|
324 |
+
y="y",
|
325 |
+
color=st.session_state[SessionKey.color],
|
326 |
+
hover_data=["hovertext"],
|
327 |
+
)
|
328 |
+
p.update_traces(marker_size=st.session_state[SessionKey.marker_size])
|
329 |
+
# If there's no figure yet or it's changed, refresh and replot it
|
330 |
+
if (
|
331 |
+
SessionKey.fig not in st.session_state
|
332 |
+
or st.session_state[SessionKey.fig] != p
|
333 |
+
):
|
334 |
+
st.session_state[SessionKey.fig] = p
|
335 |
+
print("Forcing refresh")
|
336 |
+
st.experimental_rerun()
|
337 |
+
|
338 |
+
def plot_figure(self) -> None:
|
339 |
+
st.subheader("Embeddings")
|
340 |
+
st.session_state[SessionKey.selected_points] = plotly_events(
|
341 |
+
st.session_state[SessionKey.fig],
|
342 |
+
select_event=True,
|
343 |
+
override_height=800,
|
344 |
+
key=st.session_state[SessionKey.figure_state],
|
345 |
+
)
|
346 |
+
|
347 |
+
def embeddings(self) -> None:
|
348 |
+
# Only calculate the UMAP embeddings once for a given dataframe. If we've
|
349 |
+
# already done it, save the `has_xy` state and don't recalculate
|
350 |
+
if SessionKey.has_xy not in st.session_state and self.df is not None:
|
351 |
+
progress = st.empty()
|
352 |
+
progress.text("Getting embeddings for text")
|
353 |
+
self.embs = get_text_embeddings(self.df.text.tolist())
|
354 |
+
progress.text("Applying UMAP")
|
355 |
+
self.umap_xy = get_umap_embeddings(self.embs)
|
356 |
+
progress.text("")
|
357 |
+
self.df = add_umap_embeddings(self.df, self.umap_xy)
|
358 |
+
st.session_state[SessionKey.df] = self.df
|
359 |
+
# Set so we don't have to recalculate this on every interaction with the app
|
360 |
+
st.session_state[SessionKey.has_xy] = True
|
361 |
+
|
362 |
+
|
363 |
+
if __name__ == "__main__":
|
364 |
+
Laboratory()
|