Sonnyjim commited on
Commit
b4510a6
1 Parent(s): 1f1a1c7

Lots of general fixes. New visualisations, fixed hierarchical vis for zero shot. Added calc all probabilities.

Browse files
Topic modeller to do.txt DELETED
@@ -1,13 +0,0 @@
1
- Need to add option to anonymise - done
2
-
3
- Need to add option to deduplicate
4
-
5
- Need option to sample for X number of rows with specific seed
6
-
7
- Add plotly visualisation - done
8
-
9
- Add zero shot topic list support
10
-
11
- Add topic renaming with LLMs - done
12
-
13
- Option to predict topics on a new dataset - done (kind of - just save model to file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,4 +1,8 @@
1
  import os
 
 
 
 
2
  import gradio as gr
3
  from datetime import datetime
4
  import pandas as pd
@@ -7,8 +11,6 @@ import time
7
 
8
  from sentence_transformers import SentenceTransformer
9
  from sklearn.feature_extraction.text import CountVectorizer
10
- from transformers import AutoModel, AutoTokenizer
11
- from transformers.pipelines import pipeline
12
  from sklearn.pipeline import make_pipeline
13
  from sklearn.decomposition import TruncatedSVD
14
  from sklearn.feature_extraction.text import TfidfVectorizer
@@ -17,9 +19,13 @@ from umap import UMAP
17
 
18
  from torch import cuda, backends, version
19
 
 
20
  random_seed = 42
21
 
22
  # Check for torch cuda
 
 
 
23
  print("Is CUDA enabled? ", cuda.is_available())
24
  print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
25
  if cuda.is_available():
@@ -33,25 +39,19 @@ else:
33
 
34
  print("Device used is: ", torch_device)
35
 
36
- #os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
37
 
38
- from bertopic import BERTopic
39
- #from sentence_transformers import SentenceTransformer
40
- #from bertopic.backend._hftransformers import HFTransformerBackend
41
 
42
- #from cuml.manifold import UMAP
43
 
44
- #umap_model = UMAP(n_components=5, n_neighbors=15, min_dist=0.0)
45
 
46
  today = datetime.now().strftime("%d%m%Y")
47
  today_rev = datetime.now().strftime("%Y%m%d")
48
 
49
- from funcs.helper_functions import dummy_function, put_columns_in_df, read_file, get_file_path_end, zip_folder, delete_files_in_folder
50
  #from funcs.representation_model import representation_model
51
  from funcs.embeddings import make_or_load_embeddings
52
 
53
  # Log terminal output: https://github.com/gradio-app/gradio/issues/2362
54
-
55
  import sys
56
 
57
  class Logger:
@@ -78,89 +78,42 @@ def read_logs():
78
  return f.read()
79
 
80
  # Load embeddings
 
81
 
 
82
  # Pinning a Jina revision for security purposes: https://www.baseten.co/blog/pinning-ml-model-revisions-for-compatibility-and-security/
83
  # Save Jina model locally as described here: https://huggingface.co/jinaai/jina-embeddings-v2-base-en/discussions/29
84
- embeddings_name = "BAAI/bge-small-en-v1.5" #"jinaai/jina-embeddings-v2-base-en"
85
  # local_embeddings_location = "model/jina/"
86
  #revision_choice = "b811f03af3d4d7ea72a7c25c802b21fc675a5d99"
87
  #revision_choice = "69d43700292701b06c24f43b96560566a4e5ad1f"
88
 
89
  # Model used for representing topics
90
- hf_model_name = 'second-state/stablelm-2-zephyr-1.6b-GGUF' #'TheBloke/phi-2-orange-GGUF' #'NousResearch/Nous-Capybara-7B-V1.9-GGUF' # 'second-state/stablelm-2-zephyr-1.6b-GGUF'
91
- hf_model_file = 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf' # 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf' # 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf'
92
-
93
- def save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model, progress=gr.Progress()):
94
- topic_dets = topic_model.get_topic_info()
95
-
96
- if topic_dets.shape[0] == 1:
97
- topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
98
- topic_dets.to_csv(topic_det_output_name)
99
- output_list.append(topic_det_output_name)
100
-
101
- return output_list, "No topics found, original file returned"
102
-
103
-
104
- progress(0.8, desc= "Saving output")
105
-
106
- topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
107
- topic_dets.to_csv(topic_det_output_name)
108
- output_list.append(topic_det_output_name)
109
-
110
- doc_det_output_name = "doc_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
111
- doc_dets = topic_model.get_document_info(docs)[["Document", "Topic", "Name", "Representative_document"]] # "Probability",
112
- doc_dets.to_csv(doc_det_output_name)
113
- output_list.append(doc_det_output_name)
114
-
115
- topics_text_out_str = str(topic_dets["Name"])
116
- output_text = "Topics: " + topics_text_out_str
117
-
118
- # Save topic model to file
119
- if save_topic_model == "Yes":
120
- topic_model_save_name_pkl = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev + ".pkl"# + ".safetensors"
121
- topic_model_save_name_zip = topic_model_save_name_pkl + ".zip"
122
-
123
- # Clear folder before replacing files
124
- #delete_files_in_folder(topic_model_save_name_pkl)
125
-
126
- topic_model.save(topic_model_save_name_pkl, serialization='pickle', save_embedding_model=False, save_ctfidf=False)
127
 
128
- # Zip file example
129
-
130
- #zip_folder(topic_model_save_name_pkl, topic_model_save_name_zip)
131
- output_list.append(topic_model_save_name_pkl)
132
-
133
- return output_list, output_text
134
-
135
- def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, in_label, anonymise_drop, return_intermediate_files, embeddings_super_compress, low_resource_mode, save_topic_model, embeddings_out, zero_shot_similarity, progress=gr.Progress()):
136
 
137
  progress(0, desc= "Loading data")
138
 
139
- if not in_colnames or not in_label:
140
- error_message = "Please enter one column name for the topics and another for the labelling."
 
 
 
 
 
 
141
  print(error_message)
142
- return error_message, None, None, embeddings_out
143
 
144
  all_tic = time.perf_counter()
145
 
146
  output_list = []
147
  file_list = [string.name for string in in_files]
148
 
149
- data_file_names = [string.lower() for string in file_list if "tokenised" not in string and "npz" not in string.lower() and "gz" not in string.lower()]
150
- data_file_name = data_file_names[0]
151
- data_file_name_no_ext = get_file_path_end(data_file_name)
152
-
153
  in_colnames_list_first = in_colnames[0]
154
 
155
- if in_label:
156
- in_label_list_first = in_label[0]
157
- else:
158
- in_label_list_first = in_colnames_list_first
159
-
160
- # Make sure format of input series is good
161
- data[in_colnames_list_first] = data[in_colnames_list_first].fillna('').astype(str)
162
- data[in_label_list_first] = data[in_label_list_first].fillna('').astype(str)
163
- label_list = list(data[in_label_list_first])
164
 
165
  if anonymise_drop == "Yes":
166
  progress(0.1, desc= "Anonymising data")
@@ -172,12 +125,11 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
172
  data.to_csv(anonymise_data_name)
173
  output_list.append(anonymise_data_name)
174
 
 
 
175
  anon_toc = time.perf_counter()
176
  time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
177
 
178
- docs = list(data[in_colnames_list_first].str.lower())
179
-
180
-
181
  # Check if embeddings are being loaded in
182
  progress(0.2, desc= "Loading/creating embeddings")
183
 
@@ -185,10 +137,10 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
185
 
186
  if low_resource_mode == "No":
187
  print("Using high resource BGE transformer model")
188
-
189
-
190
 
191
  embedding_model = SentenceTransformer(embeddings_name)
 
 
192
  #try:
193
  #embedding_model = AutoModel.from_pretrained(embeddings_name, revision = revision_choice, trust_remote_code=True,device_map="auto") # For Jina
194
  #except:
@@ -210,11 +162,15 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
210
 
211
  umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
212
 
213
-
214
-
215
- embeddings_out, reduced_embeddings = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
216
 
217
  vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
 
 
 
 
 
 
218
 
219
  progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
220
 
@@ -225,17 +181,18 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
225
  umap_model=umap_model,
226
  min_topic_size = min_docs_slider,
227
  nr_topics = max_topics_slider,
 
 
228
  verbose = True)
229
 
230
- topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
231
 
232
- if not topics_text:
233
- # Handle the empty array case
234
 
235
- return "No topics found.", data_file_name, None, embeddings_out, data_file_name_no_ext, topic_model, docs, label_list
236
-
237
- else:
238
- print("Topic model created.")
239
 
240
 
241
  # Do this if you have pre-defined topics
@@ -244,11 +201,13 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
244
  error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
245
  print(error_message)
246
 
247
- return error_message, output_list, None, embeddings_out, data_file_name_no_ext, None, docs, label_list
248
 
249
  zero_shot_topics = read_file(candidate_topics.name)
250
  zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
251
 
 
 
252
  topic_model = BERTopic( embedding_model=embedding_model, #embedding_model_pipe, # for Jina
253
  vectorizer_model=vectoriser_model,
254
  umap_model=umap_model,
@@ -256,19 +215,51 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
256
  nr_topics = max_topics_slider,
257
  zeroshot_topic_list = zero_shot_topics_lower,
258
  zeroshot_min_similarity = zero_shot_similarity, # 0.7
 
 
259
  verbose = True)
260
 
261
- topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
262
 
263
- # print(topics_text)
 
 
 
264
 
265
- if topics_text.size == 0:
266
- # Handle the empty array case
267
 
268
- return "No topics found.", data_file_name, None, embeddings_out, data_file_name_no_ext, topic_model, docs, label_list
269
-
270
- else:
271
- print("Topic model created.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  # Outputs
274
  output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
@@ -292,37 +283,40 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
292
  time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
293
  print(time_out)
294
 
295
- return output_text, output_list, None, embeddings_out, data_file_name_no_ext, topic_model, docs, label_list
296
 
297
- def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, low_resource_mode, create_llm_topic_labels, save_topic_model, progress=gr.Progress()):
298
- #from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
299
- from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
300
 
301
  output_list = []
302
 
303
  all_tic = time.perf_counter()
304
 
305
- vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
306
 
307
- topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
 
308
 
309
  #progress(0.2, desc= "Loading in representation model")
310
  #print("Create LLM topic labels:", create_llm_topic_labels)
 
311
  #representation_model = create_representation_model(create_llm_topic_labels, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
312
 
313
  # Reduce outliers if required, then update representation
314
  progress(0.2, desc= "Reducing outliers")
315
  print("Reducing outliers.")
316
  # Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
317
- topics_text = topic_model.reduce_outliers(docs, topics_text, strategy="embeddings")
318
  # Then, update the topics to the ones that considered the new data
319
 
320
  print("Finished reducing outliers.")
321
 
322
- progress(0.5, desc= "Creating topic representations")
323
- print("Create LLM topic labels:", "No")
324
- representation_model = create_representation_model("No", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
325
- topic_model.update_topics(docs, topics=topics_text, vectorizer_model=vectoriser_model, representation_model=representation_model)
 
326
 
327
  topic_dets = topic_model.get_topic_info()
328
 
@@ -334,15 +328,16 @@ def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, lo
334
  topic_model.set_topic_labels(list(topic_dets["Name"]))
335
 
336
  # Outputs
 
337
  output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
338
 
339
  all_toc = time.perf_counter()
340
  time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
341
  print(time_out)
342
 
343
- return output_text, output_list, embeddings_out
344
 
345
- def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, low_resource_mode, save_topic_model, progress=gr.Progress()):
346
  #from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
347
  from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
348
 
@@ -352,48 +347,76 @@ def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, l
352
 
353
  vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
354
 
355
- topics_text, probs = topic_model.fit_transform(docs, embeddings_out)
356
 
357
  topic_dets = topic_model.get_topic_info()
358
 
359
- progress(0.2, desc= "Creating topic representations")
360
  print("Create LLM topic labels:", "Yes")
361
  representation_model = create_representation_model("Yes", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
362
 
363
- topic_model.update_topics(docs, topics=topics_text, vectorizer_model=vectoriser_model, representation_model=representation_model)
364
 
365
  # Replace original labels with LLM labels
366
  if "LLM" in topic_model.get_topic_info().columns:
367
  llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
368
  topic_model.set_topic_labels(llm_labels)
369
 
370
- with open('llm_topic_list.csv', 'w') as file:
371
- for item in llm_labels:
372
- file.write(f"{item}\n")
373
- output_list.append('llm_topic_list.csv')
 
 
 
 
 
374
  else:
375
  topic_model.set_topic_labels(list(topic_dets["Name"]))
376
 
377
-
378
-
379
- # Outputs
380
  output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
381
 
382
  all_toc = time.perf_counter()
383
  time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
384
  print(time_out)
385
 
386
- return output_text, output_list, embeddings_out
 
 
 
 
387
 
388
- def visualise_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode, embeddings_out, label_list, sample_prop, visualisation_type_radio, progress=gr.Progress()):
389
  output_list = []
390
  vis_tic = time.perf_counter()
391
 
392
- from funcs.bertopic_vis_documents import visualize_documents_custom
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
  topic_dets = topic_model.get_topic_info()
395
 
396
- # Replace original labels with LLM labels
397
  if "LLM" in topic_model.get_topic_info().columns:
398
  llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
399
  topic_model.set_topic_labels(llm_labels)
@@ -414,16 +437,37 @@ def visualise_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode
414
  # "Topic document graph", "Hierarchical view"
415
 
416
  if visualisation_type_radio == "Topic document graph":
417
- topics_vis = visualize_documents_custom(topic_model, docs, hover_labels = label_list, reduced_embeddings=reduced_embeddings, hide_annotations=True, hide_document_hover=False, custom_labels=True, sample = sample_prop)
418
 
419
- topics_vis_name = data_file_name_no_ext + '_' + 'visualisation_' + today_rev + '.html'
420
  topics_vis.write_html(topics_vis_name)
421
  output_list.append(topics_vis_name)
422
 
 
 
 
 
 
 
423
  elif visualisation_type_radio == "Hierarchical view":
 
 
 
 
 
 
 
424
  hierarchical_topics = topic_model.hierarchical_topics(docs)
425
- topics_vis = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop)
426
- topics_vis_2 = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics)
 
 
 
 
 
 
 
 
427
 
428
  topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
429
  topics_vis.write_html(topics_vis_name)
@@ -433,24 +477,22 @@ def visualise_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode
433
  topics_vis_2.write_html(topics_vis_2_name)
434
  output_list.append(topics_vis_2_name)
435
 
436
- # Save new hierarchical topic model to file
437
- import pandas as pd
438
- hierarchical_topics_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topics' + today_rev + '.csv'
439
- hierarchical_topics.to_csv(hierarchical_topics_name)
440
- output_list.append(hierarchical_topics_name)
441
- #output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
442
-
443
-
444
-
445
  all_toc = time.perf_counter()
446
  time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
447
  print(time_out)
448
 
449
- return time_out, output_list, topics_vis, embeddings_out
 
 
 
 
 
 
 
450
 
451
- def save_as_pytorch_model(topic_model, docs, data_file_name_no_ext , progress=gr.Progress()):
452
  output_list = []
453
 
 
454
  topic_model_save_name_folder = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev# + ".safetensors"
455
  topic_model_save_name_zip = topic_model_save_name_folder + ".zip"
456
 
@@ -464,6 +506,8 @@ def save_as_pytorch_model(topic_model, docs, data_file_name_no_ext , progress=gr
464
  zip_folder(topic_model_save_name_folder, topic_model_save_name_zip)
465
  output_list.append(topic_model_save_name_zip)
466
 
 
 
467
  # Gradio app
468
 
469
  block = gr.Blocks(theme = gr.themes.Base())
@@ -475,7 +519,7 @@ with block:
475
  topic_model_state = gr.State()
476
  docs_state = gr.State()
477
  data_file_name_no_ext_state = gr.State()
478
- label_list_state = gr.State()
479
 
480
  gr.Markdown(
481
  """
@@ -489,8 +533,7 @@ with block:
489
  with gr.Accordion("Load data file", open = True):
490
  in_files = gr.File(label="Input text from file", file_count="multiple")
491
  with gr.Row():
492
- in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
493
- in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in the output visualisation.")
494
 
495
  with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
496
  candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
@@ -511,41 +554,48 @@ with block:
511
  with gr.Row():
512
  reduce_outliers_btn = gr.Button("Reduce outliers")
513
  represent_llm_btn = gr.Button("Generate topic labels with LLMs")
 
514
 
515
  #logs = gr.Textbox(label="Processing logs.")
516
-
517
-
518
 
519
  with gr.Tab("Visualise"):
520
-
521
- sample_slide = gr.Slider(minimum = 0.01, maximum = 1, value = 0.1, step = 0.01, label = "Proportion of data points to show on output visualisation.")
522
- visualisation_type_radio = gr.Radio(choices=["Topic document graph", "Hierarchical view"])
 
523
  plot_btn = gr.Button("Visualise topic model")
524
- out_plot_file = gr.File(label="Output plots to file", file_count="multiple")
525
- plot = gr.Plot(label="Visualise your topics here. Go to the 'Options' tab to enable.")
 
 
 
 
526
 
527
  with gr.Tab("Options"):
528
  with gr.Accordion("Data load and processing options", open = True):
529
  with gr.Row():
530
  anonymise_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Anonymise data on file load. Names and other details are replaced with tags e.g. '<person>'.")
531
  embedding_super_compress = gr.Dropdown(label = "Round embeddings to three dp for smaller files with less accuracy.", value="No", choices=["Yes", "No"])
532
- #create_llm_topic_labels = gr.Dropdown(label = "Create topic labels based on LLMs.", value="No", choices=["Yes", "No"])
 
533
  with gr.Row():
534
  low_resource_mode_opt = gr.Dropdown(label = "Use low resource embeddings and processing.", value="No", choices=["Yes", "No"])
535
- return_intermediate_files = gr.Dropdown(label = "Return intermediate processing files from file preparation. Files can be loaded in to save processing time in future.", value="Yes", choices=["Yes", "No"])
536
  save_topic_model = gr.Dropdown(label = "Save topic model to file.", value="Yes", choices=["Yes", "No"])
537
 
538
  # Update column names dropdown when file uploaded
539
- in_files.upload(fn=put_columns_in_df, inputs=[in_files], outputs=[in_colnames, in_label, data_state, embeddings_state, output_single_text, topic_model_state])
540
  in_colnames.change(dummy_function, in_colnames, None)
541
 
542
- topics_btn.click(fn=extract_topics, inputs=[data_state, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, in_label, anonymise_drop, return_intermediate_files, embedding_super_compress, low_resource_mode_opt, save_topic_model, embeddings_state, zero_shot_similarity], outputs=[output_single_text, output_file, plot, embeddings_state, data_file_name_no_ext_state, topic_model_state, docs_state, label_list_state], api_name="topics")
 
 
543
 
544
- reduce_outliers_btn.click(fn=reduce_outliers, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, low_resource_mode_opt], outputs=[output_single_text, output_file, embeddings_state], api_name="reduce_outliers")
545
 
546
- represent_llm_btn.click(fn=represent_topics, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, low_resource_mode_opt], outputs=[output_single_text, output_file, embeddings_state], api_name="represent_llm")
547
 
548
- plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, docs_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, label_list_state, sample_slide, visualisation_type_radio], outputs=[output_single_text, out_plot_file, plot], api_name="plot")
549
 
550
  #block.load(read_logs, None, logs, every=5)
551
 
 
1
  import os
2
+
3
+ # Dendrograms will not work with the latest version of scipy (1.12.0), so installing the version prior to be safe
4
+ os.system("pip install scipy==1.11.4")
5
+
6
  import gradio as gr
7
  from datetime import datetime
8
  import pandas as pd
 
11
 
12
  from sentence_transformers import SentenceTransformer
13
  from sklearn.feature_extraction.text import CountVectorizer
 
 
14
  from sklearn.pipeline import make_pipeline
15
  from sklearn.decomposition import TruncatedSVD
16
  from sklearn.feature_extraction.text import TfidfVectorizer
 
19
 
20
  from torch import cuda, backends, version
21
 
22
+ # Default seed, can be changed in number selection on options page
23
  random_seed = 42
24
 
25
  # Check for torch cuda
26
+ # If you want to disable cuda for testing purposes
27
+ #os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
28
+
29
  print("Is CUDA enabled? ", cuda.is_available())
30
  print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
31
  if cuda.is_available():
 
39
 
40
  print("Device used is: ", torch_device)
41
 
 
42
 
 
 
 
43
 
44
+ from bertopic import BERTopic
45
 
 
46
 
47
  today = datetime.now().strftime("%d%m%Y")
48
  today_rev = datetime.now().strftime("%Y%m%d")
49
 
50
+ from funcs.helper_functions import dummy_function, initial_file_load, read_file, zip_folder, delete_files_in_folder, save_topic_outputs
51
  #from funcs.representation_model import representation_model
52
  from funcs.embeddings import make_or_load_embeddings
53
 
54
  # Log terminal output: https://github.com/gradio-app/gradio/issues/2362
 
55
  import sys
56
 
57
  class Logger:
 
78
  return f.read()
79
 
80
  # Load embeddings
81
+ embeddings_name = "BAAI/bge-small-en-v1.5" #"jinaai/jina-embeddings-v2-base-en"
82
 
83
+ # Use of Jina deprecated - kept here for posterity
84
  # Pinning a Jina revision for security purposes: https://www.baseten.co/blog/pinning-ml-model-revisions-for-compatibility-and-security/
85
  # Save Jina model locally as described here: https://huggingface.co/jinaai/jina-embeddings-v2-base-en/discussions/29
 
86
  # local_embeddings_location = "model/jina/"
87
  #revision_choice = "b811f03af3d4d7ea72a7c25c802b21fc675a5d99"
88
  #revision_choice = "69d43700292701b06c24f43b96560566a4e5ad1f"
89
 
90
  # Model used for representing topics
91
+ hf_model_name = 'second-state/stablelm-2-zephyr-1.6b-GGUF' #'TheBloke/phi-2-orange-GGUF' #'NousResearch/Nous-Capybara-7B-V1.9-GGUF'
92
+ hf_model_file = 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf' # 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext, custom_labels_df, anonymise_drop, return_intermediate_files, embeddings_super_compress, low_resource_mode, save_topic_model, embeddings_out, zero_shot_similarity, random_seed, calc_probs, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
95
 
96
  progress(0, desc= "Loading data")
97
 
98
+ if calc_probs == "No":
99
+ calc_probs = False
100
+ elif calc_probs == "Yes":
101
+ print("Calculating all probabilities.")
102
+ calc_probs == True
103
+
104
+ if not in_colnames:
105
+ error_message = "Please enter one column name to use to find topics."
106
  print(error_message)
107
+ return error_message, None, embeddings_out, data_file_name_no_ext, None, None
108
 
109
  all_tic = time.perf_counter()
110
 
111
  output_list = []
112
  file_list = [string.name for string in in_files]
113
 
 
 
 
 
114
  in_colnames_list_first = in_colnames[0]
115
 
116
+ docs = list(data[in_colnames_list_first].str.lower())
 
 
 
 
 
 
 
 
117
 
118
  if anonymise_drop == "Yes":
119
  progress(0.1, desc= "Anonymising data")
 
125
  data.to_csv(anonymise_data_name)
126
  output_list.append(anonymise_data_name)
127
 
128
+ print(anonymisation_success)
129
+
130
  anon_toc = time.perf_counter()
131
  time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
132
 
 
 
 
133
  # Check if embeddings are being loaded in
134
  progress(0.2, desc= "Loading/creating embeddings")
135
 
 
137
 
138
  if low_resource_mode == "No":
139
  print("Using high resource BGE transformer model")
 
 
140
 
141
  embedding_model = SentenceTransformer(embeddings_name)
142
+
143
+ # Use of Jina now superseded by BGE, keeping this code just in case I consider reverting one day
144
  #try:
145
  #embedding_model = AutoModel.from_pretrained(embeddings_name, revision = revision_choice, trust_remote_code=True,device_map="auto") # For Jina
146
  #except:
 
162
 
163
  umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
164
 
165
+ embeddings_out = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
 
 
166
 
167
  vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
168
+
169
+ # Representation model not currently used in this function
170
+ #print("Create Keybert-like topic representations by default")
171
+ #from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
172
+ #representation_model = create_representation_model("No", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
173
+
174
 
175
  progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
176
 
 
181
  umap_model=umap_model,
182
  min_topic_size = min_docs_slider,
183
  nr_topics = max_topics_slider,
184
+ calculate_probabilities=calc_probs,
185
+ #representation_model=representation_model,
186
  verbose = True)
187
 
188
+ assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
189
 
190
+ #print(assigned_topics)
 
191
 
192
+ # Replace original labels with Keybert labels
193
+ #if "KeyBERT" in topic_model.get_topic_info().columns:
194
+ # keybert_labels = [f"{i+1}: {', '.join(entry[:5])}" for i, entry in enumerate(topic_model.get_topics(full=True)["KeyBERT"].values())]
195
+ # topic_model.set_topic_labels(keybert_labels)
196
 
197
 
198
  # Do this if you have pre-defined topics
 
201
  error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
202
  print(error_message)
203
 
204
+ return error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs
205
 
206
  zero_shot_topics = read_file(candidate_topics.name)
207
  zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
208
 
209
+
210
+
211
  topic_model = BERTopic( embedding_model=embedding_model, #embedding_model_pipe, # for Jina
212
  vectorizer_model=vectoriser_model,
213
  umap_model=umap_model,
 
215
  nr_topics = max_topics_slider,
216
  zeroshot_topic_list = zero_shot_topics_lower,
217
  zeroshot_min_similarity = zero_shot_similarity, # 0.7
218
+ calculate_probabilities=calc_probs,
219
+ #representation_model=representation_model,
220
  verbose = True)
221
 
222
+ assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
223
 
224
+ # For some reason, zero topic modelling exports assigned topics as a np.array instead of a list. Converting it back here.
225
+ if isinstance(assigned_topics, np.ndarray):
226
+ assigned_topics = assigned_topics.tolist()
227
+ #print(assigned_topics.tolist())
228
 
229
+ # Zero shot modelling is a model merge, which wipes the c_tf_idf part of the resulting model completely. To get hierarchical modelling to work, we need to recreate this part of the model with the CountVectorizer options used to create the initial model. Since with zero shot, we are merging two models that have exactly the same set of documents, the vocubulary should be the same, and so recreating the cf_tf_idf component in this way shouldn't be a problem. Discussion here, and below based on Maarten's suggested code: https://github.com/MaartenGr/BERTopic/issues/1700
 
230
 
231
+ doc_dets = topic_model.get_document_info(docs)
232
+
233
+ documents_per_topic = doc_dets.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
234
+
235
+ # Assign CountVectorizer to merged model
236
+
237
+ topic_model.vectorizer_model = vectoriser_model
238
+
239
+ # Re-calculate c-TF-IDF
240
+ c_tf_idf, _ = topic_model._c_tf_idf(documents_per_topic)
241
+ topic_model.c_tf_idf_ = c_tf_idf
242
+
243
+ # Replace original labels with Keybert labels
244
+ #if "KeyBERT" in topic_model.get_topic_info().columns:
245
+ # print(topic_model.get_topics(full=True)["KeyBERT"].values())
246
+ # keybert_labels = [f"{i+1}: {', '.join(entry[:5])}" for i, entry in enumerate(topic_model.get_topics(full=True)["KeyBERT"].values())]
247
+ # topic_model.set_topic_labels(keybert_labels)
248
+
249
+ if not assigned_topics:
250
+ # Handle the empty array case
251
+ return "No topics found.", output_list, embeddings_out, data_file_name_no_ext, topic_model, docs
252
+
253
+ else:
254
+ print("Topic model created.")
255
+
256
+ if not custom_labels_df.empty:
257
+ #print(custom_labels_df.shape)
258
+
259
+ #topic_dets = topic_model.get_topic_info()
260
+ #print(topic_dets.shape)
261
+
262
+ topic_model.set_topic_labels(list(custom_labels_df.iloc[:,0]))
263
 
264
  # Outputs
265
  output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
 
283
  time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
284
  print(time_out)
285
 
286
+ return output_text, output_list, embeddings_out, data_file_name_no_ext, topic_model, docs
287
 
288
+ def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, save_topic_model, progress=gr.Progress(track_tqdm=True)):
289
+
290
+ progress(0, desc= "Preparing data")
291
 
292
  output_list = []
293
 
294
  all_tic = time.perf_counter()
295
 
296
+ assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
297
 
298
+ if isinstance(assigned_topics, np.ndarray):
299
+ assigned_topics = assigned_topics.tolist()
300
 
301
  #progress(0.2, desc= "Loading in representation model")
302
  #print("Create LLM topic labels:", create_llm_topic_labels)
303
+ #from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
304
  #representation_model = create_representation_model(create_llm_topic_labels, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
305
 
306
  # Reduce outliers if required, then update representation
307
  progress(0.2, desc= "Reducing outliers")
308
  print("Reducing outliers.")
309
  # Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
310
+ assigned_topics = topic_model.reduce_outliers(docs, assigned_topics, strategy="embeddings")
311
  # Then, update the topics to the ones that considered the new data
312
 
313
  print("Finished reducing outliers.")
314
 
315
+ progress(0.7, desc= "Replacing topic names with LLMs if necessary")
316
+ #print("Create LLM topic labels:", "No")
317
+ #vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
318
+ #representation_model = create_representation_model("No", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
319
+ #topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model, representation_model=representation_model)
320
 
321
  topic_dets = topic_model.get_topic_info()
322
 
 
328
  topic_model.set_topic_labels(list(topic_dets["Name"]))
329
 
330
  # Outputs
331
+ progress(0.9, desc= "Saving to file")
332
  output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
333
 
334
  all_toc = time.perf_counter()
335
  time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
336
  print(time_out)
337
 
338
+ return output_text, output_list, topic_model
339
 
340
+ def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, low_resource_mode, save_topic_model, progress=gr.Progress(track_tqdm=True)):
341
  #from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
342
  from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
343
 
 
347
 
348
  vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
349
 
350
+ assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
351
 
352
  topic_dets = topic_model.get_topic_info()
353
 
354
+ progress(0.1, desc= "Loading LLM model")
355
  print("Create LLM topic labels:", "Yes")
356
  representation_model = create_representation_model("Yes", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
357
 
358
+ topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model, representation_model=representation_model)
359
 
360
  # Replace original labels with LLM labels
361
  if "LLM" in topic_model.get_topic_info().columns:
362
  llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
363
  topic_model.set_topic_labels(llm_labels)
364
 
365
+ label_list_file_name = data_file_name_no_ext + '_llm_topic_list_' + today_rev + '.csv'
366
+
367
+ llm_labels_df = pd.DataFrame(data={"Label":llm_labels})
368
+ llm_labels_df.to_csv(label_list_file_name, index=None)
369
+ #with open(label_list_file_name, 'w') as file:
370
+ # file.write(f"Label\n")
371
+ # for item in llm_labels:
372
+ # file.write(f"{item}\n")
373
+ output_list.append(label_list_file_name)
374
  else:
375
  topic_model.set_topic_labels(list(topic_dets["Name"]))
376
 
377
+ # Outputs
378
+ progress(0.8, desc= "Saving outputs")
 
379
  output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
380
 
381
  all_toc = time.perf_counter()
382
  time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
383
  print(time_out)
384
 
385
+ return output_text, output_list, topic_model
386
+
387
+ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress()):
388
+
389
+ progress(0, desc= "Preparing data for visualisation")
390
 
 
391
  output_list = []
392
  vis_tic = time.perf_counter()
393
 
394
+ from funcs.bertopic_vis_documents import visualize_documents_custom, visualize_hierarchical_documents_custom, visualize_barchart_custom
395
+
396
+ if not visualisation_type_radio:
397
+ return "Please choose a visualisation type above.", output_list, None, None
398
+
399
+ # Get topic labels
400
+ if in_label:
401
+ in_label_list_first = in_label[0]
402
+ else:
403
+ return "Label column not found. Please enter this above.", output_list, None, None
404
+
405
+ # Get docs
406
+ if in_colnames:
407
+ in_colnames_list_first = in_colnames[0]
408
+ else:
409
+ return "Label column not found. Please enter this on the data load tab.", output_list, None, None
410
+
411
+ docs = list(data[in_colnames_list_first].str.lower())
412
+
413
+ # Make sure format of input series is good
414
+ data[in_label_list_first] = data[in_label_list_first].fillna('').astype(str)
415
+ label_list = list(data[in_label_list_first])
416
 
417
  topic_dets = topic_model.get_topic_info()
418
 
419
+ # Replace original labels with LLM labels if they exist, or go with the 'Name' column
420
  if "LLM" in topic_model.get_topic_info().columns:
421
  llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
422
  topic_model.set_topic_labels(llm_labels)
 
437
  # "Topic document graph", "Hierarchical view"
438
 
439
  if visualisation_type_radio == "Topic document graph":
440
+ topics_vis = visualize_documents_custom(topic_model, docs, hover_labels = label_list, reduced_embeddings=reduced_embeddings, hide_annotations=True, hide_document_hover=False, custom_labels=True, sample = sample_prop, width= 1200, height = 750)
441
 
442
+ topics_vis_name = data_file_name_no_ext + '_' + 'vis_topic_docs_' + today_rev + '.html'
443
  topics_vis.write_html(topics_vis_name)
444
  output_list.append(topics_vis_name)
445
 
446
+ topics_vis_2 = visualize_barchart_custom(topic_model, top_n_topics = 12, custom_labels=True, width= 300, height = 250)
447
+
448
+ topics_vis_2_name = data_file_name_no_ext + '_' + 'vis_barchart_' + today_rev + '.html'
449
+ topics_vis_2.write_html(topics_vis_2_name)
450
+ output_list.append(topics_vis_2_name)
451
+
452
  elif visualisation_type_radio == "Hierarchical view":
453
+
454
+ # Check that original topics are retained
455
+ #new_topic_dets = topic_model.get_topic_info()
456
+ #new_topic_dets.to_csv("new_topic_dets.csv")
457
+
458
+ #from funcs.bertopic_hierarchical_topics_mod import hierarchical_topics_mod
459
+
460
  hierarchical_topics = topic_model.hierarchical_topics(docs)
461
+
462
+ # Save new hierarchical topic model to file
463
+ hierarchical_topics_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topics_' + today_rev + '.csv'
464
+ hierarchical_topics.to_csv(hierarchical_topics_name)
465
+ output_list.append(hierarchical_topics_name)
466
+
467
+ #hierarchical_topics = hierarchical_topics_mod(topic_model, docs)
468
+ topics_vis = visualize_hierarchical_documents_custom(topic_model, docs, label_list, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
469
+ #topics_vis = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
470
+ topics_vis_2 = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics, width= 1200, height = 750)
471
 
472
  topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
473
  topics_vis.write_html(topics_vis_name)
 
477
  topics_vis_2.write_html(topics_vis_2_name)
478
  output_list.append(topics_vis_2_name)
479
 
 
 
 
 
 
 
 
 
 
480
  all_toc = time.perf_counter()
481
  time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
482
  print(time_out)
483
 
484
+ return time_out, output_list, topics_vis, topics_vis_2
485
+
486
+ def save_as_pytorch_model(topic_model, data_file_name_no_ext , progress=gr.Progress()):
487
+
488
+ if not topic_model:
489
+ return "No Pytorch model found.", None
490
+
491
+ progress(0, desc= "Saving topic model in Pytorch format")
492
 
 
493
  output_list = []
494
 
495
+
496
  topic_model_save_name_folder = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev# + ".safetensors"
497
  topic_model_save_name_zip = topic_model_save_name_folder + ".zip"
498
 
 
506
  zip_folder(topic_model_save_name_folder, topic_model_save_name_zip)
507
  output_list.append(topic_model_save_name_zip)
508
 
509
+ return "Model saved in Pytorch format.", output_list
510
+
511
  # Gradio app
512
 
513
  block = gr.Blocks(theme = gr.themes.Base())
 
519
  topic_model_state = gr.State()
520
  docs_state = gr.State()
521
  data_file_name_no_ext_state = gr.State()
522
+ label_list_state = gr.State(pd.DataFrame())
523
 
524
  gr.Markdown(
525
  """
 
533
  with gr.Accordion("Load data file", open = True):
534
  in_files = gr.File(label="Input text from file", file_count="multiple")
535
  with gr.Row():
536
+ in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
 
537
 
538
  with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
539
  candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
 
554
  with gr.Row():
555
  reduce_outliers_btn = gr.Button("Reduce outliers")
556
  represent_llm_btn = gr.Button("Generate topic labels with LLMs")
557
+ save_pytorch_btn = gr.Button("Save model in Pytorch format")
558
 
559
  #logs = gr.Textbox(label="Processing logs.")
 
 
560
 
561
  with gr.Tab("Visualise"):
562
+ with gr.Row():
563
+ in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in output visualisations.")
564
+ visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
565
+ sample_slide = gr.Slider(minimum = 0.01, maximum = 1, value = 0.1, step = 0.01, label = "Proportion of data points to show on output visualisations.")
566
  plot_btn = gr.Button("Visualise topic model")
567
+ with gr.Row():
568
+ vis_output_single_text = gr.Textbox(label="Visualisation output text")
569
+ out_plot_file = gr.File(label="Output plots to file", file_count="multiple")
570
+ plot = gr.Plot(label="Visualise your topics here.")
571
+ plot_2 = gr.Plot(label="Visualise your topics here.")
572
+
573
 
574
  with gr.Tab("Options"):
575
  with gr.Accordion("Data load and processing options", open = True):
576
  with gr.Row():
577
  anonymise_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Anonymise data on file load. Names and other details are replaced with tags e.g. '<person>'.")
578
  embedding_super_compress = gr.Dropdown(label = "Round embeddings to three dp for smaller files with less accuracy.", value="No", choices=["Yes", "No"])
579
+ seed_number = gr.Number(label="Random seed to use for dimensionality reduction.", minimum=0, step=1, value=42, precision=0)
580
+ calc_probs = gr.Dropdown(label="Calculate all topic probabilities (i.e. a separate document prob. value for each topic)", value="No", choices=["Yes", "No"])
581
  with gr.Row():
582
  low_resource_mode_opt = gr.Dropdown(label = "Use low resource embeddings and processing.", value="No", choices=["Yes", "No"])
583
+ return_intermediate_files = gr.Dropdown(label = "Return intermediate processing files from file preparation.", value="Yes", choices=["Yes", "No"])
584
  save_topic_model = gr.Dropdown(label = "Save topic model to file.", value="Yes", choices=["Yes", "No"])
585
 
586
  # Update column names dropdown when file uploaded
587
+ in_files.upload(fn=initial_file_load, inputs=[in_files], outputs=[in_colnames, in_label, data_state, output_single_text, topic_model_state, embeddings_state, data_file_name_no_ext_state, label_list_state])
588
  in_colnames.change(dummy_function, in_colnames, None)
589
 
590
+ topics_btn.click(fn=extract_topics, inputs=[data_state, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext_state, label_list_state, anonymise_drop, return_intermediate_files, embedding_super_compress, low_resource_mode_opt, save_topic_model, embeddings_state, zero_shot_similarity, seed_number, calc_probs], outputs=[output_single_text, output_file, embeddings_state, data_file_name_no_ext_state, topic_model_state, docs_state], api_name="topics")
591
+
592
+ reduce_outliers_btn.click(fn=reduce_outliers, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="reduce_outliers")
593
 
594
+ represent_llm_btn.click(fn=represent_topics, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, low_resource_mode_opt, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="represent_llm")
595
 
596
+ save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
597
 
598
+ plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
599
 
600
  #block.load(read_logs, None, logs, every=5)
601
 
funcs/anonymiser.py CHANGED
@@ -1,7 +1,6 @@
1
  from spacy.cli import download
2
  import spacy
3
  spacy.prefer_gpu()
4
- import os
5
 
6
  def spacy_model_installed(model_name):
7
  try:
 
1
  from spacy.cli import download
2
  import spacy
3
  spacy.prefer_gpu()
 
4
 
5
  def spacy_model_installed(model_name):
6
  try:
funcs/bertopic_vis_documents.py CHANGED
@@ -1,10 +1,14 @@
1
  import numpy as np
2
  import pandas as pd
3
  import plotly.graph_objects as go
 
4
 
5
  from umap import UMAP
6
  from typing import List, Union
7
 
 
 
 
8
  # Shamelessly taken and adapted from Bertopic original implementation here (Maarten Grootendorst): https://github.com/MaartenGr/BERTopic/blob/master/bertopic/plotting/_documents.py
9
 
10
  def visualize_documents_custom(topic_model,
@@ -243,3 +247,469 @@ def visualize_documents_custom(topic_model,
243
  fig.update_xaxes(visible=False)
244
  fig.update_yaxes(visible=False)
245
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import pandas as pd
3
  import plotly.graph_objects as go
4
+ from plotly.subplots import make_subplots
5
 
6
  from umap import UMAP
7
  from typing import List, Union
8
 
9
+ import itertools
10
+ import numpy as np
11
+
12
  # Shamelessly taken and adapted from Bertopic original implementation here (Maarten Grootendorst): https://github.com/MaartenGr/BERTopic/blob/master/bertopic/plotting/_documents.py
13
 
14
  def visualize_documents_custom(topic_model,
 
247
  fig.update_xaxes(visible=False)
248
  fig.update_yaxes(visible=False)
249
  return fig
250
+
251
+ def visualize_hierarchical_documents_custom(topic_model,
252
+ docs: List[str],
253
+ hover_labels: List[str],
254
+ hierarchical_topics: pd.DataFrame,
255
+ topics: List[int] = None,
256
+ embeddings: np.ndarray = None,
257
+ reduced_embeddings: np.ndarray = None,
258
+ sample: Union[float, int] = None,
259
+ hide_annotations: bool = False,
260
+ hide_document_hover: bool = True,
261
+ nr_levels: int = 10,
262
+ level_scale: str = 'linear',
263
+ custom_labels: Union[bool, str] = False,
264
+ title: str = "<b>Hierarchical Documents and Topics</b>",
265
+ width: int = 1200,
266
+ height: int = 750) -> go.Figure:
267
+ """ Visualize documents and their topics in 2D at different levels of hierarchy
268
+
269
+ Arguments:
270
+ docs: The documents you used when calling either `fit` or `fit_transform`
271
+ hierarchical_topics: A dataframe that contains a hierarchy of topics
272
+ represented by their parents and their children
273
+ topics: A selection of topics to visualize.
274
+ Not to be confused with the topics that you get from `.fit_transform`.
275
+ For example, if you want to visualize only topics 1 through 5:
276
+ `topics = [1, 2, 3, 4, 5]`.
277
+ embeddings: The embeddings of all documents in `docs`.
278
+ reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
279
+ sample: The percentage of documents in each topic that you would like to keep.
280
+ Value can be between 0 and 1. Setting this value to, for example,
281
+ 0.1 (10% of documents in each topic) makes it easier to visualize
282
+ millions of documents as a subset is chosen.
283
+ hide_annotations: Hide the names of the traces on top of each cluster.
284
+ hide_document_hover: Hide the content of the documents when hovering over
285
+ specific points. Helps to speed up generation of visualizations.
286
+ nr_levels: The number of levels to be visualized in the hierarchy. First, the distances
287
+ in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances.
288
+ Then, for each list of distances, the merged topics are selected that have a
289
+ distance less or equal to the maximum distance of the selected list of distances.
290
+ NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to
291
+ the length of `hierarchical_topics`.
292
+ level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance
293
+ vector. Linear scaling will perform an equal number of merges at each level
294
+ while logarithmic scaling will perform more mergers in earlier levels to
295
+ provide more resolution at higher levels (this can be used for when the number
296
+ of topics is large).
297
+ custom_labels: If bool, whether to use custom topic labels that were defined using
298
+ `topic_model.set_topic_labels`.
299
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
300
+ NOTE: Custom labels are only generated for the original
301
+ un-merged topics.
302
+ title: Title of the plot.
303
+ width: The width of the figure.
304
+ height: The height of the figure.
305
+
306
+ Examples:
307
+
308
+ To visualize the topics simply run:
309
+
310
+ ```python
311
+ topic_model.visualize_hierarchical_documents(docs, hierarchical_topics)
312
+ ```
313
+
314
+ Do note that this re-calculates the embeddings and reduces them to 2D.
315
+ The advised and prefered pipeline for using this function is as follows:
316
+
317
+ ```python
318
+ from sklearn.datasets import fetch_20newsgroups
319
+ from sentence_transformers import SentenceTransformer
320
+ from bertopic import BERTopic
321
+ from umap import UMAP
322
+
323
+ # Prepare embeddings
324
+ docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
325
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
326
+ embeddings = sentence_model.encode(docs, show_progress_bar=False)
327
+
328
+ # Train BERTopic and extract hierarchical topics
329
+ topic_model = BERTopic().fit(docs, embeddings)
330
+ hierarchical_topics = topic_model.hierarchical_topics(docs)
331
+
332
+ # Reduce dimensionality of embeddings, this step is optional
333
+ # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
334
+
335
+ # Run the visualization with the original embeddings
336
+ topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings)
337
+
338
+ # Or, if you have reduced the original embeddings already:
339
+ topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
340
+ ```
341
+
342
+ Or if you want to save the resulting figure:
343
+
344
+ ```python
345
+ fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
346
+ fig.write_html("path/to/file.html")
347
+ ```
348
+
349
+ NOTE:
350
+ This visualization was inspired by the scatter plot representation of Doc2Map:
351
+ https://github.com/louisgeisler/Doc2Map
352
+
353
+ <iframe src="../../getting_started/visualization/hierarchical_documents.html"
354
+ style="width:1000px; height: 770px; border: 0px;""></iframe>
355
+ """
356
+ topic_per_doc = topic_model.topics_
357
+
358
+ # Add <br> tags to hover labels to get them to appear on multiple lines
359
+ def wrap_by_word(s, n):
360
+ '''returns a string up to 300 words where \\n is inserted between every n words'''
361
+ a = s.split()[:300]
362
+ ret = ''
363
+ for i in range(0, len(a), n):
364
+ ret += ' '.join(a[i:i+n]) + '<br>'
365
+ return ret
366
+
367
+ # Apply the function to every element in the list
368
+ hover_labels = [wrap_by_word(s, n=20) for s in hover_labels]
369
+
370
+ # Sample the data to optimize for visualization and dimensionality reduction
371
+ if sample is None or sample > 1:
372
+ sample = 1
373
+
374
+ indices = []
375
+ for topic in set(topic_per_doc):
376
+ s = np.where(np.array(topic_per_doc) == topic)[0]
377
+ size = len(s) if len(s) < 100 else int(len(s)*sample)
378
+ indices.extend(np.random.choice(s, size=size, replace=False))
379
+ indices = np.array(indices)
380
+
381
+
382
+
383
+ df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
384
+ df["doc"] = [docs[index] for index in indices]
385
+ df["hover_labels"] = [hover_labels[index] for index in indices]
386
+ df["topic"] = [topic_per_doc[index] for index in indices]
387
+
388
+ # Extract embeddings if not already done
389
+ if sample is None:
390
+ if embeddings is None and reduced_embeddings is None:
391
+ embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
392
+ else:
393
+ embeddings_to_reduce = embeddings
394
+ else:
395
+ if embeddings is not None:
396
+ embeddings_to_reduce = embeddings[indices]
397
+ elif embeddings is None and reduced_embeddings is None:
398
+ embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
399
+
400
+ # Reduce input embeddings
401
+ if reduced_embeddings is None:
402
+ umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
403
+ embeddings_2d = umap_model.embedding_
404
+ elif sample is not None and reduced_embeddings is not None:
405
+ embeddings_2d = reduced_embeddings[indices]
406
+ elif sample is None and reduced_embeddings is not None:
407
+ embeddings_2d = reduced_embeddings
408
+
409
+ # Combine data
410
+ df["x"] = embeddings_2d[:, 0]
411
+ df["y"] = embeddings_2d[:, 1]
412
+
413
+ # Create topic list for each level, levels are created by calculating the distance
414
+ distances = hierarchical_topics.Distance.to_list()
415
+ if level_scale == 'log' or level_scale == 'logarithmic':
416
+ log_indices = np.round(np.logspace(start=math.log(1,10), stop=math.log(len(distances)-1,10), num=nr_levels)).astype(int).tolist()
417
+ log_indices.reverse()
418
+ max_distances = [distances[i] for i in log_indices]
419
+ elif level_scale == 'lin' or level_scale == 'linear':
420
+ max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1]
421
+ else:
422
+ raise ValueError("level_scale needs to be one of 'log' or 'linear'")
423
+
424
+ for index, max_distance in enumerate(max_distances):
425
+
426
+ # Get topics below `max_distance`
427
+ mapping = {topic: topic for topic in df.topic.unique()}
428
+ selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :]
429
+ selection.Parent_ID = selection.Parent_ID.astype(int)
430
+ selection = selection.sort_values("Parent_ID")
431
+
432
+ for row in selection.iterrows():
433
+ for topic in row[1].Topics:
434
+ mapping[topic] = row[1].Parent_ID
435
+
436
+ # Make sure the mappings are mapped 1:1
437
+ mappings = [True for _ in mapping]
438
+ while any(mappings):
439
+ for i, (key, value) in enumerate(mapping.items()):
440
+ if value in mapping.keys() and key != value:
441
+ mapping[key] = mapping[value]
442
+ else:
443
+ mappings[i] = False
444
+
445
+ # Create new column
446
+ df[f"level_{index+1}"] = df.topic.map(mapping)
447
+ df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int)
448
+
449
+ # Prepare topic names of original and merged topics
450
+ trace_names = []
451
+ topic_names = {}
452
+ for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
453
+ if topic < hierarchical_topics.Parent_ID.astype(int).min():
454
+ if topic_model.get_topic(topic):
455
+ if isinstance(custom_labels, str):
456
+ trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
457
+ elif topic_model.custom_labels_ is not None and custom_labels:
458
+ trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
459
+ else:
460
+ trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3])
461
+ topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]}
462
+ trace_names.append(trace_name)
463
+ else:
464
+ trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
465
+ plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]])
466
+ topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]}
467
+ trace_names.append(trace_name)
468
+
469
+ # Prepare traces
470
+ all_traces = []
471
+ for level in range(len(max_distances)):
472
+ traces = []
473
+
474
+ # Outliers
475
+ if topic_model._outliers:
476
+ traces.append(
477
+ go.Scattergl(
478
+ x=df.loc[(df[f"level_{level+1}"] == -1), "x"],
479
+ y=df.loc[df[f"level_{level+1}"] == -1, "y"],
480
+ mode='markers+text',
481
+ name="other",
482
+ hoverinfo="text",
483
+ hovertext=df.loc[(df[f"level_{level+1}"] == -1), "hover_labels"] if not hide_document_hover else None,
484
+ showlegend=False,
485
+ marker=dict(color='#CFD8DC', size=5, opacity=0.5),
486
+ hoverlabel=dict(align='left')
487
+ )
488
+ )
489
+
490
+ # Selected topics
491
+ if topics:
492
+ selection = df.loc[(df.topic.isin(topics)), :]
493
+ unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()])
494
+ else:
495
+ unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()])
496
+
497
+ for topic in unique_topics:
498
+ if topic != -1:
499
+ if topics:
500
+ selection = df.loc[(df[f"level_{level+1}"] == topic) &
501
+ (df.topic.isin(topics)), :]
502
+ else:
503
+ selection = df.loc[df[f"level_{level+1}"] == topic, :]
504
+
505
+ if not hide_annotations:
506
+ selection.loc[len(selection), :] = None
507
+ selection["text"] = ""
508
+ selection.loc[len(selection) - 1, "x"] = selection.x.mean()
509
+ selection.loc[len(selection) - 1, "y"] = selection.y.mean()
510
+ selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"]
511
+
512
+ traces.append(
513
+ go.Scattergl(
514
+ x=selection.x,
515
+ y=selection.y,
516
+ text=selection.text if not hide_annotations else None,
517
+ hovertext=selection.hover_labels if not hide_document_hover else None,
518
+ hoverinfo="text",
519
+ name=topic_names[int(topic)]["trace_name"],
520
+ mode='markers+text',
521
+ marker=dict(size=5, opacity=0.5),
522
+ hoverlabel=dict(align='left')
523
+ )
524
+ )
525
+
526
+ all_traces.append(traces)
527
+
528
+ # Track and count traces
529
+ nr_traces_per_set = [len(traces) for traces in all_traces]
530
+ trace_indices = [(0, nr_traces_per_set[0])]
531
+ for index, nr_traces in enumerate(nr_traces_per_set[1:]):
532
+ start = trace_indices[index][1]
533
+ end = nr_traces + start
534
+ trace_indices.append((start, end))
535
+
536
+ # Visualization
537
+ fig = go.Figure()
538
+ for traces in all_traces:
539
+ for trace in traces:
540
+ fig.add_trace(trace)
541
+
542
+ for index in range(len(fig.data)):
543
+ if index >= nr_traces_per_set[0]:
544
+ fig.data[index].visible = False
545
+
546
+ # Create and add slider
547
+ steps = []
548
+ for index, indices in enumerate(trace_indices):
549
+ step = dict(
550
+ method="update",
551
+ label=str(index),
552
+ args=[{"visible": [False] * len(fig.data)}]
553
+ )
554
+ for index in range(indices[1]-indices[0]):
555
+ step["args"][0]["visible"][index+indices[0]] = True
556
+ steps.append(step)
557
+
558
+ sliders = [dict(
559
+ currentvalue={"prefix": "Level: "},
560
+ pad={"t": 20},
561
+ steps=steps
562
+ )]
563
+
564
+ # Add grid in a 'plus' shape
565
+ x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
566
+ y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
567
+ fig.add_shape(type="line",
568
+ x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
569
+ line=dict(color="#CFD8DC", width=2))
570
+ fig.add_shape(type="line",
571
+ x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
572
+ line=dict(color="#9E9E9E", width=2))
573
+ fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
574
+ fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
575
+
576
+ # Stylize layout
577
+ fig.update_layout(
578
+ sliders=sliders,
579
+ template="simple_white",
580
+ title={
581
+ 'text': f"{title}",
582
+ 'x': 0.5,
583
+ 'xanchor': 'center',
584
+ 'yanchor': 'top',
585
+ 'font': dict(
586
+ size=22,
587
+ color="Black")
588
+ },
589
+ width=width,
590
+ height=height,
591
+ )
592
+
593
+ fig.update_xaxes(visible=False)
594
+ fig.update_yaxes(visible=False)
595
+ return fig
596
+
597
+ def visualize_barchart_custom(topic_model,
598
+ topics: List[int] = None,
599
+ top_n_topics: int = 8,
600
+ n_words: int = 5,
601
+ custom_labels: Union[bool, str] = False,
602
+ title: str = "<b>Topic Word Scores</b>",
603
+ width: int = 250,
604
+ height: int = 250) -> go.Figure:
605
+ """ Visualize a barchart of selected topics
606
+
607
+ Arguments:
608
+ topic_model: A fitted BERTopic instance.
609
+ topics: A selection of topics to visualize.
610
+ top_n_topics: Only select the top n most frequent topics.
611
+ n_words: Number of words to show in a topic
612
+ custom_labels: If bool, whether to use custom topic labels that were defined using
613
+ `topic_model.set_topic_labels`.
614
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
615
+ title: Title of the plot.
616
+ width: The width of each figure.
617
+ height: The height of each figure.
618
+
619
+ Returns:
620
+ fig: A plotly figure
621
+
622
+ Examples:
623
+
624
+ To visualize the barchart of selected topics
625
+ simply run:
626
+
627
+ ```python
628
+ topic_model.visualize_barchart()
629
+ ```
630
+
631
+ Or if you want to save the resulting figure:
632
+
633
+ ```python
634
+ fig = topic_model.visualize_barchart()
635
+ fig.write_html("path/to/file.html")
636
+ ```
637
+ <iframe src="../../getting_started/visualization/bar_chart.html"
638
+ style="width:1100px; height: 660px; border: 0px;""></iframe>
639
+ """
640
+ colors = itertools.cycle(["#D55E00", "#0072B2", "#CC79A7", "#E69F00", "#56B4E9", "#009E73", "#F0E442"])
641
+
642
+ # Select topics based on top_n and topics args
643
+ freq_df = topic_model.get_topic_freq()
644
+ freq_df = freq_df.loc[freq_df.Topic != -1, :]
645
+ if topics is not None:
646
+ topics = list(topics)
647
+ elif top_n_topics is not None:
648
+ topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
649
+ else:
650
+ topics = sorted(freq_df.Topic.to_list()[0:6])
651
+
652
+ # Initialize figure
653
+ if isinstance(custom_labels, str):
654
+ subplot_titles = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics]
655
+ subplot_titles = ["_".join([label[0] for label in labels[:4]]) for labels in subplot_titles]
656
+ subplot_titles = [label if len(label) < 30 else label[:27] + "..." for label in subplot_titles]
657
+ elif topic_model.custom_labels_ is not None and custom_labels:
658
+ subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
659
+ else:
660
+ subplot_titles = [f"Topic {topic}" for topic in topics]
661
+ columns = 4
662
+ rows = int(np.ceil(len(topics) / columns))
663
+ fig = make_subplots(rows=rows,
664
+ cols=columns,
665
+ shared_xaxes=False,
666
+ horizontal_spacing=.1,
667
+ vertical_spacing=.4 / rows if rows > 1 else 0,
668
+ subplot_titles=subplot_titles)
669
+
670
+ # Add barchart for each topic
671
+ row = 1
672
+ column = 1
673
+ for topic in topics:
674
+ words = [word + " " for word, _ in topic_model.get_topic(topic)][:n_words][::-1]
675
+ scores = [score for _, score in topic_model.get_topic(topic)][:n_words][::-1]
676
+
677
+ fig.add_trace(
678
+ go.Bar(x=scores,
679
+ y=words,
680
+ orientation='h',
681
+ marker_color=next(colors)),
682
+ row=row, col=column)
683
+
684
+ if column == columns:
685
+ column = 1
686
+ row += 1
687
+ else:
688
+ column += 1
689
+
690
+ # Stylize graph
691
+ fig.update_layout(
692
+ template="plotly_white",
693
+ showlegend=False,
694
+ title={
695
+ 'text': f"{title}",
696
+ 'x': .5,
697
+ 'xanchor': 'center',
698
+ 'yanchor': 'top',
699
+ 'font': dict(
700
+ size=16,
701
+ color="Black")
702
+ },
703
+ width=width*4,
704
+ height=height*rows if rows > 1 else height * 1.3,
705
+ hoverlabel=dict(
706
+ bgcolor="white",
707
+ font_size=16,
708
+ font_family="Rockwell"
709
+ ),
710
+ )
711
+
712
+ fig.update_xaxes(showgrid=True)
713
+ fig.update_yaxes(showgrid=True)
714
+
715
+ return fig
funcs/embeddings.py CHANGED
@@ -4,7 +4,6 @@ from torch import cuda
4
  from sklearn.pipeline import make_pipeline
5
  from sklearn.decomposition import TruncatedSVD
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
- from umap import UMAP
8
 
9
  random_seed = 42
10
 
@@ -20,13 +19,14 @@ def make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, em
20
  print("Embeddings not found. Loading or generating new ones.")
21
 
22
  embeddings_file_names = [string.lower() for string in file_list if "embedding" in string.lower()]
23
-
24
  if embeddings_file_names:
 
25
  print("Loading embeddings from file.")
26
- embeddings_out = np.load(embeddings_file_names[0])['arr_0']
27
 
28
  # If embedding files have 'super_compress' in the title, they have been multiplied by 100 before save
29
- if "compress" in embeddings_file_names[0]:
30
  embeddings_out /= 100
31
 
32
  if not embeddings_file_names:
@@ -66,9 +66,9 @@ def make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, em
66
  embeddings_out = np.round(embeddings_out, 3)
67
  embeddings_out *= 100
68
 
69
- return embeddings_out, None
70
 
71
  else:
72
  print("Found pre-loaded embeddings.")
73
 
74
- return embeddings_out, None
 
4
  from sklearn.pipeline import make_pipeline
5
  from sklearn.decomposition import TruncatedSVD
6
  from sklearn.feature_extraction.text import TfidfVectorizer
 
7
 
8
  random_seed = 42
9
 
 
19
  print("Embeddings not found. Loading or generating new ones.")
20
 
21
  embeddings_file_names = [string.lower() for string in file_list if "embedding" in string.lower()]
22
+
23
  if embeddings_file_names:
24
+ embeddings_file_name = embeddings_file_names[0]
25
  print("Loading embeddings from file.")
26
+ embeddings_out = np.load(embeddings_file_name)['arr_0']
27
 
28
  # If embedding files have 'super_compress' in the title, they have been multiplied by 100 before save
29
+ if "compress" in embeddings_file_name:
30
  embeddings_out /= 100
31
 
32
  if not embeddings_file_names:
 
66
  embeddings_out = np.round(embeddings_out, 3)
67
  embeddings_out *= 100
68
 
69
+ return embeddings_out
70
 
71
  else:
72
  print("Found pre-loaded embeddings.")
73
 
74
+ return embeddings_out
funcs/helper_functions.py CHANGED
@@ -6,6 +6,11 @@ import gradio as gr
6
  import gzip
7
  import pickle
8
  import numpy as np
 
 
 
 
 
9
 
10
 
11
  def detect_file_type(filename):
@@ -20,6 +25,8 @@ def detect_file_type(filename):
20
  return 'pkl.gz'
21
  elif filename.endswith('.pkl'):
22
  return 'pkl'
 
 
23
  else:
24
  raise ValueError("Unsupported file type.")
25
 
@@ -30,35 +37,45 @@ def read_file(filename):
30
  print("Loading in file")
31
 
32
  if file_type == 'csv':
33
- file = pd.read_csv(filename, low_memory=False).reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
34
  elif file_type == 'xlsx':
35
- file = pd.read_excel(filename).reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
36
  elif file_type == 'parquet':
37
- file = pd.read_parquet(filename).reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
38
  elif file_type == 'pkl.gz':
39
  with gzip.open(filename, 'rb') as file:
40
  file = pickle.load(file)
41
  #file = pd.read_pickle(filename)
42
  elif file_type == 'pkl':
43
- file = pickle.load(file)
 
 
 
 
 
 
44
 
45
  print("File load complete")
46
 
47
  return file
48
 
49
- def put_columns_in_df(in_file, in_bm25_column):
50
  '''
51
  When file is loaded, update the column dropdown choices and write to relevant data states.
52
  '''
53
  new_choices = []
54
  concat_choices = []
 
 
 
55
 
56
  file_list = [string.name for string in in_file]
57
 
58
- data_file_names = [string.lower() for string in file_list if "npz" not in string.lower() and "pkl" not in string.lower()]
59
  if data_file_names:
60
  data_file_name = data_file_names[0]
61
  df = read_file(data_file_name)
 
62
 
63
  new_choices = list(df.columns)
64
  concat_choices.extend(new_choices)
@@ -72,13 +89,23 @@ def put_columns_in_df(in_file, in_bm25_column):
72
  if model_file_names:
73
  model_file_name = model_file_names[0]
74
  topic_model = read_file(model_file_name)
75
- output_text = "Bertopic model loaded in"
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
-
78
- return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, np.array([]), output_text, topic_model
79
-
80
  #The np.array([]) at the end is for clearing the embedding state when a new file is loaded
81
- return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, np.array([]), output_text, None
82
 
83
  def get_file_path_end(file_path):
84
  # First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
@@ -134,4 +161,51 @@ def delete_files_in_folder(folder_path):
134
  else:
135
  print(f"Skipping {file_path} as it is a directory")
136
  except Exception as e:
137
- print(f"Failed to delete {file_path}. Reason: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import gzip
7
  import pickle
8
  import numpy as np
9
+ from bertopic import BERTopic
10
+ from datetime import datetime
11
+
12
+ today = datetime.now().strftime("%d%m%Y")
13
+ today_rev = datetime.now().strftime("%Y%m%d")
14
 
15
 
16
  def detect_file_type(filename):
 
25
  return 'pkl.gz'
26
  elif filename.endswith('.pkl'):
27
  return 'pkl'
28
+ elif filename.endswith('.npz'):
29
+ return 'npz'
30
  else:
31
  raise ValueError("Unsupported file type.")
32
 
 
37
  print("Loading in file")
38
 
39
  if file_type == 'csv':
40
+ file = pd.read_csv(filename, low_memory=False)#.reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
41
  elif file_type == 'xlsx':
42
+ file = pd.read_excel(filename)#.reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
43
  elif file_type == 'parquet':
44
+ file = pd.read_parquet(filename)#.reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
45
  elif file_type == 'pkl.gz':
46
  with gzip.open(filename, 'rb') as file:
47
  file = pickle.load(file)
48
  #file = pd.read_pickle(filename)
49
  elif file_type == 'pkl':
50
+ file = BERTopic.load(filename)
51
+ elif file_type == 'npz':
52
+ file = np.load(filename)['arr_0']
53
+
54
+ # If embedding files have 'super_compress' in the title, they have been multiplied by 100 before save
55
+ if "compress" in filename:
56
+ file /= 100
57
 
58
  print("File load complete")
59
 
60
  return file
61
 
62
+ def initial_file_load(in_file):
63
  '''
64
  When file is loaded, update the column dropdown choices and write to relevant data states.
65
  '''
66
  new_choices = []
67
  concat_choices = []
68
+ custom_labels = pd.DataFrame()
69
+ topic_model = None
70
+ embeddings = np.array([])
71
 
72
  file_list = [string.name for string in in_file]
73
 
74
+ data_file_names = [string.lower() for string in file_list if "npz" not in string.lower() and "pkl" not in string.lower() and "topic_list.csv" not in string.lower()]
75
  if data_file_names:
76
  data_file_name = data_file_names[0]
77
  df = read_file(data_file_name)
78
+ data_file_name_no_ext = get_file_path_end(data_file_name)
79
 
80
  new_choices = list(df.columns)
81
  concat_choices.extend(new_choices)
 
89
  if model_file_names:
90
  model_file_name = model_file_names[0]
91
  topic_model = read_file(model_file_name)
92
+ output_text = "Bertopic model loaded."
93
+
94
+ embedding_file_names = [string.lower() for string in file_list if "npz" in string.lower()]
95
+ if embedding_file_names:
96
+ embedding_file_name = embedding_file_names[0]
97
+ embeddings = read_file(embedding_file_name)
98
+ output_text = "Embeddings loaded."
99
+
100
+ label_file_names = [string.lower() for string in file_list if "topic_list" in string.lower()]
101
+ if label_file_names:
102
+ label_file_name = label_file_names[0]
103
+ custom_labels = read_file(label_file_name)
104
+ output_text = "Labels loaded."
105
+
106
 
 
 
 
107
  #The np.array([]) at the end is for clearing the embedding state when a new file is loaded
108
+ return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, output_text, topic_model, embeddings, data_file_name_no_ext, custom_labels
109
 
110
  def get_file_path_end(file_path):
111
  # First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
 
161
  else:
162
  print(f"Skipping {file_path} as it is a directory")
163
  except Exception as e:
164
+ print(f"Failed to delete {file_path}. Reason: {e}")
165
+
166
+
167
+ def save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model, progress=gr.Progress()):
168
+
169
+ progress(0.7, desc= "Checking data")
170
+
171
+ topic_dets = topic_model.get_topic_info()
172
+
173
+ if topic_dets.shape[0] == 1:
174
+ topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
175
+ topic_dets.to_csv(topic_det_output_name)
176
+ output_list.append(topic_det_output_name)
177
+
178
+ return output_list, "No topics found, original file returned"
179
+
180
+
181
+ progress(0.8, desc= "Saving output")
182
+
183
+ topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
184
+ topic_dets.to_csv(topic_det_output_name)
185
+ output_list.append(topic_det_output_name)
186
+
187
+ doc_det_output_name = "doc_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
188
+ doc_dets = topic_model.get_document_info(docs)[["Document", "Topic", "Name", "Probability", "Representative_document"]]
189
+ doc_dets.to_csv(doc_det_output_name)
190
+ output_list.append(doc_det_output_name)
191
+
192
+ topics_text_out_str = str(topic_dets["Name"])
193
+ output_text = "Topics: " + topics_text_out_str
194
+
195
+ # Save topic model to file
196
+ if save_topic_model == "Yes":
197
+ print("Saving BERTopic model in .pkl format.")
198
+ topic_model_save_name_pkl = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev + ".pkl"# + ".safetensors"
199
+ topic_model_save_name_zip = topic_model_save_name_pkl + ".zip"
200
+
201
+ # Clear folder before replacing files
202
+ #delete_files_in_folder(topic_model_save_name_pkl)
203
+
204
+ topic_model.save(topic_model_save_name_pkl, serialization='pickle', save_embedding_model=False, save_ctfidf=False)
205
+
206
+ # Zip file example
207
+
208
+ #zip_folder(topic_model_save_name_pkl, topic_model_save_name_zip)
209
+ output_list.append(topic_model_save_name_pkl)
210
+
211
+ return output_list, output_text
funcs/representation_model.py CHANGED
@@ -28,7 +28,7 @@ else:
28
  low_resource_mode = "Yes"
29
  n_gpu_layers = 0
30
 
31
- low_resource_mode = "No" # Override for testing
32
 
33
  #print("Running on device:", torch_device)
34
  n_threads = torch.get_num_threads()
 
28
  low_resource_mode = "Yes"
29
  n_gpu_layers = 0
30
 
31
+ #low_resource_mode = "No" # Override for testing
32
 
33
  #print("Running on device:", torch_device)
34
  n_threads = torch.get_num_threads()
requirements.txt CHANGED
@@ -1,11 +1,12 @@
1
  gradio==3.50.0
2
- transformers
3
- accelerate
4
- torch
5
- llama-cpp-python
6
- bertopic
7
- spacy
8
- pyarrow
9
- faker
10
- presidio_analyzer
11
- presidio_anonymizer
 
 
1
  gradio==3.50.0
2
+ transformers==4.37.1
3
+ accelerate==0.26.1
4
+ torch==2.1.2
5
+ llama-cpp-python==0.2.33
6
+ bertopic==0.16.0
7
+ spacy==3.7.2
8
+ pyarrow==14.0.2
9
+ Faker==22.2.0
10
+ presidio_analyzer==2.2.351
11
+ presidio_anonymizer==2.2.351
12
+ scipy==1.11.4