Spaces:
Running
Running
Allowed for loading in external topic labels. A few visualisation modifications.
Browse files- app.py +19 -13
- funcs/bertopic_vis_documents.py +9 -3
app.py
CHANGED
@@ -253,13 +253,16 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
253 |
else:
|
254 |
print("Topic model created.")
|
255 |
|
|
|
256 |
if not custom_labels_df.empty:
|
257 |
-
#
|
|
|
258 |
|
259 |
-
|
260 |
-
#
|
261 |
|
262 |
-
|
|
|
263 |
|
264 |
# Outputs
|
265 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
@@ -384,7 +387,7 @@ def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, l
|
|
384 |
|
385 |
return output_text, output_list, topic_model
|
386 |
|
387 |
-
def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress()):
|
388 |
|
389 |
progress(0, desc= "Preparing data for visualisation")
|
390 |
|
@@ -416,12 +419,13 @@ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode
|
|
416 |
|
417 |
topic_dets = topic_model.get_topic_info()
|
418 |
|
419 |
-
# Replace original labels with
|
420 |
-
if
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
|
|
425 |
|
426 |
# Pre-reduce embeddings for visualisation purposes
|
427 |
if low_resource_mode == "No":
|
@@ -560,9 +564,11 @@ with block:
|
|
560 |
|
561 |
with gr.Tab("Visualise"):
|
562 |
with gr.Row():
|
563 |
-
in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in output visualisations.")
|
564 |
visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
|
|
|
565 |
sample_slide = gr.Slider(minimum = 0.01, maximum = 1, value = 0.1, step = 0.01, label = "Proportion of data points to show on output visualisations.")
|
|
|
|
|
566 |
plot_btn = gr.Button("Visualise topic model")
|
567 |
with gr.Row():
|
568 |
vis_output_single_text = gr.Textbox(label="Visualisation output text")
|
@@ -595,7 +601,7 @@ with block:
|
|
595 |
|
596 |
save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
|
597 |
|
598 |
-
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
|
599 |
|
600 |
#block.load(read_logs, None, logs, every=5)
|
601 |
|
|
|
253 |
else:
|
254 |
print("Topic model created.")
|
255 |
|
256 |
+
# Replace current topic labels if new ones loaded in
|
257 |
if not custom_labels_df.empty:
|
258 |
+
#custom_label_list = list(custom_labels_df.iloc[:,0])
|
259 |
+
custom_label_list = [label.replace("\n", "") for label in custom_labels_df.iloc[:,0]]
|
260 |
|
261 |
+
topic_model.set_topic_labels(custom_label_list)
|
262 |
+
#topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model)
|
263 |
|
264 |
+
|
265 |
+
print("Custom topics: ", topic_model.custom_labels_)
|
266 |
|
267 |
# Outputs
|
268 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
|
|
387 |
|
388 |
return output_text, output_list, topic_model
|
389 |
|
390 |
+
def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, legend_label, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress()):
|
391 |
|
392 |
progress(0, desc= "Preparing data for visualisation")
|
393 |
|
|
|
419 |
|
420 |
topic_dets = topic_model.get_topic_info()
|
421 |
|
422 |
+
# Replace original labels with another representation if specified
|
423 |
+
if legend_label:
|
424 |
+
topic_dets = topic_model.get_topics(full=True)
|
425 |
+
if legend_label in topic_dets:
|
426 |
+
labels = [topic_dets[legend_label].values()]
|
427 |
+
labels = [str(v) for v in labels]
|
428 |
+
topic_model.set_topic_labels(labels)
|
429 |
|
430 |
# Pre-reduce embeddings for visualisation purposes
|
431 |
if low_resource_mode == "No":
|
|
|
564 |
|
565 |
with gr.Tab("Visualise"):
|
566 |
with gr.Row():
|
|
|
567 |
visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
|
568 |
+
in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in output visualisations.")
|
569 |
sample_slide = gr.Slider(minimum = 0.01, maximum = 1, value = 0.1, step = 0.01, label = "Proportion of data points to show on output visualisations.")
|
570 |
+
legend_label = gr.Textbox(label="Custom legend column (optional, any column from the topic details output)", visible=False)
|
571 |
+
|
572 |
plot_btn = gr.Button("Visualise topic model")
|
573 |
with gr.Row():
|
574 |
vis_output_single_text = gr.Textbox(label="Visualisation output text")
|
|
|
601 |
|
602 |
save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
|
603 |
|
604 |
+
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, legend_label, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
|
605 |
|
606 |
#block.load(read_logs, None, logs, every=5)
|
607 |
|
funcs/bertopic_vis_documents.py
CHANGED
@@ -160,10 +160,14 @@ def visualize_documents_custom(topic_model,
|
|
160 |
names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
|
161 |
names = [label if len(label) < 30 else label[:27] + "..." for label in names]
|
162 |
elif topic_model.custom_labels_ is not None and custom_labels:
|
|
|
163 |
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
|
164 |
else:
|
|
|
165 |
names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
|
166 |
|
|
|
|
|
167 |
# Visualize
|
168 |
fig = go.Figure()
|
169 |
|
@@ -192,6 +196,8 @@ def visualize_documents_custom(topic_model,
|
|
192 |
|
193 |
# Selected topics
|
194 |
for name, topic in zip(names, unique_topics):
|
|
|
|
|
195 |
if topic in topics and topic != -1:
|
196 |
selection = df.loc[df.topic == topic, :]
|
197 |
selection["text"] = ""
|
@@ -658,7 +664,7 @@ def visualize_barchart_custom(topic_model,
|
|
658 |
subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
|
659 |
else:
|
660 |
subplot_titles = [f"Topic {topic}" for topic in topics]
|
661 |
-
columns =
|
662 |
rows = int(np.ceil(len(topics) / columns))
|
663 |
fig = make_subplots(rows=rows,
|
664 |
cols=columns,
|
@@ -697,14 +703,14 @@ def visualize_barchart_custom(topic_model,
|
|
697 |
'xanchor': 'center',
|
698 |
'yanchor': 'top',
|
699 |
'font': dict(
|
700 |
-
size=
|
701 |
color="Black")
|
702 |
},
|
703 |
width=width*4,
|
704 |
height=height*rows if rows > 1 else height * 1.3,
|
705 |
hoverlabel=dict(
|
706 |
bgcolor="white",
|
707 |
-
font_size=
|
708 |
font_family="Rockwell"
|
709 |
),
|
710 |
)
|
|
|
160 |
names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
|
161 |
names = [label if len(label) < 30 else label[:27] + "..." for label in names]
|
162 |
elif topic_model.custom_labels_ is not None and custom_labels:
|
163 |
+
print("Using custom labels: ", topic_model.custom_labels_)
|
164 |
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
|
165 |
else:
|
166 |
+
print("Not using custom labels")
|
167 |
names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
|
168 |
|
169 |
+
print(names)
|
170 |
+
|
171 |
# Visualize
|
172 |
fig = go.Figure()
|
173 |
|
|
|
196 |
|
197 |
# Selected topics
|
198 |
for name, topic in zip(names, unique_topics):
|
199 |
+
#print(name)
|
200 |
+
#print(topic)
|
201 |
if topic in topics and topic != -1:
|
202 |
selection = df.loc[df.topic == topic, :]
|
203 |
selection["text"] = ""
|
|
|
664 |
subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
|
665 |
else:
|
666 |
subplot_titles = [f"Topic {topic}" for topic in topics]
|
667 |
+
columns = 3
|
668 |
rows = int(np.ceil(len(topics) / columns))
|
669 |
fig = make_subplots(rows=rows,
|
670 |
cols=columns,
|
|
|
703 |
'xanchor': 'center',
|
704 |
'yanchor': 'top',
|
705 |
'font': dict(
|
706 |
+
size=14,
|
707 |
color="Black")
|
708 |
},
|
709 |
width=width*4,
|
710 |
height=height*rows if rows > 1 else height * 1.3,
|
711 |
hoverlabel=dict(
|
712 |
bgcolor="white",
|
713 |
+
font_size=14,
|
714 |
font_family="Rockwell"
|
715 |
),
|
716 |
)
|