Sonnyjim commited on
Commit
ffe5eb2
1 Parent(s): 9eeba1e

More efficient embeddings save and representations load/process. Custom visualisation hover option added, formatting improvements. Version 0.1?

Browse files
.gitignore CHANGED
@@ -7,6 +7,7 @@
7
  *.png
8
  *.safetensors
9
  *.json
 
10
  .ipynb_checkpoints/*
11
  old_code/*
12
  model/*
 
7
  *.png
8
  *.safetensors
9
  *.json
10
+ *.html
11
  .ipynb_checkpoints/*
12
  old_code/*
13
  model/*
app.py CHANGED
@@ -80,7 +80,14 @@ hf_model_name = 'TheBloke/phi-2-orange-GGUF' #'NousResearch/Nous-Capybara-7B-V1
80
  hf_model_file = 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf' # 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf'
81
 
82
 
83
- def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, in_label, anonymise_drop, return_intermediate_files, embeddings_super_compress, low_resource_mode, create_llm_topic_labels, save_topic_model, visualise_topics, reduce_outliers, embeddings_out):
 
 
 
 
 
 
 
84
 
85
  all_tic = time.perf_counter()
86
 
@@ -97,8 +104,13 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
97
  in_label_list_first = in_label[0]
98
  else:
99
  in_label_list_first = in_colnames_list_first
 
 
 
 
100
 
101
  if anonymise_drop == "Yes":
 
102
  anon_tic = time.perf_counter()
103
  time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
104
  in_files_anon_col, anonymisation_success = anon.anonymise_script(in_files, in_colnames_list_first, anon_strat="replace")
@@ -111,7 +123,7 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
111
  time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
112
 
113
  docs = list(in_files[in_colnames_list_first].str.lower())
114
- label_col = in_files[in_label_list_first]
115
 
116
  # Check if embeddings are being loaded in
117
  ## Load in pre-embedded file if exists
@@ -144,6 +156,8 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
144
 
145
  umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
146
 
 
 
147
  embeddings_out, reduced_embeddings = make_or_load_embeddings(docs, file_list, data_file_name_no_ext, embeddings_out, embedding_model, return_intermediate_files, embeddings_super_compress, low_resource_mode, create_llm_topic_labels)
148
 
149
  vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
@@ -151,28 +165,27 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
151
  from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
152
  from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
153
 
154
- print("Create LLM topic labels:", create_llm_topic_labels)
155
- representation_model = create_representation_model(create_llm_topic_labels, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
156
-
157
 
158
  if not candidate_topics:
159
 
160
  # Generate representation model here if topics won't be changed later
161
- if reduce_outliers == "No":
162
- topic_model = BERTopic( embedding_model=embedding_model_pipe,
163
- vectorizer_model=vectoriser_model,
164
- umap_model=umap_model,
165
- min_topic_size = min_docs_slider,
166
- nr_topics = max_topics_slider,
167
- representation_model=representation_model,
168
- verbose = True)
169
- else:
170
- topic_model = BERTopic( embedding_model=embedding_model_pipe,
171
- vectorizer_model=vectoriser_model,
172
- umap_model=umap_model,
173
- min_topic_size = min_docs_slider,
174
- nr_topics = max_topics_slider,
175
- verbose = True)
176
 
177
  topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
178
 
@@ -189,25 +202,25 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
189
  zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
190
 
191
  # Generate representation model here if topics won't be changed later
192
- if reduce_outliers == "No":
193
- topic_model = BERTopic( embedding_model=embedding_model_pipe,
194
- vectorizer_model=vectoriser_model,
195
- umap_model=umap_model,
196
- min_topic_size = min_docs_slider,
197
- nr_topics = max_topics_slider,
198
- zeroshot_topic_list = zero_shot_topics_lower,
199
- zeroshot_min_similarity = 0.5,#0.7,
200
- representation_model=representation_model,
201
- verbose = True)
202
- else:
203
- topic_model = BERTopic( embedding_model=embedding_model_pipe,
204
- vectorizer_model=vectoriser_model,
205
- umap_model=umap_model,
206
- min_topic_size = min_docs_slider,
207
- nr_topics = max_topics_slider,
208
- zeroshot_topic_list = zero_shot_topics_lower,
209
- zeroshot_min_similarity = 0.5,#0.7,
210
- verbose = True)
211
 
212
  topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
213
 
@@ -215,35 +228,43 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
215
  return "No topics found.", data_file_name, None
216
 
217
  else:
218
- print("Preparing topic model outputs.")
 
 
 
 
 
219
 
220
- # Reduce outliers if required
221
  if reduce_outliers == "Yes":
 
222
  print("Reducing outliers.")
223
  # Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
224
  topics_text = topic_model.reduce_outliers(docs, topics_text, strategy="embeddings")
225
  # Then, update the topics to the ones that considered the new data
226
- topic_model.update_topics(docs, topics=topics_text, vectorizer_model=vectoriser_model, representation_model=representation_model)
227
  print("Finished reducing outliers.")
228
 
 
 
 
229
  topic_dets = topic_model.get_topic_info()
230
- #print(topic_dets.columns)
231
 
232
  if topic_dets.shape[0] == 1:
233
  topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
234
  topic_dets.to_csv(topic_det_output_name)
235
  output_list.append(topic_det_output_name)
236
 
237
- return "No topics found, original file returned", output_list, None
238
 
239
  # Replace original labels with LLM labels
240
- if "Mistral" in topic_model.get_topic_info().columns:
241
- llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["Mistral"].values()]
242
  topic_model.set_topic_labels(llm_labels)
243
  else:
244
  topic_model.set_topic_labels(list(topic_dets["Name"]))
245
 
246
  # Outputs
 
247
 
248
  topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
249
  topic_dets.to_csv(topic_det_output_name)
@@ -288,10 +309,15 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
288
  output_list.append(embeddings_file_name)
289
 
290
  if visualise_topics == "Yes":
 
 
291
  # Visualise the topics:
292
  vis_tic = time.perf_counter()
293
  print("Creating visualisation")
294
- topics_vis = topic_model.visualize_documents(label_col, reduced_embeddings=reduced_embeddings, hide_annotations=True, hide_document_hover=False, custom_labels=True)
 
 
 
295
 
296
  all_toc = time.perf_counter()
297
  time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
@@ -304,7 +330,7 @@ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_s
304
  return output_text, output_list, topics_vis, embeddings_out
305
 
306
  all_toc = time.perf_counter()
307
- time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
308
  print(time_out)
309
 
310
  return output_text, output_list, None, embeddings_out
@@ -321,7 +347,9 @@ with block:
321
  gr.Markdown(
322
  """
323
  # Topic modeller
324
- Generate topics from open text in tabular data. Upload a file (csv, xlsx, or parquet), then specify the columns that you want to use to generate topics and use for labels in the visualisation. If you have an embeddings .npz file of the text made using the 'jina-embeddings-v2-small-en' model, you can load this in at the same time to skip the first modelling step. If you have a pre-defined list of topics, you can upload this as a csv file under 'I have my own list of topics...'. Further configuration options are available under the 'Options' tab.
 
 
325
  """)
326
 
327
  with gr.Tab("Load files and find topics"):
@@ -329,7 +357,7 @@ with block:
329
  in_files = gr.File(label="Input text from file", file_count="multiple")
330
  with gr.Row():
331
  in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
332
- in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to for labelling documents in the output visualisation.")
333
 
334
  with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
335
  candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
 
80
  hf_model_file = 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf' # 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf'
81
 
82
 
83
+ def extract_topics(in_files, in_file, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, in_label, anonymise_drop, return_intermediate_files, embeddings_super_compress, low_resource_mode, create_llm_topic_labels, save_topic_model, visualise_topics, reduce_outliers, embeddings_out, progress=gr.Progress()):
84
+
85
+ progress(0, desc= "Loading data")
86
+
87
+ if not in_colnames or not in_label:
88
+ error_message = "Please enter one column name for the topics and another for the labelling."
89
+ print(error_message)
90
+ return error_message, None, None, embeddings_out
91
 
92
  all_tic = time.perf_counter()
93
 
 
104
  in_label_list_first = in_label[0]
105
  else:
106
  in_label_list_first = in_colnames_list_first
107
+
108
+ # Make sure format of input series is good
109
+ in_files[in_colnames_list_first] = in_files[in_colnames_list_first].fillna('').astype(str)
110
+ in_files[in_label_list_first] = in_files[in_label_list_first].fillna('').astype(str)
111
 
112
  if anonymise_drop == "Yes":
113
+ progress(0.1, desc= "Anonymising data")
114
  anon_tic = time.perf_counter()
115
  time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
116
  in_files_anon_col, anonymisation_success = anon.anonymise_script(in_files, in_colnames_list_first, anon_strat="replace")
 
123
  time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
124
 
125
  docs = list(in_files[in_colnames_list_first].str.lower())
126
+ label_list = list(in_files[in_label_list_first])
127
 
128
  # Check if embeddings are being loaded in
129
  ## Load in pre-embedded file if exists
 
156
 
157
  umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
158
 
159
+ progress(0.2, desc= "Loading/creating embeddings")
160
+
161
  embeddings_out, reduced_embeddings = make_or_load_embeddings(docs, file_list, data_file_name_no_ext, embeddings_out, embedding_model, return_intermediate_files, embeddings_super_compress, low_resource_mode, create_llm_topic_labels)
162
 
163
  vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
 
165
  from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
166
  from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
167
 
168
+
169
+ progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
 
170
 
171
  if not candidate_topics:
172
 
173
  # Generate representation model here if topics won't be changed later
174
+ # if reduce_outliers == "No":
175
+ # topic_model = BERTopic( embedding_model=embedding_model_pipe,
176
+ # vectorizer_model=vectoriser_model,
177
+ # umap_model=umap_model,
178
+ # min_topic_size = min_docs_slider,
179
+ # nr_topics = max_topics_slider,
180
+ # representation_model=representation_model,
181
+ # verbose = True)
182
+
183
+ topic_model = BERTopic( embedding_model=embedding_model_pipe,
184
+ vectorizer_model=vectoriser_model,
185
+ umap_model=umap_model,
186
+ min_topic_size = min_docs_slider,
187
+ nr_topics = max_topics_slider,
188
+ verbose = True)
189
 
190
  topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
191
 
 
202
  zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
203
 
204
  # Generate representation model here if topics won't be changed later
205
+ # if reduce_outliers == "No":
206
+ # topic_model = BERTopic( embedding_model=embedding_model_pipe,
207
+ # vectorizer_model=vectoriser_model,
208
+ # umap_model=umap_model,
209
+ # min_topic_size = min_docs_slider,
210
+ # nr_topics = max_topics_slider,
211
+ # zeroshot_topic_list = zero_shot_topics_lower,
212
+ # zeroshot_min_similarity = 0.5,#0.7,
213
+ # representation_model=representation_model,
214
+ # verbose = True)
215
+ # else:
216
+ topic_model = BERTopic( embedding_model=embedding_model_pipe,
217
+ vectorizer_model=vectoriser_model,
218
+ umap_model=umap_model,
219
+ min_topic_size = min_docs_slider,
220
+ nr_topics = max_topics_slider,
221
+ zeroshot_topic_list = zero_shot_topics_lower,
222
+ zeroshot_min_similarity = 0.5,#0.7,
223
+ verbose = True)
224
 
225
  topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
226
 
 
228
  return "No topics found.", data_file_name, None
229
 
230
  else:
231
+ print("Topic model created.")
232
+
233
+ progress(0.5, desc= "Loading in representation model")
234
+ print("Create LLM topic labels:", create_llm_topic_labels)
235
+ representation_model = create_representation_model(create_llm_topic_labels, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
236
+
237
 
238
+ # Reduce outliers if required, then update representation
239
  if reduce_outliers == "Yes":
240
+ progress(0.6, desc= "Reducing outliers then creating topic representations")
241
  print("Reducing outliers.")
242
  # Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
243
  topics_text = topic_model.reduce_outliers(docs, topics_text, strategy="embeddings")
244
  # Then, update the topics to the ones that considered the new data
 
245
  print("Finished reducing outliers.")
246
 
247
+ progress(0.6, desc= "Creating topic representations")
248
+ topic_model.update_topics(docs, topics=topics_text, vectorizer_model=vectoriser_model, representation_model=representation_model)
249
+
250
  topic_dets = topic_model.get_topic_info()
 
251
 
252
  if topic_dets.shape[0] == 1:
253
  topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
254
  topic_dets.to_csv(topic_det_output_name)
255
  output_list.append(topic_det_output_name)
256
 
257
+ return "No topics found, original file returned", output_list, None, embeddings_out
258
 
259
  # Replace original labels with LLM labels
260
+ if "Phi" in topic_model.get_topic_info().columns:
261
+ llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["Phi"].values()]
262
  topic_model.set_topic_labels(llm_labels)
263
  else:
264
  topic_model.set_topic_labels(list(topic_dets["Name"]))
265
 
266
  # Outputs
267
+ progress(0.8, desc= "Saving output")
268
 
269
  topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
270
  topic_dets.to_csv(topic_det_output_name)
 
309
  output_list.append(embeddings_file_name)
310
 
311
  if visualise_topics == "Yes":
312
+ from funcs.bertopic_vis_documents import visualize_documents_custom
313
+ progress(0.9, desc= "Creating visualisation (this can take a while)")
314
  # Visualise the topics:
315
  vis_tic = time.perf_counter()
316
  print("Creating visualisation")
317
+ topics_vis = visualize_documents_custom(topic_model, docs, hover_labels = label_list, reduced_embeddings=reduced_embeddings, hide_annotations=True, hide_document_hover=False, custom_labels=True)
318
+ topics_vis_name = data_file_name_no_ext + '_' + 'visualisation_' + today_rev + '.html'
319
+ topics_vis.write_html(topics_vis_name)
320
+ output_list.append(topics_vis_name)
321
 
322
  all_toc = time.perf_counter()
323
  time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
 
330
  return output_text, output_list, topics_vis, embeddings_out
331
 
332
  all_toc = time.perf_counter()
333
+ time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
334
  print(time_out)
335
 
336
  return output_text, output_list, None, embeddings_out
 
347
  gr.Markdown(
348
  """
349
  # Topic modeller
350
+ Generate topics from open text in tabular data. Upload a file (csv, xlsx, or parquet), then specify the open text column that you want to use to generate topics, and another for labels in the visualisation. If you have an embeddings .npz file of the text made using the 'jina-embeddings-v2-small-en' model, you can load this in at the same time to skip the first modelling step. If you have a pre-defined list of topics, you can upload this as a csv file under 'I have my own list of topics...'. Further configuration options are available under the 'Options' tab.
351
+
352
+ Suggested test dataset: https://huggingface.co/datasets/rag-datasets/mini_wikipedia/tree/main/data (passages.parquet)
353
  """)
354
 
355
  with gr.Tab("Load files and find topics"):
 
357
  in_files = gr.File(label="Input text from file", file_count="multiple")
358
  with gr.Row():
359
  in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
360
+ in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in the output visualisation.")
361
 
362
  with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
363
  candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
funcs/bertopic_vis_documents.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import plotly.graph_objects as go
4
+
5
+ from umap import UMAP
6
+ from typing import List, Union
7
+
8
+ # Shamelessly taken and adapted from Bertopic original implementation here (Maarten Grootendorst): https://github.com/MaartenGr/BERTopic/blob/master/bertopic/plotting/_documents.py
9
+
10
+ def visualize_documents_custom(topic_model,
11
+ docs: List[str],
12
+ hover_labels: List[str],
13
+ topics: List[int] = None,
14
+ embeddings: np.ndarray = None,
15
+ reduced_embeddings: np.ndarray = None,
16
+ sample: float = None,
17
+ hide_annotations: bool = False,
18
+ hide_document_hover: bool = False,
19
+ custom_labels: Union[bool, str] = False,
20
+ title: str = "<b>Documents and Topics</b>",
21
+ width: int = 1200,
22
+ height: int = 750):
23
+ """ Visualize documents and their topics in 2D
24
+
25
+ Arguments:
26
+ topic_model: A fitted BERTopic instance.
27
+ docs: The documents you used when calling either `fit` or `fit_transform`
28
+ topics: A selection of topics to visualize.
29
+ Not to be confused with the topics that you get from `.fit_transform`.
30
+ For example, if you want to visualize only topics 1 through 5:
31
+ `topics = [1, 2, 3, 4, 5]`.
32
+ embeddings: The embeddings of all documents in `docs`.
33
+ reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
34
+ sample: The percentage of documents in each topic that you would like to keep.
35
+ Value can be between 0 and 1. Setting this value to, for example,
36
+ 0.1 (10% of documents in each topic) makes it easier to visualize
37
+ millions of documents as a subset is chosen.
38
+ hide_annotations: Hide the names of the traces on top of each cluster.
39
+ hide_document_hover: Hide the content of the documents when hovering over
40
+ specific points. Helps to speed up generation of visualization.
41
+ custom_labels: If bool, whether to use custom topic labels that were defined using
42
+ `topic_model.set_topic_labels`.
43
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
44
+ title: Title of the plot.
45
+ width: The width of the figure.
46
+ height: The height of the figure.
47
+
48
+ Examples:
49
+
50
+ To visualize the topics simply run:
51
+
52
+ ```python
53
+ topic_model.visualize_documents(docs)
54
+ ```
55
+
56
+ Do note that this re-calculates the embeddings and reduces them to 2D.
57
+ The advised and prefered pipeline for using this function is as follows:
58
+
59
+ ```python
60
+ from sklearn.datasets import fetch_20newsgroups
61
+ from sentence_transformers import SentenceTransformer
62
+ from bertopic import BERTopic
63
+ from umap import UMAP
64
+
65
+ # Prepare embeddings
66
+ docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
67
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
68
+ embeddings = sentence_model.encode(docs, show_progress_bar=False)
69
+
70
+ # Train BERTopic
71
+ topic_model = BERTopic().fit(docs, embeddings)
72
+
73
+ # Reduce dimensionality of embeddings, this step is optional
74
+ # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
75
+
76
+ # Run the visualization with the original embeddings
77
+ topic_model.visualize_documents(docs, embeddings=embeddings)
78
+
79
+ # Or, if you have reduced the original embeddings already:
80
+ topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)
81
+ ```
82
+
83
+ Or if you want to save the resulting figure:
84
+
85
+ ```python
86
+ fig = topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)
87
+ fig.write_html("path/to/file.html")
88
+ ```
89
+
90
+ <iframe src="../../getting_started/visualization/documents.html"
91
+ style="width:1000px; height: 800px; border: 0px;""></iframe>
92
+ """
93
+ topic_per_doc = topic_model.topics_
94
+
95
+ # Add <br> tags to hover labels to get them to appear on multiple lines
96
+ def wrap_by_word(s, n):
97
+ '''returns a string where \\n is inserted between every n words'''
98
+ a = s.split()
99
+ ret = ''
100
+ for i in range(0, len(a), n):
101
+ ret += ' '.join(a[i:i+n]) + '<br>'
102
+ return ret
103
+
104
+ # Apply the function to every element in the list
105
+ hover_labels = [wrap_by_word(s, n=20) for s in hover_labels]
106
+
107
+
108
+ # Sample the data to optimize for visualization and dimensionality reduction
109
+ if sample is None or sample > 1:
110
+ sample = 1
111
+
112
+ indices = []
113
+ for topic in set(topic_per_doc):
114
+ s = np.where(np.array(topic_per_doc) == topic)[0]
115
+ size = len(s) if len(s) < 100 else int(len(s) * sample)
116
+ indices.extend(np.random.choice(s, size=size, replace=False))
117
+ indices = np.array(indices)
118
+
119
+ df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
120
+ df["doc"] = [docs[index] for index in indices]
121
+ df["hover_labels"] = [hover_labels[index] for index in indices]
122
+ df["topic"] = [topic_per_doc[index] for index in indices]
123
+
124
+ # Extract embeddings if not already done
125
+ if sample is None:
126
+ if embeddings is None and reduced_embeddings is None:
127
+ embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
128
+ else:
129
+ embeddings_to_reduce = embeddings
130
+ else:
131
+ if embeddings is not None:
132
+ embeddings_to_reduce = embeddings[indices]
133
+ elif embeddings is None and reduced_embeddings is None:
134
+ embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
135
+
136
+ # Reduce input embeddings
137
+ if reduced_embeddings is None:
138
+ umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
139
+ embeddings_2d = umap_model.embedding_
140
+ elif sample is not None and reduced_embeddings is not None:
141
+ embeddings_2d = reduced_embeddings[indices]
142
+ elif sample is None and reduced_embeddings is not None:
143
+ embeddings_2d = reduced_embeddings
144
+
145
+ unique_topics = set(topic_per_doc)
146
+ if topics is None:
147
+ topics = unique_topics
148
+
149
+ # Combine data
150
+ df["x"] = embeddings_2d[:, 0]
151
+ df["y"] = embeddings_2d[:, 1]
152
+
153
+ # Prepare text and names
154
+ if isinstance(custom_labels, str):
155
+ names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in unique_topics]
156
+ names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
157
+ names = [label if len(label) < 30 else label[:27] + "..." for label in names]
158
+ elif topic_model.custom_labels_ is not None and custom_labels:
159
+ names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
160
+ else:
161
+ names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
162
+
163
+ # Visualize
164
+ fig = go.Figure()
165
+
166
+ # Outliers and non-selected topics
167
+ non_selected_topics = set(unique_topics).difference(topics)
168
+ if len(non_selected_topics) == 0:
169
+ non_selected_topics = [-1]
170
+
171
+ selection = df.loc[df.topic.isin(non_selected_topics), :]
172
+ selection["text"] = ""
173
+ selection.loc[len(selection), :] = [None, None, None, selection.x.mean(), selection.y.mean(), "Other documents"]
174
+
175
+ fig.add_trace(
176
+ go.Scattergl(
177
+ x=selection.x,
178
+ y=selection.y,
179
+ hovertext=selection.hover_labels if not hide_document_hover else None,
180
+ hoverinfo="text",
181
+ mode='markers+text',
182
+ name="other",
183
+ showlegend=False,
184
+ marker=dict(color='#CFD8DC', size=5, opacity=0.5),
185
+ hoverlabel=dict(align='left')
186
+ )
187
+ )
188
+
189
+ # Selected topics
190
+ for name, topic in zip(names, unique_topics):
191
+ if topic in topics and topic != -1:
192
+ selection = df.loc[df.topic == topic, :]
193
+ selection["text"] = ""
194
+
195
+ if not hide_annotations:
196
+ selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), name]
197
+
198
+ fig.add_trace(
199
+ go.Scattergl(
200
+ x=selection.x,
201
+ y=selection.y,
202
+ hovertext=selection.hover_labels if not hide_document_hover else None,
203
+ hoverinfo="text",
204
+ text=selection.text,
205
+ mode='markers+text',
206
+ name=name,
207
+ textfont=dict(
208
+ size=12,
209
+ ),
210
+ marker=dict(size=5, opacity=0.5),
211
+ hoverlabel=dict(align='left')
212
+ ))
213
+
214
+ # Add grid in a 'plus' shape
215
+ x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
216
+ y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
217
+ fig.add_shape(type="line",
218
+ x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
219
+ line=dict(color="#CFD8DC", width=2))
220
+ fig.add_shape(type="line",
221
+ x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
222
+ line=dict(color="#9E9E9E", width=2))
223
+ fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
224
+ fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
225
+
226
+ # Stylize layout
227
+ fig.update_layout(
228
+ template="simple_white",
229
+ title={
230
+ 'text': f"{title}",
231
+ 'x': 0.5,
232
+ 'xanchor': 'center',
233
+ 'yanchor': 'top',
234
+ 'font': dict(
235
+ size=22,
236
+ color="Black")
237
+ },
238
+ hoverlabel_align = 'left',
239
+ width=width,
240
+ height=height
241
+ )
242
+
243
+ fig.update_xaxes(visible=False)
244
+ fig.update_yaxes(visible=False)
245
+ return fig
funcs/representation_model.py CHANGED
@@ -168,7 +168,7 @@ def create_representation_model(create_llm_topic_labels, llm_config, hf_model_na
168
  # All representation models
169
  representation_model = {
170
  "KeyBERT": keybert,
171
- "Mistral": llm_model
172
  }
173
 
174
  elif create_llm_topic_labels == "No":
 
168
  # All representation models
169
  representation_model = {
170
  "KeyBERT": keybert,
171
+ "Phi": llm_model
172
  }
173
 
174
  elif create_llm_topic_labels == "No":