Sonnyjim commited on
Commit
b27bab2
1 Parent(s): 356791c

Allowed for loading in external topic labels. A few visualisation modifications.

Browse files
Files changed (2) hide show
  1. app.py +19 -13
  2. funcs/bertopic_vis_documents.py +9 -3
app.py CHANGED
@@ -253,13 +253,16 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
253
  else:
254
  print("Topic model created.")
255
 
 
256
  if not custom_labels_df.empty:
257
- #print(custom_labels_df.shape)
 
258
 
259
- #topic_dets = topic_model.get_topic_info()
260
- #print(topic_dets.shape)
261
 
262
- topic_model.set_topic_labels(list(custom_labels_df.iloc[:,0]))
 
263
 
264
  # Outputs
265
  output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
@@ -384,7 +387,7 @@ def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, l
384
 
385
  return output_text, output_list, topic_model
386
 
387
- def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress()):
388
 
389
  progress(0, desc= "Preparing data for visualisation")
390
 
@@ -416,12 +419,13 @@ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode
416
 
417
  topic_dets = topic_model.get_topic_info()
418
 
419
- # Replace original labels with LLM labels if they exist, or go with the 'Name' column
420
- if "LLM" in topic_model.get_topic_info().columns:
421
- llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
422
- topic_model.set_topic_labels(llm_labels)
423
- else:
424
- topic_model.set_topic_labels(list(topic_dets["Name"]))
 
425
 
426
  # Pre-reduce embeddings for visualisation purposes
427
  if low_resource_mode == "No":
@@ -560,9 +564,11 @@ with block:
560
 
561
  with gr.Tab("Visualise"):
562
  with gr.Row():
563
- in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in output visualisations.")
564
  visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
 
565
  sample_slide = gr.Slider(minimum = 0.01, maximum = 1, value = 0.1, step = 0.01, label = "Proportion of data points to show on output visualisations.")
 
 
566
  plot_btn = gr.Button("Visualise topic model")
567
  with gr.Row():
568
  vis_output_single_text = gr.Textbox(label="Visualisation output text")
@@ -595,7 +601,7 @@ with block:
595
 
596
  save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
597
 
598
- plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
599
 
600
  #block.load(read_logs, None, logs, every=5)
601
 
 
253
  else:
254
  print("Topic model created.")
255
 
256
+ # Replace current topic labels if new ones loaded in
257
  if not custom_labels_df.empty:
258
+ #custom_label_list = list(custom_labels_df.iloc[:,0])
259
+ custom_label_list = [label.replace("\n", "") for label in custom_labels_df.iloc[:,0]]
260
 
261
+ topic_model.set_topic_labels(custom_label_list)
262
+ #topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model)
263
 
264
+
265
+ print("Custom topics: ", topic_model.custom_labels_)
266
 
267
  # Outputs
268
  output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
 
387
 
388
  return output_text, output_list, topic_model
389
 
390
+ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, legend_label, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress()):
391
 
392
  progress(0, desc= "Preparing data for visualisation")
393
 
 
419
 
420
  topic_dets = topic_model.get_topic_info()
421
 
422
+ # Replace original labels with another representation if specified
423
+ if legend_label:
424
+ topic_dets = topic_model.get_topics(full=True)
425
+ if legend_label in topic_dets:
426
+ labels = [topic_dets[legend_label].values()]
427
+ labels = [str(v) for v in labels]
428
+ topic_model.set_topic_labels(labels)
429
 
430
  # Pre-reduce embeddings for visualisation purposes
431
  if low_resource_mode == "No":
 
564
 
565
  with gr.Tab("Visualise"):
566
  with gr.Row():
 
567
  visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
568
+ in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in output visualisations.")
569
  sample_slide = gr.Slider(minimum = 0.01, maximum = 1, value = 0.1, step = 0.01, label = "Proportion of data points to show on output visualisations.")
570
+ legend_label = gr.Textbox(label="Custom legend column (optional, any column from the topic details output)", visible=False)
571
+
572
  plot_btn = gr.Button("Visualise topic model")
573
  with gr.Row():
574
  vis_output_single_text = gr.Textbox(label="Visualisation output text")
 
601
 
602
  save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
603
 
604
+ plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, legend_label, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
605
 
606
  #block.load(read_logs, None, logs, every=5)
607
 
funcs/bertopic_vis_documents.py CHANGED
@@ -160,10 +160,14 @@ def visualize_documents_custom(topic_model,
160
  names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
161
  names = [label if len(label) < 30 else label[:27] + "..." for label in names]
162
  elif topic_model.custom_labels_ is not None and custom_labels:
 
163
  names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
164
  else:
 
165
  names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
166
 
 
 
167
  # Visualize
168
  fig = go.Figure()
169
 
@@ -192,6 +196,8 @@ def visualize_documents_custom(topic_model,
192
 
193
  # Selected topics
194
  for name, topic in zip(names, unique_topics):
 
 
195
  if topic in topics and topic != -1:
196
  selection = df.loc[df.topic == topic, :]
197
  selection["text"] = ""
@@ -658,7 +664,7 @@ def visualize_barchart_custom(topic_model,
658
  subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
659
  else:
660
  subplot_titles = [f"Topic {topic}" for topic in topics]
661
- columns = 4
662
  rows = int(np.ceil(len(topics) / columns))
663
  fig = make_subplots(rows=rows,
664
  cols=columns,
@@ -697,14 +703,14 @@ def visualize_barchart_custom(topic_model,
697
  'xanchor': 'center',
698
  'yanchor': 'top',
699
  'font': dict(
700
- size=16,
701
  color="Black")
702
  },
703
  width=width*4,
704
  height=height*rows if rows > 1 else height * 1.3,
705
  hoverlabel=dict(
706
  bgcolor="white",
707
- font_size=16,
708
  font_family="Rockwell"
709
  ),
710
  )
 
160
  names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
161
  names = [label if len(label) < 30 else label[:27] + "..." for label in names]
162
  elif topic_model.custom_labels_ is not None and custom_labels:
163
+ print("Using custom labels: ", topic_model.custom_labels_)
164
  names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
165
  else:
166
+ print("Not using custom labels")
167
  names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
168
 
169
+ print(names)
170
+
171
  # Visualize
172
  fig = go.Figure()
173
 
 
196
 
197
  # Selected topics
198
  for name, topic in zip(names, unique_topics):
199
+ #print(name)
200
+ #print(topic)
201
  if topic in topics and topic != -1:
202
  selection = df.loc[df.topic == topic, :]
203
  selection["text"] = ""
 
664
  subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
665
  else:
666
  subplot_titles = [f"Topic {topic}" for topic in topics]
667
+ columns = 3
668
  rows = int(np.ceil(len(topics) / columns))
669
  fig = make_subplots(rows=rows,
670
  cols=columns,
 
703
  'xanchor': 'center',
704
  'yanchor': 'top',
705
  'font': dict(
706
+ size=14,
707
  color="Black")
708
  },
709
  width=width*4,
710
  height=height*rows if rows > 1 else height * 1.3,
711
  hoverlabel=dict(
712
  bgcolor="white",
713
+ font_size=14,
714
  font_family="Rockwell"
715
  ),
716
  )