Spaces:
Running
Running
More efficient embeddings save and representations load/process. Custom visualisation hover option added, formatting improvements. Version 0.1?
Browse files- .gitignore +1 -0
- app.py +78 -50
- funcs/bertopic_vis_documents.py +245 -0
- funcs/representation_model.py +1 -1
.gitignore
CHANGED
@@ -7,6 +7,7 @@
|
|
7 |
*.png
|
8 |
*.safetensors
|
9 |
*.json
|
|
|
10 |
.ipynb_checkpoints/*
|
11 |
old_code/*
|
12 |
model/*
|
|
|
7 |
*.png
|
8 |
*.safetensors
|
9 |
*.json
|
10 |
+
*.html
|
11 |
.ipynb_checkpoints/*
|
12 |
old_code/*
|
13 |
model/*
|
app.py
CHANGED
@@ -80,7 +80,14 @@ hf_model_name = 'TheBloke/phi-2-orange-GGUF' #'NousResearch/Nous-Capybara-7B-V1
|
|
80 |
hf_model_file = 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf' # 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf'
|
81 |
|
82 |
|
83 |
-
def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, in_label, anonymise_drop, return_intermediate_files, embeddings_super_compress, low_resource_mode, create_llm_topic_labels, save_topic_model, visualise_topics, reduce_outliers, embeddings_out):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
all_tic = time.perf_counter()
|
86 |
|
@@ -97,8 +104,13 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
|
|
97 |
in_label_list_first = in_label[0]
|
98 |
else:
|
99 |
in_label_list_first = in_colnames_list_first
|
|
|
|
|
|
|
|
|
100 |
|
101 |
if anonymise_drop == "Yes":
|
|
|
102 |
anon_tic = time.perf_counter()
|
103 |
time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
|
104 |
in_files_anon_col, anonymisation_success = anon.anonymise_script(in_files, in_colnames_list_first, anon_strat="replace")
|
@@ -111,7 +123,7 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
|
|
111 |
time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
|
112 |
|
113 |
docs = list(in_files[in_colnames_list_first].str.lower())
|
114 |
-
|
115 |
|
116 |
# Check if embeddings are being loaded in
|
117 |
## Load in pre-embedded file if exists
|
@@ -144,6 +156,8 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
|
|
144 |
|
145 |
umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
|
146 |
|
|
|
|
|
147 |
embeddings_out, reduced_embeddings = make_or_load_embeddings(docs, file_list, data_file_name_no_ext, embeddings_out, embedding_model, return_intermediate_files, embeddings_super_compress, low_resource_mode, create_llm_topic_labels)
|
148 |
|
149 |
vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
@@ -151,28 +165,27 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
|
|
151 |
from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
|
152 |
from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
|
158 |
if not candidate_topics:
|
159 |
|
160 |
# Generate representation model here if topics won't be changed later
|
161 |
-
if reduce_outliers == "No":
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
|
177 |
topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
|
178 |
|
@@ -189,25 +202,25 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
|
|
189 |
zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
|
190 |
|
191 |
# Generate representation model here if topics won't be changed later
|
192 |
-
if reduce_outliers == "No":
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
else:
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
|
212 |
topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
|
213 |
|
@@ -215,35 +228,43 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
|
|
215 |
return "No topics found.", data_file_name, None
|
216 |
|
217 |
else:
|
218 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
# Reduce outliers if required
|
221 |
if reduce_outliers == "Yes":
|
|
|
222 |
print("Reducing outliers.")
|
223 |
# Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
|
224 |
topics_text = topic_model.reduce_outliers(docs, topics_text, strategy="embeddings")
|
225 |
# Then, update the topics to the ones that considered the new data
|
226 |
-
topic_model.update_topics(docs, topics=topics_text, vectorizer_model=vectoriser_model, representation_model=representation_model)
|
227 |
print("Finished reducing outliers.")
|
228 |
|
|
|
|
|
|
|
229 |
topic_dets = topic_model.get_topic_info()
|
230 |
-
#print(topic_dets.columns)
|
231 |
|
232 |
if topic_dets.shape[0] == 1:
|
233 |
topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
234 |
topic_dets.to_csv(topic_det_output_name)
|
235 |
output_list.append(topic_det_output_name)
|
236 |
|
237 |
-
return "No topics found, original file returned", output_list, None
|
238 |
|
239 |
# Replace original labels with LLM labels
|
240 |
-
if "
|
241 |
-
llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["
|
242 |
topic_model.set_topic_labels(llm_labels)
|
243 |
else:
|
244 |
topic_model.set_topic_labels(list(topic_dets["Name"]))
|
245 |
|
246 |
# Outputs
|
|
|
247 |
|
248 |
topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
249 |
topic_dets.to_csv(topic_det_output_name)
|
@@ -288,10 +309,15 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
|
|
288 |
output_list.append(embeddings_file_name)
|
289 |
|
290 |
if visualise_topics == "Yes":
|
|
|
|
|
291 |
# Visualise the topics:
|
292 |
vis_tic = time.perf_counter()
|
293 |
print("Creating visualisation")
|
294 |
-
topics_vis = topic_model
|
|
|
|
|
|
|
295 |
|
296 |
all_toc = time.perf_counter()
|
297 |
time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
|
@@ -304,7 +330,7 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
|
|
304 |
return output_text, output_list, topics_vis, embeddings_out
|
305 |
|
306 |
all_toc = time.perf_counter()
|
307 |
-
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
|
308 |
print(time_out)
|
309 |
|
310 |
return output_text, output_list, None, embeddings_out
|
@@ -321,7 +347,9 @@ with block:
|
|
321 |
gr.Markdown(
|
322 |
"""
|
323 |
# Topic modeller
|
324 |
-
Generate topics from open text in tabular data. Upload a file (csv, xlsx, or parquet), then specify the
|
|
|
|
|
325 |
""")
|
326 |
|
327 |
with gr.Tab("Load files and find topics"):
|
@@ -329,7 +357,7 @@ with block:
|
|
329 |
in_files = gr.File(label="Input text from file", file_count="multiple")
|
330 |
with gr.Row():
|
331 |
in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
|
332 |
-
in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column
|
333 |
|
334 |
with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
|
335 |
candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
|
|
|
80 |
hf_model_file = 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf' # 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf'
|
81 |
|
82 |
|
83 |
+
def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, in_label, anonymise_drop, return_intermediate_files, embeddings_super_compress, low_resource_mode, create_llm_topic_labels, save_topic_model, visualise_topics, reduce_outliers, embeddings_out, progress=gr.Progress()):
|
84 |
+
|
85 |
+
progress(0, desc= "Loading data")
|
86 |
+
|
87 |
+
if not in_colnames or not in_label:
|
88 |
+
error_message = "Please enter one column name for the topics and another for the labelling."
|
89 |
+
print(error_message)
|
90 |
+
return error_message, None, None, embeddings_out
|
91 |
|
92 |
all_tic = time.perf_counter()
|
93 |
|
|
|
104 |
in_label_list_first = in_label[0]
|
105 |
else:
|
106 |
in_label_list_first = in_colnames_list_first
|
107 |
+
|
108 |
+
# Make sure format of input series is good
|
109 |
+
in_files[in_colnames_list_first] = in_files[in_colnames_list_first].fillna('').astype(str)
|
110 |
+
in_files[in_label_list_first] = in_files[in_label_list_first].fillna('').astype(str)
|
111 |
|
112 |
if anonymise_drop == "Yes":
|
113 |
+
progress(0.1, desc= "Anonymising data")
|
114 |
anon_tic = time.perf_counter()
|
115 |
time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
|
116 |
in_files_anon_col, anonymisation_success = anon.anonymise_script(in_files, in_colnames_list_first, anon_strat="replace")
|
|
|
123 |
time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
|
124 |
|
125 |
docs = list(in_files[in_colnames_list_first].str.lower())
|
126 |
+
label_list = list(in_files[in_label_list_first])
|
127 |
|
128 |
# Check if embeddings are being loaded in
|
129 |
## Load in pre-embedded file if exists
|
|
|
156 |
|
157 |
umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
|
158 |
|
159 |
+
progress(0.2, desc= "Loading/creating embeddings")
|
160 |
+
|
161 |
embeddings_out, reduced_embeddings = make_or_load_embeddings(docs, file_list, data_file_name_no_ext, embeddings_out, embedding_model, return_intermediate_files, embeddings_super_compress, low_resource_mode, create_llm_topic_labels)
|
162 |
|
163 |
vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
|
|
165 |
from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
|
166 |
from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
167 |
|
168 |
+
|
169 |
+
progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
|
|
|
170 |
|
171 |
if not candidate_topics:
|
172 |
|
173 |
# Generate representation model here if topics won't be changed later
|
174 |
+
# if reduce_outliers == "No":
|
175 |
+
# topic_model = BERTopic( embedding_model=embedding_model_pipe,
|
176 |
+
# vectorizer_model=vectoriser_model,
|
177 |
+
# umap_model=umap_model,
|
178 |
+
# min_topic_size = min_docs_slider,
|
179 |
+
# nr_topics = max_topics_slider,
|
180 |
+
# representation_model=representation_model,
|
181 |
+
# verbose = True)
|
182 |
+
|
183 |
+
topic_model = BERTopic( embedding_model=embedding_model_pipe,
|
184 |
+
vectorizer_model=vectoriser_model,
|
185 |
+
umap_model=umap_model,
|
186 |
+
min_topic_size = min_docs_slider,
|
187 |
+
nr_topics = max_topics_slider,
|
188 |
+
verbose = True)
|
189 |
|
190 |
topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
|
191 |
|
|
|
202 |
zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
|
203 |
|
204 |
# Generate representation model here if topics won't be changed later
|
205 |
+
# if reduce_outliers == "No":
|
206 |
+
# topic_model = BERTopic( embedding_model=embedding_model_pipe,
|
207 |
+
# vectorizer_model=vectoriser_model,
|
208 |
+
# umap_model=umap_model,
|
209 |
+
# min_topic_size = min_docs_slider,
|
210 |
+
# nr_topics = max_topics_slider,
|
211 |
+
# zeroshot_topic_list = zero_shot_topics_lower,
|
212 |
+
# zeroshot_min_similarity = 0.5,#0.7,
|
213 |
+
# representation_model=representation_model,
|
214 |
+
# verbose = True)
|
215 |
+
# else:
|
216 |
+
topic_model = BERTopic( embedding_model=embedding_model_pipe,
|
217 |
+
vectorizer_model=vectoriser_model,
|
218 |
+
umap_model=umap_model,
|
219 |
+
min_topic_size = min_docs_slider,
|
220 |
+
nr_topics = max_topics_slider,
|
221 |
+
zeroshot_topic_list = zero_shot_topics_lower,
|
222 |
+
zeroshot_min_similarity = 0.5,#0.7,
|
223 |
+
verbose = True)
|
224 |
|
225 |
topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
|
226 |
|
|
|
228 |
return "No topics found.", data_file_name, None
|
229 |
|
230 |
else:
|
231 |
+
print("Topic model created.")
|
232 |
+
|
233 |
+
progress(0.5, desc= "Loading in representation model")
|
234 |
+
print("Create LLM topic labels:", create_llm_topic_labels)
|
235 |
+
representation_model = create_representation_model(create_llm_topic_labels, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
236 |
+
|
237 |
|
238 |
+
# Reduce outliers if required, then update representation
|
239 |
if reduce_outliers == "Yes":
|
240 |
+
progress(0.6, desc= "Reducing outliers then creating topic representations")
|
241 |
print("Reducing outliers.")
|
242 |
# Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
|
243 |
topics_text = topic_model.reduce_outliers(docs, topics_text, strategy="embeddings")
|
244 |
# Then, update the topics to the ones that considered the new data
|
|
|
245 |
print("Finished reducing outliers.")
|
246 |
|
247 |
+
progress(0.6, desc= "Creating topic representations")
|
248 |
+
topic_model.update_topics(docs, topics=topics_text, vectorizer_model=vectoriser_model, representation_model=representation_model)
|
249 |
+
|
250 |
topic_dets = topic_model.get_topic_info()
|
|
|
251 |
|
252 |
if topic_dets.shape[0] == 1:
|
253 |
topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
254 |
topic_dets.to_csv(topic_det_output_name)
|
255 |
output_list.append(topic_det_output_name)
|
256 |
|
257 |
+
return "No topics found, original file returned", output_list, None, embeddings_out
|
258 |
|
259 |
# Replace original labels with LLM labels
|
260 |
+
if "Phi" in topic_model.get_topic_info().columns:
|
261 |
+
llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["Phi"].values()]
|
262 |
topic_model.set_topic_labels(llm_labels)
|
263 |
else:
|
264 |
topic_model.set_topic_labels(list(topic_dets["Name"]))
|
265 |
|
266 |
# Outputs
|
267 |
+
progress(0.8, desc= "Saving output")
|
268 |
|
269 |
topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
270 |
topic_dets.to_csv(topic_det_output_name)
|
|
|
309 |
output_list.append(embeddings_file_name)
|
310 |
|
311 |
if visualise_topics == "Yes":
|
312 |
+
from funcs.bertopic_vis_documents import visualize_documents_custom
|
313 |
+
progress(0.9, desc= "Creating visualisation (this can take a while)")
|
314 |
# Visualise the topics:
|
315 |
vis_tic = time.perf_counter()
|
316 |
print("Creating visualisation")
|
317 |
+
topics_vis = visualize_documents_custom(topic_model, docs, hover_labels = label_list, reduced_embeddings=reduced_embeddings, hide_annotations=True, hide_document_hover=False, custom_labels=True)
|
318 |
+
topics_vis_name = data_file_name_no_ext + '_' + 'visualisation_' + today_rev + '.html'
|
319 |
+
topics_vis.write_html(topics_vis_name)
|
320 |
+
output_list.append(topics_vis_name)
|
321 |
|
322 |
all_toc = time.perf_counter()
|
323 |
time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
|
|
|
330 |
return output_text, output_list, topics_vis, embeddings_out
|
331 |
|
332 |
all_toc = time.perf_counter()
|
333 |
+
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
|
334 |
print(time_out)
|
335 |
|
336 |
return output_text, output_list, None, embeddings_out
|
|
|
347 |
gr.Markdown(
|
348 |
"""
|
349 |
# Topic modeller
|
350 |
+
Generate topics from open text in tabular data. Upload a file (csv, xlsx, or parquet), then specify the open text column that you want to use to generate topics, and another for labels in the visualisation. If you have an embeddings .npz file of the text made using the 'jina-embeddings-v2-small-en' model, you can load this in at the same time to skip the first modelling step. If you have a pre-defined list of topics, you can upload this as a csv file under 'I have my own list of topics...'. Further configuration options are available under the 'Options' tab.
|
351 |
+
|
352 |
+
Suggested test dataset: https://huggingface.co/datasets/rag-datasets/mini_wikipedia/tree/main/data (passages.parquet)
|
353 |
""")
|
354 |
|
355 |
with gr.Tab("Load files and find topics"):
|
|
|
357 |
in_files = gr.File(label="Input text from file", file_count="multiple")
|
358 |
with gr.Row():
|
359 |
in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
|
360 |
+
in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in the output visualisation.")
|
361 |
|
362 |
with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
|
363 |
candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
|
funcs/bertopic_vis_documents.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
|
5 |
+
from umap import UMAP
|
6 |
+
from typing import List, Union
|
7 |
+
|
8 |
+
# Shamelessly taken and adapted from Bertopic original implementation here (Maarten Grootendorst): https://github.com/MaartenGr/BERTopic/blob/master/bertopic/plotting/_documents.py
|
9 |
+
|
10 |
+
def visualize_documents_custom(topic_model,
|
11 |
+
docs: List[str],
|
12 |
+
hover_labels: List[str],
|
13 |
+
topics: List[int] = None,
|
14 |
+
embeddings: np.ndarray = None,
|
15 |
+
reduced_embeddings: np.ndarray = None,
|
16 |
+
sample: float = None,
|
17 |
+
hide_annotations: bool = False,
|
18 |
+
hide_document_hover: bool = False,
|
19 |
+
custom_labels: Union[bool, str] = False,
|
20 |
+
title: str = "<b>Documents and Topics</b>",
|
21 |
+
width: int = 1200,
|
22 |
+
height: int = 750):
|
23 |
+
""" Visualize documents and their topics in 2D
|
24 |
+
|
25 |
+
Arguments:
|
26 |
+
topic_model: A fitted BERTopic instance.
|
27 |
+
docs: The documents you used when calling either `fit` or `fit_transform`
|
28 |
+
topics: A selection of topics to visualize.
|
29 |
+
Not to be confused with the topics that you get from `.fit_transform`.
|
30 |
+
For example, if you want to visualize only topics 1 through 5:
|
31 |
+
`topics = [1, 2, 3, 4, 5]`.
|
32 |
+
embeddings: The embeddings of all documents in `docs`.
|
33 |
+
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
|
34 |
+
sample: The percentage of documents in each topic that you would like to keep.
|
35 |
+
Value can be between 0 and 1. Setting this value to, for example,
|
36 |
+
0.1 (10% of documents in each topic) makes it easier to visualize
|
37 |
+
millions of documents as a subset is chosen.
|
38 |
+
hide_annotations: Hide the names of the traces on top of each cluster.
|
39 |
+
hide_document_hover: Hide the content of the documents when hovering over
|
40 |
+
specific points. Helps to speed up generation of visualization.
|
41 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
42 |
+
`topic_model.set_topic_labels`.
|
43 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
44 |
+
title: Title of the plot.
|
45 |
+
width: The width of the figure.
|
46 |
+
height: The height of the figure.
|
47 |
+
|
48 |
+
Examples:
|
49 |
+
|
50 |
+
To visualize the topics simply run:
|
51 |
+
|
52 |
+
```python
|
53 |
+
topic_model.visualize_documents(docs)
|
54 |
+
```
|
55 |
+
|
56 |
+
Do note that this re-calculates the embeddings and reduces them to 2D.
|
57 |
+
The advised and prefered pipeline for using this function is as follows:
|
58 |
+
|
59 |
+
```python
|
60 |
+
from sklearn.datasets import fetch_20newsgroups
|
61 |
+
from sentence_transformers import SentenceTransformer
|
62 |
+
from bertopic import BERTopic
|
63 |
+
from umap import UMAP
|
64 |
+
|
65 |
+
# Prepare embeddings
|
66 |
+
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
|
67 |
+
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
68 |
+
embeddings = sentence_model.encode(docs, show_progress_bar=False)
|
69 |
+
|
70 |
+
# Train BERTopic
|
71 |
+
topic_model = BERTopic().fit(docs, embeddings)
|
72 |
+
|
73 |
+
# Reduce dimensionality of embeddings, this step is optional
|
74 |
+
# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
|
75 |
+
|
76 |
+
# Run the visualization with the original embeddings
|
77 |
+
topic_model.visualize_documents(docs, embeddings=embeddings)
|
78 |
+
|
79 |
+
# Or, if you have reduced the original embeddings already:
|
80 |
+
topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)
|
81 |
+
```
|
82 |
+
|
83 |
+
Or if you want to save the resulting figure:
|
84 |
+
|
85 |
+
```python
|
86 |
+
fig = topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)
|
87 |
+
fig.write_html("path/to/file.html")
|
88 |
+
```
|
89 |
+
|
90 |
+
<iframe src="../../getting_started/visualization/documents.html"
|
91 |
+
style="width:1000px; height: 800px; border: 0px;""></iframe>
|
92 |
+
"""
|
93 |
+
topic_per_doc = topic_model.topics_
|
94 |
+
|
95 |
+
# Add <br> tags to hover labels to get them to appear on multiple lines
|
96 |
+
def wrap_by_word(s, n):
|
97 |
+
'''returns a string where \\n is inserted between every n words'''
|
98 |
+
a = s.split()
|
99 |
+
ret = ''
|
100 |
+
for i in range(0, len(a), n):
|
101 |
+
ret += ' '.join(a[i:i+n]) + '<br>'
|
102 |
+
return ret
|
103 |
+
|
104 |
+
# Apply the function to every element in the list
|
105 |
+
hover_labels = [wrap_by_word(s, n=20) for s in hover_labels]
|
106 |
+
|
107 |
+
|
108 |
+
# Sample the data to optimize for visualization and dimensionality reduction
|
109 |
+
if sample is None or sample > 1:
|
110 |
+
sample = 1
|
111 |
+
|
112 |
+
indices = []
|
113 |
+
for topic in set(topic_per_doc):
|
114 |
+
s = np.where(np.array(topic_per_doc) == topic)[0]
|
115 |
+
size = len(s) if len(s) < 100 else int(len(s) * sample)
|
116 |
+
indices.extend(np.random.choice(s, size=size, replace=False))
|
117 |
+
indices = np.array(indices)
|
118 |
+
|
119 |
+
df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
|
120 |
+
df["doc"] = [docs[index] for index in indices]
|
121 |
+
df["hover_labels"] = [hover_labels[index] for index in indices]
|
122 |
+
df["topic"] = [topic_per_doc[index] for index in indices]
|
123 |
+
|
124 |
+
# Extract embeddings if not already done
|
125 |
+
if sample is None:
|
126 |
+
if embeddings is None and reduced_embeddings is None:
|
127 |
+
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
128 |
+
else:
|
129 |
+
embeddings_to_reduce = embeddings
|
130 |
+
else:
|
131 |
+
if embeddings is not None:
|
132 |
+
embeddings_to_reduce = embeddings[indices]
|
133 |
+
elif embeddings is None and reduced_embeddings is None:
|
134 |
+
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
135 |
+
|
136 |
+
# Reduce input embeddings
|
137 |
+
if reduced_embeddings is None:
|
138 |
+
umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
|
139 |
+
embeddings_2d = umap_model.embedding_
|
140 |
+
elif sample is not None and reduced_embeddings is not None:
|
141 |
+
embeddings_2d = reduced_embeddings[indices]
|
142 |
+
elif sample is None and reduced_embeddings is not None:
|
143 |
+
embeddings_2d = reduced_embeddings
|
144 |
+
|
145 |
+
unique_topics = set(topic_per_doc)
|
146 |
+
if topics is None:
|
147 |
+
topics = unique_topics
|
148 |
+
|
149 |
+
# Combine data
|
150 |
+
df["x"] = embeddings_2d[:, 0]
|
151 |
+
df["y"] = embeddings_2d[:, 1]
|
152 |
+
|
153 |
+
# Prepare text and names
|
154 |
+
if isinstance(custom_labels, str):
|
155 |
+
names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in unique_topics]
|
156 |
+
names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
|
157 |
+
names = [label if len(label) < 30 else label[:27] + "..." for label in names]
|
158 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
159 |
+
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
|
160 |
+
else:
|
161 |
+
names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
|
162 |
+
|
163 |
+
# Visualize
|
164 |
+
fig = go.Figure()
|
165 |
+
|
166 |
+
# Outliers and non-selected topics
|
167 |
+
non_selected_topics = set(unique_topics).difference(topics)
|
168 |
+
if len(non_selected_topics) == 0:
|
169 |
+
non_selected_topics = [-1]
|
170 |
+
|
171 |
+
selection = df.loc[df.topic.isin(non_selected_topics), :]
|
172 |
+
selection["text"] = ""
|
173 |
+
selection.loc[len(selection), :] = [None, None, None, selection.x.mean(), selection.y.mean(), "Other documents"]
|
174 |
+
|
175 |
+
fig.add_trace(
|
176 |
+
go.Scattergl(
|
177 |
+
x=selection.x,
|
178 |
+
y=selection.y,
|
179 |
+
hovertext=selection.hover_labels if not hide_document_hover else None,
|
180 |
+
hoverinfo="text",
|
181 |
+
mode='markers+text',
|
182 |
+
name="other",
|
183 |
+
showlegend=False,
|
184 |
+
marker=dict(color='#CFD8DC', size=5, opacity=0.5),
|
185 |
+
hoverlabel=dict(align='left')
|
186 |
+
)
|
187 |
+
)
|
188 |
+
|
189 |
+
# Selected topics
|
190 |
+
for name, topic in zip(names, unique_topics):
|
191 |
+
if topic in topics and topic != -1:
|
192 |
+
selection = df.loc[df.topic == topic, :]
|
193 |
+
selection["text"] = ""
|
194 |
+
|
195 |
+
if not hide_annotations:
|
196 |
+
selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), name]
|
197 |
+
|
198 |
+
fig.add_trace(
|
199 |
+
go.Scattergl(
|
200 |
+
x=selection.x,
|
201 |
+
y=selection.y,
|
202 |
+
hovertext=selection.hover_labels if not hide_document_hover else None,
|
203 |
+
hoverinfo="text",
|
204 |
+
text=selection.text,
|
205 |
+
mode='markers+text',
|
206 |
+
name=name,
|
207 |
+
textfont=dict(
|
208 |
+
size=12,
|
209 |
+
),
|
210 |
+
marker=dict(size=5, opacity=0.5),
|
211 |
+
hoverlabel=dict(align='left')
|
212 |
+
))
|
213 |
+
|
214 |
+
# Add grid in a 'plus' shape
|
215 |
+
x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
|
216 |
+
y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
|
217 |
+
fig.add_shape(type="line",
|
218 |
+
x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
|
219 |
+
line=dict(color="#CFD8DC", width=2))
|
220 |
+
fig.add_shape(type="line",
|
221 |
+
x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
|
222 |
+
line=dict(color="#9E9E9E", width=2))
|
223 |
+
fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
|
224 |
+
fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
|
225 |
+
|
226 |
+
# Stylize layout
|
227 |
+
fig.update_layout(
|
228 |
+
template="simple_white",
|
229 |
+
title={
|
230 |
+
'text': f"{title}",
|
231 |
+
'x': 0.5,
|
232 |
+
'xanchor': 'center',
|
233 |
+
'yanchor': 'top',
|
234 |
+
'font': dict(
|
235 |
+
size=22,
|
236 |
+
color="Black")
|
237 |
+
},
|
238 |
+
hoverlabel_align = 'left',
|
239 |
+
width=width,
|
240 |
+
height=height
|
241 |
+
)
|
242 |
+
|
243 |
+
fig.update_xaxes(visible=False)
|
244 |
+
fig.update_yaxes(visible=False)
|
245 |
+
return fig
|
funcs/representation_model.py
CHANGED
@@ -168,7 +168,7 @@ def create_representation_model(create_llm_topic_labels, llm_config, hf_model_na
|
|
168 |
# All representation models
|
169 |
representation_model = {
|
170 |
"KeyBERT": keybert,
|
171 |
-
"
|
172 |
}
|
173 |
|
174 |
elif create_llm_topic_labels == "No":
|
|
|
168 |
# All representation models
|
169 |
representation_model = {
|
170 |
"KeyBERT": keybert,
|
171 |
+
"Phi": llm_model
|
172 |
}
|
173 |
|
174 |
elif create_llm_topic_labels == "No":
|