Sonnyjim commited on
Commit
e1c1f68
1 Parent(s): 0a543a0

Reduce outliers now more efficient and relabels with correct vectoriser. Default topic labels now tidier. Hiearchical topics outputs more useful for joining to df afterwards. Switched low resource reduction algorithm to UMAP as default is not good.

Browse files
app.py CHANGED
@@ -7,7 +7,7 @@ import pandas as pd
7
  import numpy as np
8
 
9
  from funcs.topic_core_funcs import pre_clean, extract_topics, reduce_outliers, represent_topics, visualise_topics, save_as_pytorch_model
10
- from funcs.helper_functions import dummy_function, initial_file_load, custom_regex_load
11
  from sklearn.feature_extraction.text import CountVectorizer
12
 
13
  # Gradio app
@@ -20,6 +20,7 @@ with block:
20
  embeddings_state = gr.State(np.array([]))
21
  embeddings_type_state = gr.State("")
22
  topic_model_state = gr.State()
 
23
  custom_regex_state = gr.State(pd.DataFrame())
24
  docs_state = gr.State()
25
  data_file_name_no_ext_state = gr.State()
@@ -104,23 +105,22 @@ with block:
104
 
105
  # Load in data. Update column names dropdown when file uploaded
106
  in_files.upload(fn=initial_file_load, inputs=[in_files], outputs=[in_colnames, in_label, data_state, output_single_text, topic_model_state, embeddings_state, data_file_name_no_ext_state, label_list_state])
107
- in_colnames.change(dummy_function, in_colnames, None)
108
 
109
  # Clean data
110
  custom_regex.upload(fn=custom_regex_load, inputs=[custom_regex], outputs=[custom_regex_text, custom_regex_state])
111
  clean_btn.click(fn=pre_clean, inputs=[data_state, in_colnames, data_file_name_no_ext_state, custom_regex_state, clean_text, drop_duplicate_text, anonymise_drop], outputs=[output_single_text, output_file, data_state, data_file_name_no_ext_state], api_name="clean")
112
 
113
  # Extract topics
114
- topics_btn.click(fn=extract_topics, inputs=[data_state, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext_state, label_list_state, return_intermediate_files, embedding_super_compress, low_resource_mode_opt, save_topic_model, embeddings_state, embeddings_type_state, zero_shot_similarity, seed_number, calc_probs, vectoriser_state], outputs=[output_single_text, output_file, embeddings_state, embeddings_type_state, data_file_name_no_ext_state, topic_model_state, docs_state, vectoriser_state], api_name="topics")
115
 
116
  # Reduce outliers
117
- reduce_outliers_btn.click(fn=reduce_outliers, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="reduce_outliers")
118
 
119
  # Re-represent topic labels
120
  represent_llm_btn.click(fn=represent_topics, inputs=[topic_model_state, docs_state, data_file_name_no_ext_state, low_resource_mode_opt, save_topic_model, representation_type, vectoriser_state], outputs=[output_single_text, output_file, topic_model_state], api_name="represent_llm")
121
 
122
  # Save in Pytorch format
123
- save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
124
 
125
  # Visualise topics
126
  plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, legend_label, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
 
7
  import numpy as np
8
 
9
  from funcs.topic_core_funcs import pre_clean, extract_topics, reduce_outliers, represent_topics, visualise_topics, save_as_pytorch_model
10
+ from funcs.helper_functions import initial_file_load, custom_regex_load
11
  from sklearn.feature_extraction.text import CountVectorizer
12
 
13
  # Gradio app
 
20
  embeddings_state = gr.State(np.array([]))
21
  embeddings_type_state = gr.State("")
22
  topic_model_state = gr.State()
23
+ assigned_topics_state = gr.State([])
24
  custom_regex_state = gr.State(pd.DataFrame())
25
  docs_state = gr.State()
26
  data_file_name_no_ext_state = gr.State()
 
105
 
106
  # Load in data. Update column names dropdown when file uploaded
107
  in_files.upload(fn=initial_file_load, inputs=[in_files], outputs=[in_colnames, in_label, data_state, output_single_text, topic_model_state, embeddings_state, data_file_name_no_ext_state, label_list_state])
 
108
 
109
  # Clean data
110
  custom_regex.upload(fn=custom_regex_load, inputs=[custom_regex], outputs=[custom_regex_text, custom_regex_state])
111
  clean_btn.click(fn=pre_clean, inputs=[data_state, in_colnames, data_file_name_no_ext_state, custom_regex_state, clean_text, drop_duplicate_text, anonymise_drop], outputs=[output_single_text, output_file, data_state, data_file_name_no_ext_state], api_name="clean")
112
 
113
  # Extract topics
114
+ topics_btn.click(fn=extract_topics, inputs=[data_state, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext_state, label_list_state, return_intermediate_files, embedding_super_compress, low_resource_mode_opt, save_topic_model, embeddings_state, embeddings_type_state, zero_shot_similarity, seed_number, calc_probs, vectoriser_state], outputs=[output_single_text, output_file, embeddings_state, embeddings_type_state, data_file_name_no_ext_state, topic_model_state, docs_state, vectoriser_state, assigned_topics_state], api_name="topics")
115
 
116
  # Reduce outliers
117
+ reduce_outliers_btn.click(fn=reduce_outliers, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, assigned_topics_state, vectoriser_state, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="reduce_outliers")
118
 
119
  # Re-represent topic labels
120
  represent_llm_btn.click(fn=represent_topics, inputs=[topic_model_state, docs_state, data_file_name_no_ext_state, low_resource_mode_opt, save_topic_model, representation_type, vectoriser_state], outputs=[output_single_text, output_file, topic_model_state], api_name="represent_llm")
121
 
122
  # Save in Pytorch format
123
+ save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file], api_name="pytorch_save")
124
 
125
  # Visualise topics
126
  plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, legend_label, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
funcs/bertopic_hierarchical_documents.py DELETED
@@ -1,336 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import plotly.graph_objects as go
4
- import math
5
-
6
- from umap import UMAP
7
- from typing import List, Union
8
-
9
-
10
- def visualize_hierarchical_documents(topic_model,
11
- docs: List[str],
12
- hierarchical_topics: pd.DataFrame,
13
- topics: List[int] = None,
14
- embeddings: np.ndarray = None,
15
- reduced_embeddings: np.ndarray = None,
16
- sample: Union[float, int] = None,
17
- hide_annotations: bool = False,
18
- hide_document_hover: bool = True,
19
- nr_levels: int = 10,
20
- level_scale: str = 'linear',
21
- custom_labels: Union[bool, str] = False,
22
- title: str = "<b>Hierarchical Documents and Topics</b>",
23
- width: int = 1200,
24
- height: int = 750) -> go.Figure:
25
- """ Visualize documents and their topics in 2D at different levels of hierarchy
26
-
27
- Arguments:
28
- docs: The documents you used when calling either `fit` or `fit_transform`
29
- hierarchical_topics: A dataframe that contains a hierarchy of topics
30
- represented by their parents and their children
31
- topics: A selection of topics to visualize.
32
- Not to be confused with the topics that you get from `.fit_transform`.
33
- For example, if you want to visualize only topics 1 through 5:
34
- `topics = [1, 2, 3, 4, 5]`.
35
- embeddings: The embeddings of all documents in `docs`.
36
- reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
37
- sample: The percentage of documents in each topic that you would like to keep.
38
- Value can be between 0 and 1. Setting this value to, for example,
39
- 0.1 (10% of documents in each topic) makes it easier to visualize
40
- millions of documents as a subset is chosen.
41
- hide_annotations: Hide the names of the traces on top of each cluster.
42
- hide_document_hover: Hide the content of the documents when hovering over
43
- specific points. Helps to speed up generation of visualizations.
44
- nr_levels: The number of levels to be visualized in the hierarchy. First, the distances
45
- in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances.
46
- Then, for each list of distances, the merged topics are selected that have a
47
- distance less or equal to the maximum distance of the selected list of distances.
48
- NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to
49
- the length of `hierarchical_topics`.
50
- level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance
51
- vector. Linear scaling will perform an equal number of merges at each level
52
- while logarithmic scaling will perform more mergers in earlier levels to
53
- provide more resolution at higher levels (this can be used for when the number
54
- of topics is large).
55
- custom_labels: If bool, whether to use custom topic labels that were defined using
56
- `topic_model.set_topic_labels`.
57
- If `str`, it uses labels from other aspects, e.g., "Aspect1".
58
- NOTE: Custom labels are only generated for the original
59
- un-merged topics.
60
- title: Title of the plot.
61
- width: The width of the figure.
62
- height: The height of the figure.
63
-
64
- Examples:
65
-
66
- To visualize the topics simply run:
67
-
68
- ```python
69
- topic_model.visualize_hierarchical_documents(docs, hierarchical_topics)
70
- ```
71
-
72
- Do note that this re-calculates the embeddings and reduces them to 2D.
73
- The advised and prefered pipeline for using this function is as follows:
74
-
75
- ```python
76
- from sklearn.datasets import fetch_20newsgroups
77
- from sentence_transformers import SentenceTransformer
78
- from bertopic import BERTopic
79
- from umap import UMAP
80
-
81
- # Prepare embeddings
82
- docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
83
- sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
84
- embeddings = sentence_model.encode(docs, show_progress_bar=False)
85
-
86
- # Train BERTopic and extract hierarchical topics
87
- topic_model = BERTopic().fit(docs, embeddings)
88
- hierarchical_topics = topic_model.hierarchical_topics(docs)
89
-
90
- # Reduce dimensionality of embeddings, this step is optional
91
- # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
92
-
93
- # Run the visualization with the original embeddings
94
- topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings)
95
-
96
- # Or, if you have reduced the original embeddings already:
97
- topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
98
- ```
99
-
100
- Or if you want to save the resulting figure:
101
-
102
- ```python
103
- fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
104
- fig.write_html("path/to/file.html")
105
- ```
106
-
107
- NOTE:
108
- This visualization was inspired by the scatter plot representation of Doc2Map:
109
- https://github.com/louisgeisler/Doc2Map
110
-
111
- <iframe src="../../getting_started/visualization/hierarchical_documents.html"
112
- style="width:1000px; height: 770px; border: 0px;""></iframe>
113
- """
114
- topic_per_doc = topic_model.topics_
115
-
116
- # Sample the data to optimize for visualization and dimensionality reduction
117
- if sample is None or sample > 1:
118
- sample = 1
119
-
120
- indices = []
121
- for topic in set(topic_per_doc):
122
- s = np.where(np.array(topic_per_doc) == topic)[0]
123
- size = len(s) if len(s) < 100 else int(len(s)*sample)
124
- indices.extend(np.random.choice(s, size=size, replace=False))
125
- indices = np.array(indices)
126
-
127
- df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
128
- df["doc"] = [docs[index] for index in indices]
129
- df["topic"] = [topic_per_doc[index] for index in indices]
130
-
131
- # Extract embeddings if not already done
132
- if sample is None:
133
- if embeddings is None and reduced_embeddings is None:
134
- embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
135
- else:
136
- embeddings_to_reduce = embeddings
137
- else:
138
- if embeddings is not None:
139
- embeddings_to_reduce = embeddings[indices]
140
- elif embeddings is None and reduced_embeddings is None:
141
- embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
142
-
143
- # Reduce input embeddings
144
- if reduced_embeddings is None:
145
- umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
146
- embeddings_2d = umap_model.embedding_
147
- elif sample is not None and reduced_embeddings is not None:
148
- embeddings_2d = reduced_embeddings[indices]
149
- elif sample is None and reduced_embeddings is not None:
150
- embeddings_2d = reduced_embeddings
151
-
152
- # Combine data
153
- df["x"] = embeddings_2d[:, 0]
154
- df["y"] = embeddings_2d[:, 1]
155
-
156
- # Create topic list for each level, levels are created by calculating the distance
157
- distances = hierarchical_topics.Distance.to_list()
158
- if level_scale == 'log' or level_scale == 'logarithmic':
159
- log_indices = np.round(np.logspace(start=math.log(1,10), stop=math.log(len(distances)-1,10), num=nr_levels)).astype(int).tolist()
160
- log_indices.reverse()
161
- max_distances = [distances[i] for i in log_indices]
162
- elif level_scale == 'lin' or level_scale == 'linear':
163
- max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1]
164
- else:
165
- raise ValueError("level_scale needs to be one of 'log' or 'linear'")
166
-
167
- for index, max_distance in enumerate(max_distances):
168
-
169
- # Get topics below `max_distance`
170
- mapping = {topic: topic for topic in df.topic.unique()}
171
- selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :]
172
- selection.Parent_ID = selection.Parent_ID.astype(int)
173
- selection = selection.sort_values("Parent_ID")
174
-
175
- for row in selection.iterrows():
176
- for topic in row[1].Topics:
177
- mapping[topic] = row[1].Parent_ID
178
-
179
- # Make sure the mappings are mapped 1:1
180
- mappings = [True for _ in mapping]
181
- while any(mappings):
182
- for i, (key, value) in enumerate(mapping.items()):
183
- if value in mapping.keys() and key != value:
184
- mapping[key] = mapping[value]
185
- else:
186
- mappings[i] = False
187
-
188
- # Create new column
189
- df[f"level_{index+1}"] = df.topic.map(mapping)
190
- df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int)
191
-
192
- # Prepare topic names of original and merged topics
193
- trace_names = []
194
- topic_names = {}
195
- for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
196
- if topic < hierarchical_topics.Parent_ID.astype(int).min():
197
- if topic_model.get_topic(topic):
198
- if isinstance(custom_labels, str):
199
- trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
200
- elif topic_model.custom_labels_ is not None and custom_labels:
201
- trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
202
- else:
203
- trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3])
204
- topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]}
205
- trace_names.append(trace_name)
206
- else:
207
- trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
208
- plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]])
209
- topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]}
210
- trace_names.append(trace_name)
211
-
212
- # Prepare traces
213
- all_traces = []
214
- for level in range(len(max_distances)):
215
- traces = []
216
-
217
- # Outliers
218
- if topic_model._outliers:
219
- traces.append(
220
- go.Scattergl(
221
- x=df.loc[(df[f"level_{level+1}"] == -1), "x"],
222
- y=df.loc[df[f"level_{level+1}"] == -1, "y"],
223
- mode='markers+text',
224
- name="other",
225
- hoverinfo="text",
226
- hovertext=df.loc[(df[f"level_{level+1}"] == -1), "doc"] if not hide_document_hover else None,
227
- showlegend=False,
228
- marker=dict(color='#CFD8DC', size=5, opacity=0.5)
229
- )
230
- )
231
-
232
- # Selected topics
233
- if topics:
234
- selection = df.loc[(df.topic.isin(topics)), :]
235
- unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()])
236
- else:
237
- unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()])
238
-
239
- for topic in unique_topics:
240
- if topic != -1:
241
- if topics:
242
- selection = df.loc[(df[f"level_{level+1}"] == topic) &
243
- (df.topic.isin(topics)), :]
244
- else:
245
- selection = df.loc[df[f"level_{level+1}"] == topic, :]
246
-
247
- if not hide_annotations:
248
- selection.loc[len(selection), :] = None
249
- selection["text"] = ""
250
- selection.loc[len(selection) - 1, "x"] = selection.x.mean()
251
- selection.loc[len(selection) - 1, "y"] = selection.y.mean()
252
- selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"]
253
-
254
- traces.append(
255
- go.Scattergl(
256
- x=selection.x,
257
- y=selection.y,
258
- text=selection.text if not hide_annotations else None,
259
- hovertext=selection.doc if not hide_document_hover else None,
260
- hoverinfo="text",
261
- name=topic_names[int(topic)]["trace_name"],
262
- mode='markers+text',
263
- marker=dict(size=5, opacity=0.5)
264
- )
265
- )
266
-
267
- all_traces.append(traces)
268
-
269
- # Track and count traces
270
- nr_traces_per_set = [len(traces) for traces in all_traces]
271
- trace_indices = [(0, nr_traces_per_set[0])]
272
- for index, nr_traces in enumerate(nr_traces_per_set[1:]):
273
- start = trace_indices[index][1]
274
- end = nr_traces + start
275
- trace_indices.append((start, end))
276
-
277
- # Visualization
278
- fig = go.Figure()
279
- for traces in all_traces:
280
- for trace in traces:
281
- fig.add_trace(trace)
282
-
283
- for index in range(len(fig.data)):
284
- if index >= nr_traces_per_set[0]:
285
- fig.data[index].visible = False
286
-
287
- # Create and add slider
288
- steps = []
289
- for index, indices in enumerate(trace_indices):
290
- step = dict(
291
- method="update",
292
- label=str(index),
293
- args=[{"visible": [False] * len(fig.data)}]
294
- )
295
- for index in range(indices[1]-indices[0]):
296
- step["args"][0]["visible"][index+indices[0]] = True
297
- steps.append(step)
298
-
299
- sliders = [dict(
300
- currentvalue={"prefix": "Level: "},
301
- pad={"t": 20},
302
- steps=steps
303
- )]
304
-
305
- # Add grid in a 'plus' shape
306
- x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
307
- y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
308
- fig.add_shape(type="line",
309
- x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
310
- line=dict(color="#CFD8DC", width=2))
311
- fig.add_shape(type="line",
312
- x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
313
- line=dict(color="#9E9E9E", width=2))
314
- fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
315
- fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
316
-
317
- # Stylize layout
318
- fig.update_layout(
319
- sliders=sliders,
320
- template="simple_white",
321
- title={
322
- 'text': f"{title}",
323
- 'x': 0.5,
324
- 'xanchor': 'center',
325
- 'yanchor': 'top',
326
- 'font': dict(
327
- size=22,
328
- color="Black")
329
- },
330
- width=width,
331
- height=height,
332
- )
333
-
334
- fig.update_xaxes(visible=False)
335
- fig.update_yaxes(visible=False)
336
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
funcs/bertopic_hierarchical_documents_to_df.py DELETED
@@ -1,250 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import plotly.graph_objects as go
4
- import math
5
-
6
- from umap import UMAP
7
- from typing import List, Union
8
-
9
-
10
- def visualize_hierarchical_documents_to_df(topic_model,
11
- docs: List[str],
12
- hierarchical_topics: pd.DataFrame,
13
- topics: List[int] = None,
14
- embeddings: np.ndarray = None,
15
- reduced_embeddings: np.ndarray = None,
16
- sample: Union[float, int] = None,
17
- hide_annotations: bool = False,
18
- hide_document_hover: bool = True,
19
- nr_levels: int = 10,
20
- level_scale: str = 'linear',
21
- custom_labels: Union[bool, str] = False,
22
- title: str = "<b>Hierarchical Documents and Topics</b>",
23
- width: int = 1200,
24
- height: int = 750) -> go.Figure:
25
- """ Visualize documents and their topics in 2D at different levels of hierarchy
26
-
27
- Arguments:
28
- docs: The documents you used when calling either `fit` or `fit_transform`
29
- hierarchical_topics: A dataframe that contains a hierarchy of topics
30
- represented by their parents and their children
31
- topics: A selection of topics to visualize.
32
- Not to be confused with the topics that you get from `.fit_transform`.
33
- For example, if you want to visualize only topics 1 through 5:
34
- `topics = [1, 2, 3, 4, 5]`.
35
- embeddings: The embeddings of all documents in `docs`.
36
- reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
37
- sample: The percentage of documents in each topic that you would like to keep.
38
- Value can be between 0 and 1. Setting this value to, for example,
39
- 0.1 (10% of documents in each topic) makes it easier to visualize
40
- millions of documents as a subset is chosen.
41
- hide_annotations: Hide the names of the traces on top of each cluster.
42
- hide_document_hover: Hide the content of the documents when hovering over
43
- specific points. Helps to speed up generation of visualizations.
44
- nr_levels: The number of levels to be visualized in the hierarchy. First, the distances
45
- in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances.
46
- Then, for each list of distances, the merged topics are selected that have a
47
- distance less or equal to the maximum distance of the selected list of distances.
48
- NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to
49
- the length of `hierarchical_topics`.
50
- level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance
51
- vector. Linear scaling will perform an equal number of merges at each level
52
- while logarithmic scaling will perform more mergers in earlier levels to
53
- provide more resolution at higher levels (this can be used for when the number
54
- of topics is large).
55
- custom_labels: If bool, whether to use custom topic labels that were defined using
56
- `topic_model.set_topic_labels`.
57
- If `str`, it uses labels from other aspects, e.g., "Aspect1".
58
- NOTE: Custom labels are only generated for the original
59
- un-merged topics.
60
- title: Title of the plot.
61
- width: The width of the figure.
62
- height: The height of the figure.
63
-
64
- Examples:
65
-
66
- To visualize the topics simply run:
67
-
68
- ```python
69
- topic_model.visualize_hierarchical_documents(docs, hierarchical_topics)
70
- ```
71
-
72
- Do note that this re-calculates the embeddings and reduces them to 2D.
73
- The advised and prefered pipeline for using this function is as follows:
74
-
75
- ```python
76
- from sklearn.datasets import fetch_20newsgroups
77
- from sentence_transformers import SentenceTransformer
78
- from bertopic import BERTopic
79
- from umap import UMAP
80
-
81
- # Prepare embeddings
82
- docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
83
- sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
84
- embeddings = sentence_model.encode(docs, show_progress_bar=False)
85
-
86
- # Train BERTopic and extract hierarchical topics
87
- topic_model = BERTopic().fit(docs, embeddings)
88
- hierarchical_topics = topic_model.hierarchical_topics(docs)
89
-
90
- # Reduce dimensionality of embeddings, this step is optional
91
- # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
92
-
93
- # Run the visualization with the original embeddings
94
- topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings)
95
-
96
- # Or, if you have reduced the original embeddings already:
97
- topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
98
- ```
99
-
100
- Or if you want to save the resulting figure:
101
-
102
- ```python
103
- fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
104
- fig.write_html("path/to/file.html")
105
- ```
106
-
107
- NOTE:
108
- This visualization was inspired by the scatter plot representation of Doc2Map:
109
- https://github.com/louisgeisler/Doc2Map
110
-
111
- <iframe src="../../getting_started/visualization/hierarchical_documents.html"
112
- style="width:1000px; height: 770px; border: 0px;""></iframe>
113
- """
114
- topic_per_doc = topic_model.topics_
115
-
116
- # Sample the data to optimize for visualization and dimensionality reduction
117
- if sample is None or sample > 1:
118
- sample = 1
119
-
120
- indices = []
121
- for topic in set(topic_per_doc):
122
- s = np.where(np.array(topic_per_doc) == topic)[0]
123
- size = len(s) if len(s) < 100 else int(len(s)*sample)
124
- indices.extend(np.random.choice(s, size=size, replace=False))
125
- indices = np.array(indices)
126
-
127
- df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
128
- df["doc"] = [docs[index] for index in indices]
129
- df["topic"] = [topic_per_doc[index] for index in indices]
130
-
131
- # Extract embeddings if not already done
132
- if sample is None:
133
- if embeddings is None and reduced_embeddings is None:
134
- embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
135
- else:
136
- embeddings_to_reduce = embeddings
137
- else:
138
- if embeddings is not None:
139
- embeddings_to_reduce = embeddings[indices]
140
- elif embeddings is None and reduced_embeddings is None:
141
- embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
142
-
143
- # Reduce input embeddings
144
- if reduced_embeddings is None:
145
- umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
146
- embeddings_2d = umap_model.embedding_
147
- elif sample is not None and reduced_embeddings is not None:
148
- embeddings_2d = reduced_embeddings[indices]
149
- elif sample is None and reduced_embeddings is not None:
150
- embeddings_2d = reduced_embeddings
151
-
152
- # Combine data
153
- df["x"] = embeddings_2d[:, 0]
154
- df["y"] = embeddings_2d[:, 1]
155
-
156
- # Create topic list for each level, levels are created by calculating the distance
157
- distances = hierarchical_topics.Distance.to_list()
158
- if level_scale == 'log' or level_scale == 'logarithmic':
159
- log_indices = np.round(np.logspace(start=math.log(1,10), stop=math.log(len(distances)-1,10), num=nr_levels)).astype(int).tolist()
160
- log_indices.reverse()
161
- max_distances = [distances[i] for i in log_indices]
162
- elif level_scale == 'lin' or level_scale == 'linear':
163
- max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1]
164
- else:
165
- raise ValueError("level_scale needs to be one of 'log' or 'linear'")
166
-
167
- for index, max_distance in enumerate(max_distances):
168
-
169
- # Get topics below `max_distance`
170
- mapping = {topic: topic for topic in df.topic.unique()}
171
- selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :]
172
- selection.Parent_ID = selection.Parent_ID.astype(int)
173
- selection = selection.sort_values("Parent_ID")
174
-
175
- for row in selection.iterrows():
176
- for topic in row[1].Topics:
177
- mapping[topic] = row[1].Parent_ID
178
-
179
- # Make sure the mappings are mapped 1:1
180
- mappings = [True for _ in mapping]
181
- while any(mappings):
182
- for i, (key, value) in enumerate(mapping.items()):
183
- if value in mapping.keys() and key != value:
184
- mapping[key] = mapping[value]
185
- else:
186
- mappings[i] = False
187
-
188
- # Create new column
189
- df[f"level_{index+1}"] = df.topic.map(mapping)
190
- df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int)
191
-
192
- # Prepare topic names of original and merged topics
193
- trace_names = []
194
- topic_names = {}
195
- for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
196
- if topic < hierarchical_topics.Parent_ID.astype(int).min():
197
- if topic_model.get_topic(topic):
198
- if isinstance(custom_labels, str):
199
- trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
200
- elif topic_model.custom_labels_ is not None and custom_labels:
201
- trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
202
- else:
203
- trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3])
204
- topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]}
205
- trace_names.append(trace_name)
206
- else:
207
- trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
208
- plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]])
209
- topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]}
210
- trace_names.append(trace_name)
211
-
212
- # Prepare traces
213
- all_traces = []
214
- for level in range(len(max_distances)):
215
- traces = []
216
-
217
- # Selected topics
218
- if topics:
219
- selection = df.loc[(df.topic.isin(topics)), :]
220
- unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()])
221
- else:
222
- unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()])
223
-
224
- for topic in unique_topics:
225
- if topic != -1:
226
- if topics:
227
- selection = df.loc[(df[f"level_{level+1}"] == topic) &
228
- (df.topic.isin(topics)), :]
229
- else:
230
- selection = df.loc[df[f"level_{level+1}"] == topic, :]
231
-
232
- if not hide_annotations:
233
- selection.loc[len(selection), :] = None
234
- selection["text"] = ""
235
- selection.loc[len(selection) - 1, "x"] = selection.x.mean()
236
- selection.loc[len(selection) - 1, "y"] = selection.y.mean()
237
- selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"]
238
-
239
- all_traces.append(traces)
240
-
241
- # Track and count traces
242
- nr_traces_per_set = [len(traces) for traces in all_traces]
243
- trace_indices = [(0, nr_traces_per_set[0])]
244
- for index, nr_traces in enumerate(nr_traces_per_set[1:]):
245
- start = trace_indices[index][1]
246
- end = nr_traces + start
247
- trace_indices.append((start, end))
248
-
249
-
250
- return all_traces, selection, df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
funcs/bertopic_vis_documents.py CHANGED
@@ -1,10 +1,23 @@
1
  import numpy as np
2
  import pandas as pd
 
3
  import plotly.graph_objects as go
4
  from plotly.subplots import make_subplots
5
 
 
 
 
 
 
 
6
  from umap import UMAP
7
- from typing import List, Union
 
 
 
 
 
 
8
 
9
  import itertools
10
  import numpy as np
@@ -23,7 +36,7 @@ def visualize_documents_custom(topic_model,
23
  custom_labels: Union[bool, str] = False,
24
  title: str = "<b>Documents and Topics</b>",
25
  width: int = 1200,
26
- height: int = 750):
27
  """ Visualize documents and their topics in 2D
28
 
29
  Arguments:
@@ -164,9 +177,9 @@ def visualize_documents_custom(topic_model,
164
  names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
165
  else:
166
  print("Not using custom labels")
167
- names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
168
 
169
- print(names)
170
 
171
  # Visualize
172
  fig = go.Figure()
@@ -254,6 +267,350 @@ def visualize_documents_custom(topic_model,
254
  fig.update_yaxes(visible=False)
255
  return fig
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  def visualize_hierarchical_documents_custom(topic_model,
258
  docs: List[str],
259
  hover_labels: List[str],
@@ -269,7 +626,7 @@ def visualize_hierarchical_documents_custom(topic_model,
269
  custom_labels: Union[bool, str] = False,
270
  title: str = "<b>Hierarchical Documents and Topics</b>",
271
  width: int = 1200,
272
- height: int = 750) -> go.Figure:
273
  """ Visualize documents and their topics in 2D at different levels of hierarchy
274
 
275
  Arguments:
@@ -455,21 +812,22 @@ def visualize_hierarchical_documents_custom(topic_model,
455
  # Prepare topic names of original and merged topics
456
  trace_names = []
457
  topic_names = {}
 
458
  for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
459
  if topic < hierarchical_topics.Parent_ID.astype(int).min():
460
  if topic_model.get_topic(topic):
461
  if isinstance(custom_labels, str):
462
- trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
463
  elif topic_model.custom_labels_ is not None and custom_labels:
464
  trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
465
  else:
466
- trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3])
467
- topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]}
468
  trace_names.append(trace_name)
469
  else:
470
- trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
471
- plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]])
472
- topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]}
473
  trace_names.append(trace_name)
474
 
475
  # Prepare traces
@@ -598,7 +956,13 @@ def visualize_hierarchical_documents_custom(topic_model,
598
 
599
  fig.update_xaxes(visible=False)
600
  fig.update_yaxes(visible=False)
601
- return fig
 
 
 
 
 
 
602
 
603
  def visualize_barchart_custom(topic_model,
604
  topics: List[int] = None,
@@ -607,7 +971,7 @@ def visualize_barchart_custom(topic_model,
607
  custom_labels: Union[bool, str] = False,
608
  title: str = "<b>Topic Word Scores</b>",
609
  width: int = 250,
610
- height: int = 250) -> go.Figure:
611
  """ Visualize a barchart of selected topics
612
 
613
  Arguments:
 
1
  import numpy as np
2
  import pandas as pd
3
+ import gradio as gr
4
  import plotly.graph_objects as go
5
  from plotly.subplots import make_subplots
6
 
7
+ from bertopic._utils import check_documents_type, validate_distance_matrix
8
+ from bertopic.plotting._hierarchy import _get_annotations
9
+ import plotly.figure_factory as ff
10
+ from packaging import version
11
+
12
+ import math
13
  from umap import UMAP
14
+ from typing import List, Union, Callable
15
+
16
+ from scipy.sparse import csr_matrix
17
+ from scipy.cluster import hierarchy as sch
18
+ from sklearn.metrics.pairwise import cosine_similarity
19
+ from sklearn import __version__ as sklearn_version
20
+ from tqdm import tqdm
21
 
22
  import itertools
23
  import numpy as np
 
36
  custom_labels: Union[bool, str] = False,
37
  title: str = "<b>Documents and Topics</b>",
38
  width: int = 1200,
39
+ height: int = 750, progress=gr.Progress(track_tqdm=True)):
40
  """ Visualize documents and their topics in 2D
41
 
42
  Arguments:
 
177
  names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
178
  else:
179
  print("Not using custom labels")
180
+ names = [f"{topic} " + ", ".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
181
 
182
+ #print(names)
183
 
184
  # Visualize
185
  fig = go.Figure()
 
267
  fig.update_yaxes(visible=False)
268
  return fig
269
 
270
+ def hierarchical_topics_custom(self,
271
+ docs: List[str],
272
+ linkage_function: Callable[[csr_matrix], np.ndarray] = None,
273
+ distance_function: Callable[[csr_matrix], csr_matrix] = None, progress=gr.Progress(track_tqdm=True)) -> pd.DataFrame:
274
+ """ Create a hierarchy of topics
275
+
276
+ To create this hierarchy, BERTopic needs to be already fitted once.
277
+ Then, a hierarchy is calculated on the distance matrix of the c-TF-IDF
278
+ representation using `scipy.cluster.hierarchy.linkage`.
279
+
280
+ Based on that hierarchy, we calculate the topic representation at each
281
+ merged step. This is a local representation, as we only assume that the
282
+ chosen step is merged and not all others which typically improves the
283
+ topic representation.
284
+
285
+ Arguments:
286
+ docs: The documents you used when calling either `fit` or `fit_transform`
287
+ linkage_function: The linkage function to use. Default is:
288
+ `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
289
+ distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
290
+ `lambda x: 1 - cosine_similarity(x)`.
291
+ You can pass any function that returns either a square matrix of
292
+ shape (n_samples, n_samples) with zeros on the diagonal and
293
+ non-negative values or condensed distance matrix of shape
294
+ (n_samples * (n_samples - 1) / 2,) containing the upper
295
+ triangular of the distance matrix.
296
+
297
+ Returns:
298
+ hierarchical_topics: A dataframe that contains a hierarchy of topics
299
+ represented by their parents and their children
300
+
301
+ Examples:
302
+
303
+ ```python
304
+ from bertopic import BERTopic
305
+ topic_model = BERTopic()
306
+ topics, probs = topic_model.fit_transform(docs)
307
+ hierarchical_topics = topic_model.hierarchical_topics(docs)
308
+ ```
309
+
310
+ A custom linkage function can be used as follows:
311
+
312
+ ```python
313
+ from scipy.cluster import hierarchy as sch
314
+ from bertopic import BERTopic
315
+ topic_model = BERTopic()
316
+ topics, probs = topic_model.fit_transform(docs)
317
+
318
+ # Hierarchical topics
319
+ linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True)
320
+ hierarchical_topics = topic_model.hierarchical_topics(docs, linkage_function=linkage_function)
321
+ ```
322
+ """
323
+ check_documents_type(docs)
324
+ if distance_function is None:
325
+ distance_function = lambda x: 1 - cosine_similarity(x)
326
+
327
+ if linkage_function is None:
328
+ linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True)
329
+
330
+ # Calculate distance
331
+ embeddings = self.c_tf_idf_[self._outliers:]
332
+ X = distance_function(embeddings)
333
+ X = validate_distance_matrix(X, embeddings.shape[0])
334
+
335
+ # Use the 1-D condensed distance matrix as an input instead of the raw distance matrix
336
+ Z = linkage_function(X)
337
+
338
+ # Calculate basic bag-of-words to be iteratively merged later
339
+ documents = pd.DataFrame({"Document": docs,
340
+ "ID": range(len(docs)),
341
+ "Topic": self.topics_})
342
+ documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
343
+ documents_per_topic = documents_per_topic.loc[documents_per_topic.Topic != -1, :]
344
+ clean_documents = self._preprocess_text(documents_per_topic.Document.values)
345
+
346
+ # Scikit-Learn Deprecation: get_feature_names is deprecated in 1.0
347
+ # and will be removed in 1.2. Please use get_feature_names_out instead.
348
+ if version.parse(sklearn_version) >= version.parse("1.0.0"):
349
+ words = self.vectorizer_model.get_feature_names_out()
350
+ else:
351
+ words = self.vectorizer_model.get_feature_names()
352
+
353
+ bow = self.vectorizer_model.transform(clean_documents)
354
+
355
+ # Extract clusters
356
+ hier_topics = pd.DataFrame(columns=["Parent_ID", "Parent_Name", "Topics",
357
+ "Child_Left_ID", "Child_Left_Name",
358
+ "Child_Right_ID", "Child_Right_Name"])
359
+ for index in tqdm(range(len(Z))):
360
+
361
+ # Find clustered documents
362
+ clusters = sch.fcluster(Z, t=Z[index][2], criterion='distance') - self._outliers
363
+ nr_clusters = len(clusters)
364
+
365
+ # Extract first topic we find to get the set of topics in a merged topic
366
+ topic = None
367
+ val = Z[index][0]
368
+ while topic is None:
369
+ if val - len(clusters) < 0:
370
+ topic = int(val)
371
+ else:
372
+ val = Z[int(val - len(clusters))][0]
373
+ clustered_topics = [i for i, x in enumerate(clusters) if x == clusters[topic]]
374
+
375
+ # Group bow per cluster, calculate c-TF-IDF and extract words
376
+ grouped = csr_matrix(bow[clustered_topics].sum(axis=0))
377
+ c_tf_idf = self.ctfidf_model.transform(grouped)
378
+ selection = documents.loc[documents.Topic.isin(clustered_topics), :]
379
+ selection.Topic = 0
380
+ words_per_topic = self._extract_words_per_topic(words, selection, c_tf_idf, calculate_aspects=False)
381
+
382
+ # Extract parent's name and ID
383
+ parent_id = index + len(clusters)
384
+ parent_name = ", ".join([x[0] for x in words_per_topic[0]][:5])
385
+
386
+ # Extract child's name and ID
387
+ Z_id = Z[index][0]
388
+ child_left_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters
389
+
390
+ if Z_id - nr_clusters < 0:
391
+ child_left_name = ", ".join([x[0] for x in self.get_topic(Z_id)][:5])
392
+ else:
393
+ child_left_name = hier_topics.iloc[int(child_left_id)].Parent_Name
394
+
395
+ # Extract child's name and ID
396
+ Z_id = Z[index][1]
397
+ child_right_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters
398
+
399
+ if Z_id - nr_clusters < 0:
400
+ child_right_name = ", ".join([x[0] for x in self.get_topic(Z_id)][:5])
401
+ else:
402
+ child_right_name = hier_topics.iloc[int(child_right_id)].Parent_Name
403
+
404
+ # Save results
405
+ hier_topics.loc[len(hier_topics), :] = [parent_id, parent_name,
406
+ clustered_topics,
407
+ int(Z[index][0]), child_left_name,
408
+ int(Z[index][1]), child_right_name]
409
+
410
+ hier_topics["Distance"] = Z[:, 2]
411
+ hier_topics = hier_topics.sort_values("Parent_ID", ascending=False)
412
+ hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]] = hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]].astype(str)
413
+
414
+ return hier_topics
415
+
416
+ def visualize_hierarchy_custom(topic_model,
417
+ orientation: str = "left",
418
+ topics: List[int] = None,
419
+ top_n_topics: int = None,
420
+ custom_labels: Union[bool, str] = False,
421
+ title: str = "<b>Hierarchical Clustering</b>",
422
+ width: int = 1000,
423
+ height: int = 600,
424
+ hierarchical_topics: pd.DataFrame = None,
425
+ linkage_function: Callable[[csr_matrix], np.ndarray] = None,
426
+ distance_function: Callable[[csr_matrix], csr_matrix] = None,
427
+ color_threshold: int = 1) -> go.Figure:
428
+ """ Visualize a hierarchical structure of the topics
429
+
430
+ A ward linkage function is used to perform the
431
+ hierarchical clustering based on the cosine distance
432
+ matrix between topic embeddings.
433
+
434
+ Arguments:
435
+ topic_model: A fitted BERTopic instance.
436
+ orientation: The orientation of the figure.
437
+ Either 'left' or 'bottom'
438
+ topics: A selection of topics to visualize
439
+ top_n_topics: Only select the top n most frequent topics
440
+ custom_labels: If bool, whether to use custom topic labels that were defined using
441
+ `topic_model.set_topic_labels`.
442
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
443
+ NOTE: Custom labels are only generated for the original
444
+ un-merged topics.
445
+ title: Title of the plot.
446
+ width: The width of the figure. Only works if orientation is set to 'left'
447
+ height: The height of the figure. Only works if orientation is set to 'bottom'
448
+ hierarchical_topics: A dataframe that contains a hierarchy of topics
449
+ represented by their parents and their children.
450
+ NOTE: The hierarchical topic names are only visualized
451
+ if both `topics` and `top_n_topics` are not set.
452
+ linkage_function: The linkage function to use. Default is:
453
+ `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
454
+ NOTE: Make sure to use the same `linkage_function` as used
455
+ in `topic_model.hierarchical_topics`.
456
+ distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
457
+ `lambda x: 1 - cosine_similarity(x)`.
458
+ You can pass any function that returns either a square matrix of
459
+ shape (n_samples, n_samples) with zeros on the diagonal and
460
+ non-negative values or condensed distance matrix of shape
461
+ (n_samples * (n_samples - 1) / 2,) containing the upper
462
+ triangular of the distance matrix.
463
+ NOTE: Make sure to use the same `distance_function` as used
464
+ in `topic_model.hierarchical_topics`.
465
+ color_threshold: Value at which the separation of clusters will be made which
466
+ will result in different colors for different clusters.
467
+ A higher value will typically lead in less colored clusters.
468
+
469
+ Returns:
470
+ fig: A plotly figure
471
+
472
+ Examples:
473
+
474
+ To visualize the hierarchical structure of
475
+ topics simply run:
476
+
477
+ ```python
478
+ topic_model.visualize_hierarchy()
479
+ ```
480
+
481
+ If you also want the labels visualized of hierarchical topics,
482
+ run the following:
483
+
484
+ ```python
485
+ # Extract hierarchical topics and their representations
486
+ hierarchical_topics = topic_model.hierarchical_topics(docs)
487
+
488
+ # Visualize these representations
489
+ topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics)
490
+ ```
491
+
492
+ If you want to save the resulting figure:
493
+
494
+ ```python
495
+ fig = topic_model.visualize_hierarchy()
496
+ fig.write_html("path/to/file.html")
497
+ ```
498
+ <iframe src="../../getting_started/visualization/hierarchy.html"
499
+ style="width:1000px; height: 680px; border: 0px;""></iframe>
500
+ """
501
+ if distance_function is None:
502
+ distance_function = lambda x: 1 - cosine_similarity(x)
503
+
504
+ if linkage_function is None:
505
+ linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True)
506
+
507
+ # Select topics based on top_n and topics args
508
+ freq_df = topic_model.get_topic_freq()
509
+ freq_df = freq_df.loc[freq_df.Topic != -1, :]
510
+ if topics is not None:
511
+ topics = list(topics)
512
+ elif top_n_topics is not None:
513
+ topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
514
+ else:
515
+ topics = sorted(freq_df.Topic.to_list())
516
+
517
+ # Select embeddings
518
+ all_topics = sorted(list(topic_model.get_topics().keys()))
519
+ indices = np.array([all_topics.index(topic) for topic in topics])
520
+
521
+ # Select topic embeddings
522
+ if topic_model.c_tf_idf_ is not None:
523
+ embeddings = topic_model.c_tf_idf_[indices]
524
+ else:
525
+ embeddings = np.array(topic_model.topic_embeddings_)[indices]
526
+
527
+ # Annotations
528
+ if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()):
529
+ annotations = _get_annotations(topic_model=topic_model,
530
+ hierarchical_topics=hierarchical_topics,
531
+ embeddings=embeddings,
532
+ distance_function=distance_function,
533
+ linkage_function=linkage_function,
534
+ orientation=orientation,
535
+ custom_labels=custom_labels)
536
+ else:
537
+ annotations = None
538
+
539
+ # wrap distance function to validate input and return a condensed distance matrix
540
+ distance_function_viz = lambda x: validate_distance_matrix(
541
+ distance_function(x), embeddings.shape[0])
542
+ # Create dendogram
543
+ fig = ff.create_dendrogram(embeddings,
544
+ orientation=orientation,
545
+ distfun=distance_function_viz,
546
+ linkagefun=linkage_function,
547
+ hovertext=annotations,
548
+ color_threshold=color_threshold)
549
+
550
+ # Create nicer labels
551
+ axis = "yaxis" if orientation == "left" else "xaxis"
552
+ if isinstance(custom_labels, str):
553
+ new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]]
554
+ new_labels = [", ".join([label[0] for label in labels[:4]]) for labels in new_labels]
555
+ new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
556
+ elif topic_model.custom_labels_ is not None and custom_labels:
557
+ new_labels = [topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]]
558
+ else:
559
+ new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)])
560
+ for x in fig.layout[axis]["ticktext"]]
561
+ new_labels = [", ".join([label[0] for label in labels[:4]]) for labels in new_labels]
562
+ new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
563
+
564
+ # Stylize layout
565
+ fig.update_layout(
566
+ plot_bgcolor='#ECEFF1',
567
+ template="plotly_white",
568
+ title={
569
+ 'text': f"{title}",
570
+ 'x': 0.5,
571
+ 'xanchor': 'center',
572
+ 'yanchor': 'top',
573
+ 'font': dict(
574
+ size=22,
575
+ color="Black")
576
+ },
577
+ hoverlabel=dict(
578
+ bgcolor="white",
579
+ font_size=16,
580
+ font_family="Rockwell"
581
+ ),
582
+ )
583
+
584
+ # Stylize orientation
585
+ if orientation == "left":
586
+ fig.update_layout(height=200 + (15 * len(topics)),
587
+ width=width,
588
+ yaxis=dict(tickmode="array",
589
+ ticktext=new_labels))
590
+
591
+ # Fix empty space on the bottom of the graph
592
+ y_max = max([trace['y'].max() + 5 for trace in fig['data']])
593
+ y_min = min([trace['y'].min() - 5 for trace in fig['data']])
594
+ fig.update_layout(yaxis=dict(range=[y_min, y_max]))
595
+
596
+ else:
597
+ fig.update_layout(width=200 + (15 * len(topics)),
598
+ height=height,
599
+ xaxis=dict(tickmode="array",
600
+ ticktext=new_labels))
601
+
602
+ if hierarchical_topics is not None:
603
+ for index in [0, 3]:
604
+ axis = "x" if orientation == "left" else "y"
605
+ xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
606
+ ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
607
+ hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
608
+
609
+ fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black',
610
+ hovertext=hovertext, hoverinfo="text",
611
+ mode='markers', showlegend=False))
612
+ return fig
613
+
614
  def visualize_hierarchical_documents_custom(topic_model,
615
  docs: List[str],
616
  hover_labels: List[str],
 
626
  custom_labels: Union[bool, str] = False,
627
  title: str = "<b>Hierarchical Documents and Topics</b>",
628
  width: int = 1200,
629
+ height: int = 750, progress=gr.Progress(track_tqdm=True)) -> go.Figure:
630
  """ Visualize documents and their topics in 2D at different levels of hierarchy
631
 
632
  Arguments:
 
812
  # Prepare topic names of original and merged topics
813
  trace_names = []
814
  topic_names = {}
815
+ trace_name_char_length = 60
816
  for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
817
  if topic < hierarchical_topics.Parent_ID.astype(int).min():
818
  if topic_model.get_topic(topic):
819
  if isinstance(custom_labels, str):
820
+ trace_name = f"{topic} " + ", ".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:5])
821
  elif topic_model.custom_labels_ is not None and custom_labels:
822
  trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
823
  else:
824
+ trace_name = f"{topic} " + ", ".join([word[:20] for word, _ in topic_model.get_topic(topic)][:5])
825
+ topic_names[topic] = {"trace_name": trace_name[:trace_name_char_length], "plot_text": trace_name[:trace_name_char_length]}
826
  trace_names.append(trace_name)
827
  else:
828
+ trace_name = f"{topic} " + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
829
+ plot_text = ", ".join([name[:20] for name in trace_name.split(" ")[:5]])
830
+ topic_names[topic] = {"trace_name": trace_name[:trace_name_char_length], "plot_text": plot_text[:trace_name_char_length]}
831
  trace_names.append(trace_name)
832
 
833
  # Prepare traces
 
956
 
957
  fig.update_xaxes(visible=False)
958
  fig.update_yaxes(visible=False)
959
+
960
+ hierarchy_topics_df = df.filter(regex=r'topic|^level').drop_duplicates(subset="topic")
961
+
962
+ topic_names = pd.DataFrame(topic_names).T
963
+
964
+
965
+ return fig, hierarchy_topics_df, topic_names
966
 
967
  def visualize_barchart_custom(topic_model,
968
  topics: List[int] = None,
 
971
  custom_labels: Union[bool, str] = False,
972
  title: str = "<b>Topic Word Scores</b>",
973
  width: int = 250,
974
+ height: int = 250, progress=gr.Progress(track_tqdm=True)) -> go.Figure:
975
  """ Visualize a barchart of selected topics
976
 
977
  Arguments:
funcs/clean_funcs.py CHANGED
@@ -33,18 +33,19 @@ multiple_spaces_regex = r'\s{2,}'
33
 
34
  def initial_clean(texts, custom_regex, progress=gr.Progress()):
35
  texts = pl.Series(texts).str.strip_chars()
36
- text = texts.str.replace_all(html_pattern_regex, '')
37
- text = text.str.replace_all(email_pattern_regex, '')
38
- text = text.str.replace_all(nums_two_more_regex, '')
39
- text = text.str.replace_all(postcode_pattern_regex, '')
40
- text = text.str.replace_all(multiple_spaces_regex, '')
41
 
42
  # Allow for custom regex patterns to be removed
43
  if len(custom_regex) > 0:
44
  for pattern in custom_regex:
45
  raw_string_pattern = r'{}'.format(pattern)
46
  print("Removing regex pattern: ", raw_string_pattern)
47
- text = text.str.replace_all(raw_string_pattern, '')
 
 
48
 
49
  text = text.to_list()
50
 
 
33
 
34
  def initial_clean(texts, custom_regex, progress=gr.Progress()):
35
  texts = pl.Series(texts).str.strip_chars()
36
+ text = texts.str.replace_all(html_pattern_regex, ' ')
37
+ text = text.str.replace_all(email_pattern_regex, ' ')
38
+ text = text.str.replace_all(nums_two_more_regex, ' ')
39
+ text = text.str.replace_all(postcode_pattern_regex, ' ')
 
40
 
41
  # Allow for custom regex patterns to be removed
42
  if len(custom_regex) > 0:
43
  for pattern in custom_regex:
44
  raw_string_pattern = r'{}'.format(pattern)
45
  print("Removing regex pattern: ", raw_string_pattern)
46
+ text = text.str.replace_all(raw_string_pattern, ' ')
47
+
48
+ text = text.str.replace_all(multiple_spaces_regex, ' ')
49
 
50
  text = text.to_list()
51
 
funcs/topic_core_funcs.py CHANGED
@@ -11,6 +11,8 @@ from bertopic import BERTopic
11
  from funcs.clean_funcs import initial_clean
12
  from funcs.helper_functions import read_file, zip_folder, delete_files_in_folder, save_topic_outputs
13
  from funcs.embeddings import make_or_load_embeddings
 
 
14
 
15
  from sentence_transformers import SentenceTransformer
16
  from sklearn.pipeline import make_pipeline
@@ -145,9 +147,7 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
145
  if not in_colnames:
146
  error_message = "Please enter one column name to use for cleaning and finding topics."
147
  print(error_message)
148
- return error_message, None, data_file_name_no_ext, embeddings_out, None, None
149
-
150
-
151
 
152
  in_colnames_list_first = in_colnames[0]
153
 
@@ -186,7 +186,9 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
186
 
187
  embeddings_type_state = "tfidf"
188
 
189
- umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
 
 
190
 
191
  embeddings_out = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
192
 
@@ -195,7 +197,7 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
195
 
196
  progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
197
 
198
- fail_error_message = "Topic model creation failed. Try reducing minimum documents per topic on the slider above (try 15 or less), then click 'Extract topics' again."
199
 
200
  if not candidate_topics:
201
 
@@ -217,10 +219,11 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
217
  topics_probs_out.to_csv(topics_probs_out_name)
218
  output_list.append(topics_probs_out_name)
219
 
220
- except:
 
221
  print(fail_error_message)
222
 
223
- return fail_error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs, vectoriser_model
224
 
225
 
226
  # Do this if you have pre-defined topics
@@ -229,7 +232,7 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
229
  error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
230
  print(error_message)
231
 
232
- return error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs, vectoriser_model
233
 
234
  zero_shot_topics = read_file(candidate_topics.name)
235
  zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
@@ -254,17 +257,21 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
254
  topics_probs_out.to_csv(topics_probs_out_name)
255
  output_list.append(topics_probs_out_name)
256
 
257
- except:
 
258
  print(fail_error_message)
259
 
260
- return fail_error_message, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, docs, vectoriser_model
261
 
262
  # For some reason, zero topic modelling exports assigned topics as a np.array instead of a list. Converting it back here.
263
  if isinstance(assigned_topics, np.ndarray):
264
  assigned_topics = assigned_topics.tolist()
265
 
266
- # Zero shot modelling is a model merge, which wipes the c_tf_idf part of the resulting model completely. To get hierarchical modelling to work, we need to recreate this part of the model with the CountVectorizer options used to create the initial model. Since with zero shot, we are merging two models that have exactly the same set of documents, the vocubulary should be the same, and so recreating the cf_tf_idf component in this way shouldn't be a problem. Discussion here, and below based on Maarten's suggested code: https://github.com/MaartenGr/BERTopic/issues/1700
 
 
267
 
 
268
  doc_dets = topic_model.get_document_info(docs)
269
 
270
  documents_per_topic = doc_dets.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
@@ -277,13 +284,19 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
277
  c_tf_idf, _ = topic_model._c_tf_idf(documents_per_topic)
278
  topic_model.c_tf_idf_ = c_tf_idf
279
 
 
 
 
 
280
  if not assigned_topics:
281
- # Handle the empty array case
282
- return "No topics found.", output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, topic_model, docs
283
-
284
  else:
285
  print("Topic model created.")
286
 
 
 
 
 
287
  # Replace current topic labels if new ones loaded in
288
  if not custom_labels_df.empty:
289
  #custom_label_list = list(custom_labels_df.iloc[:,0])
@@ -315,9 +328,9 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
315
  time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
316
  print(time_out)
317
 
318
- return output_text, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, topic_model, docs, vectoriser_model
319
 
320
- def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, save_topic_model, progress=gr.Progress(track_tqdm=True)):
321
 
322
  progress(0, desc= "Preparing data")
323
 
@@ -325,7 +338,8 @@ def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, sa
325
 
326
  all_tic = time.perf_counter()
327
 
328
- assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
 
329
 
330
  if isinstance(assigned_topics, np.ndarray):
331
  assigned_topics = assigned_topics.tolist()
@@ -339,7 +353,12 @@ def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, sa
339
  # Then, update the topics to the ones that considered the new data
340
 
341
  progress(0.6, desc= "Updating original model")
342
- topic_model.update_topics(docs, topics=assigned_topics)
 
 
 
 
 
343
 
344
  print("Finished reducing outliers.")
345
 
@@ -375,7 +394,7 @@ def represent_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode
375
 
376
  representation_model = create_representation_model(representation_type, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
377
 
378
- progress(0.6, desc= "Updating existing topics")
379
  topic_model.update_topics(docs, vectorizer_model=vectoriser_model, representation_model=representation_model)
380
 
381
  topic_dets = topic_model.get_topic_info()
@@ -394,8 +413,7 @@ def represent_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode
394
  else:
395
  new_topic_labels = topic_model.generate_topic_labels(nr_words=3, separator=", ", aspect = representation_type)
396
 
397
- topic_model.set_topic_labels(new_topic_labels)#list(topic_dets[representation_type]))
398
- #topic_model.set_topic_labels(list(topic_dets["Name"]))
399
 
400
  # Outputs
401
  progress(0.8, desc= "Saving outputs")
@@ -414,8 +432,7 @@ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode
414
  output_list = []
415
  vis_tic = time.perf_counter()
416
 
417
- from funcs.bertopic_vis_documents import visualize_documents_custom, visualize_hierarchical_documents_custom, visualize_barchart_custom
418
-
419
  if not visualisation_type_radio:
420
  return "Please choose a visualisation type above.", output_list, None, None
421
 
@@ -475,7 +492,7 @@ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode
475
 
476
  elif visualisation_type_radio == "Hierarchical view":
477
 
478
- hierarchical_topics = topic_model.hierarchical_topics(docs)
479
 
480
  # Print topic tree
481
  tree = topic_model.get_topic_tree(hierarchical_topics, tight_layout = True)
@@ -488,16 +505,28 @@ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode
488
  output_list.append(tree_name)
489
 
490
  # Save new hierarchical topic model to file
491
- hierarchical_topics_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topics_' + today_rev + '.csv'
492
  hierarchical_topics.to_csv(hierarchical_topics_name)
493
  output_list.append(hierarchical_topics_name)
494
 
495
- try:
496
- topics_vis = visualize_hierarchical_documents_custom(topic_model, docs, label_list, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
497
- topics_vis_2 = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics, width= 1200, height = 750)
498
- except:
499
- error_message = "Visualisation preparation failed. Perhaps you need more topics to create the full hierarchy (more than 10)?"
500
- return error_message, output_list, None, None
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
503
  topics_vis.write_html(topics_vis_name)
 
11
  from funcs.clean_funcs import initial_clean
12
  from funcs.helper_functions import read_file, zip_folder, delete_files_in_folder, save_topic_outputs
13
  from funcs.embeddings import make_or_load_embeddings
14
+ from funcs.bertopic_vis_documents import visualize_documents_custom, visualize_hierarchical_documents_custom, hierarchical_topics_custom, visualize_hierarchy_custom
15
+
16
 
17
  from sentence_transformers import SentenceTransformer
18
  from sklearn.pipeline import make_pipeline
 
147
  if not in_colnames:
148
  error_message = "Please enter one column name to use for cleaning and finding topics."
149
  print(error_message)
150
+ return error_message, None, data_file_name_no_ext, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, None, vectoriser_state, []
 
 
151
 
152
  in_colnames_list_first = in_colnames[0]
153
 
 
186
 
187
  embeddings_type_state = "tfidf"
188
 
189
+ #umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
190
+ # UMAP model uses Bertopic defaults
191
+ umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', low_memory=True, random_state=random_seed)
192
 
193
  embeddings_out = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
194
 
 
197
 
198
  progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
199
 
200
+ fail_error_message = "Topic model creation failed. Try reducing minimum documents per topic on the slider above (try 15 or less), then click 'Extract topics' again. If that doesn't work, try running the first two clean steps on your data first (see Clean data above) to ensure there are no NaNs/missing texts in your data."
201
 
202
  if not candidate_topics:
203
 
 
219
  topics_probs_out.to_csv(topics_probs_out_name)
220
  output_list.append(topics_probs_out_name)
221
 
222
+ except Exception as error:
223
+ print(error)
224
  print(fail_error_message)
225
 
226
+ return fail_error_message, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, docs, vectoriser_model, []
227
 
228
 
229
  # Do this if you have pre-defined topics
 
232
  error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
233
  print(error_message)
234
 
235
+ return error_message, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, docs, vectoriser_model, []
236
 
237
  zero_shot_topics = read_file(candidate_topics.name)
238
  zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
 
257
  topics_probs_out.to_csv(topics_probs_out_name)
258
  output_list.append(topics_probs_out_name)
259
 
260
+ except Exception as error:
261
+ print("An exception occurred:", error)
262
  print(fail_error_message)
263
 
264
+ return fail_error_message, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, docs, vectoriser_model, []
265
 
266
  # For some reason, zero topic modelling exports assigned topics as a np.array instead of a list. Converting it back here.
267
  if isinstance(assigned_topics, np.ndarray):
268
  assigned_topics = assigned_topics.tolist()
269
 
270
+
271
+
272
+ # Zero shot modelling is a model merge, which wipes the c_tf_idf part of the resulting model completely. To get hierarchical modelling to work, we need to recreate this part of the model with the CountVectorizer options used to create the initial model. Since with zero shot, we are merging two models that have exactly the same set of documents, the vocubulary should be the same, and so recreating the cf_tf_idf component in this way shouldn't be a problem. Discussion here, and below based on Maarten's suggested code: https://github.com/MaartenGr/BERTopic/issues/1700
273
 
274
+ # Get document info
275
  doc_dets = topic_model.get_document_info(docs)
276
 
277
  documents_per_topic = doc_dets.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
 
284
  c_tf_idf, _ = topic_model._c_tf_idf(documents_per_topic)
285
  topic_model.c_tf_idf_ = c_tf_idf
286
 
287
+ ###
288
+
289
+
290
+ # Check we have topics
291
  if not assigned_topics:
292
+ return "No topics found.", output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, topic_model, docs, vectoriser_model,[]
 
 
293
  else:
294
  print("Topic model created.")
295
 
296
+ # Tidy up topic label format a bit to have commas and spaces by default
297
+ new_topic_labels = topic_model.generate_topic_labels(nr_words=3, separator=", ")
298
+ topic_model.set_topic_labels(new_topic_labels)
299
+
300
  # Replace current topic labels if new ones loaded in
301
  if not custom_labels_df.empty:
302
  #custom_label_list = list(custom_labels_df.iloc[:,0])
 
328
  time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
329
  print(time_out)
330
 
331
+ return output_text, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, topic_model, docs, vectoriser_model, assigned_topics
332
 
333
+ def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, assigned_topics, vectoriser_model, save_topic_model, progress=gr.Progress(track_tqdm=True)):
334
 
335
  progress(0, desc= "Preparing data")
336
 
 
338
 
339
  all_tic = time.perf_counter()
340
 
341
+ # This step not necessary?
342
+ #assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
343
 
344
  if isinstance(assigned_topics, np.ndarray):
345
  assigned_topics = assigned_topics.tolist()
 
353
  # Then, update the topics to the ones that considered the new data
354
 
355
  progress(0.6, desc= "Updating original model")
356
+
357
+ topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model = vectoriser_model)
358
+
359
+ # Tidy up topic label format a bit to have commas and spaces by default
360
+ new_topic_labels = topic_model.generate_topic_labels(nr_words=3, separator=", ")
361
+ topic_model.set_topic_labels(new_topic_labels)
362
 
363
  print("Finished reducing outliers.")
364
 
 
394
 
395
  representation_model = create_representation_model(representation_type, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
396
 
397
+ progress(0.3, desc= "Updating existing topics")
398
  topic_model.update_topics(docs, vectorizer_model=vectoriser_model, representation_model=representation_model)
399
 
400
  topic_dets = topic_model.get_topic_info()
 
413
  else:
414
  new_topic_labels = topic_model.generate_topic_labels(nr_words=3, separator=", ", aspect = representation_type)
415
 
416
+ topic_model.set_topic_labels(new_topic_labels)
 
417
 
418
  # Outputs
419
  progress(0.8, desc= "Saving outputs")
 
432
  output_list = []
433
  vis_tic = time.perf_counter()
434
 
435
+
 
436
  if not visualisation_type_radio:
437
  return "Please choose a visualisation type above.", output_list, None, None
438
 
 
492
 
493
  elif visualisation_type_radio == "Hierarchical view":
494
 
495
+ hierarchical_topics = hierarchical_topics_custom(topic_model, docs)
496
 
497
  # Print topic tree
498
  tree = topic_model.get_topic_tree(hierarchical_topics, tight_layout = True)
 
505
  output_list.append(tree_name)
506
 
507
  # Save new hierarchical topic model to file
508
+ hierarchical_topics_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topics_distz_' + today_rev + '.csv'
509
  hierarchical_topics.to_csv(hierarchical_topics_name)
510
  output_list.append(hierarchical_topics_name)
511
 
512
+
513
+ #try:
514
+ topics_vis, hierarchy_df, hierarchy_topic_names = visualize_hierarchical_documents_custom(topic_model, docs, label_list, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
515
+ topics_vis_2 = visualize_hierarchy_custom(topic_model, hierarchical_topics=hierarchical_topics, width= 1200, height = 750)
516
+
517
+ # Write hierarchical topics levels to df
518
+ hierarchy_df_name = data_file_name_no_ext + '_' + 'hierarchy_topics_df_' + today_rev + '.csv'
519
+ hierarchy_df.to_csv(hierarchy_df_name)
520
+ output_list.append(hierarchy_df_name)
521
+
522
+ # Write hierarchical topics names to df
523
+ hierarchy_topic_names_name = data_file_name_no_ext + '_' + 'hierarchy_topics_names_' + today_rev + '.csv'
524
+ hierarchy_topic_names.to_csv(hierarchy_topic_names_name)
525
+ output_list.append(hierarchy_topic_names_name)
526
+
527
+ #except:
528
+ # error_message = "Visualisation preparation failed. Perhaps you need more topics to create the full hierarchy (more than 10)?"
529
+ # return error_message, output_list, None, None
530
 
531
  topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
532
  topics_vis.write_html(topics_vis_name)