import itertools
import numpy as np
from typing import List
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def visualize_barchart(topic_model,
topics: List[int] = None,
top_n_topics: int = 8,
n_words: int = 5,
custom_labels: bool = False,
title: str = "Kata Kunci tiap Topic",
width: int = 250,
height: int = 250) -> go.Figure:
""" Visualize a barchart of selected topics
Arguments:
topic_model: A fitted BERTopic instance.
topics: A selection of topics to visualize.
top_n_topics: Only select the top n most frequent topics.
n_words: Number of words to show in a topic
custom_labels: Whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
title: Title of the plot.
width: The width of each figure.
height: The height of each figure.
Returns:
fig: A plotly figure
Examples:
To visualize the barchart of selected topics
simply run:
```python
topic_model.visualize_barchart()
```
Or if you want to save the resulting figure:
```python
fig = topic_model.visualize_barchart()
fig.write_html("path/to/file.html")
```
"""
colors = itertools.cycle(['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52'])
# Select topics based on top_n and topics args
freq_df = topic_model.get_topic_freq()
if len(freq_df) > 1:
freq_df = freq_df.loc[freq_df.Topic != -1, :]
if topics is not None:
topics = list(topics)
elif top_n_topics is not None:
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
else:
topics = sorted(freq_df.Topic.to_list()[0:6])
# Initialize figure
if topic_model.custom_labels_ is not None and custom_labels:
subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
else:
subplot_titles = [f"Topic {topic}" for topic in topics]
columns = 3
rows = int(np.ceil(len(topics) / columns))
fig = make_subplots(rows=rows,
cols=columns,
shared_xaxes=False,
horizontal_spacing=.1,
vertical_spacing=.4 / rows if rows > 1 else 0,
subplot_titles=subplot_titles)
# Add barchart for each topic
row = 1
column = 1
for topic in topics:
words = [word + " " for word, _ in topic_model.get_topic(topic)][:n_words][::-1]
scores = [score for _, score in topic_model.get_topic(topic)][:n_words][::-1]
fig.add_trace(
go.Bar(x=scores,
y=words,
orientation='h',
marker_color=next(colors)),
row=row, col=column)
if column == columns:
column = 1
row += 1
else:
column += 1
# Stylize graph
fig.update_layout(
showlegend=False,
title={
'text': f"{title}",
'xanchor': 'center',
'yanchor': 'top',
'font': dict(
size=22,
color="Black")
},
width=width*3,
height=height*rows if rows > 1 else height * 1.3,
hoverlabel=dict(
bgcolor="white",
font_size=13,
font_family="Rockwell"
),
margin=dict(l=40, r=40)
)
fig.update_xaxes(showgrid=True)
fig.update_yaxes(showgrid=True)
return fig