import itertools import numpy as np from typing import List import plotly.graph_objects as go from plotly.subplots import make_subplots def visualize_barchart(topic_model, topics: List[int] = None, top_n_topics: int = 8, n_words: int = 5, custom_labels: bool = False, title: str = "Kata Kunci tiap Topic", width: int = 250, height: int = 250) -> go.Figure: """ Visualize a barchart of selected topics Arguments: topic_model: A fitted BERTopic instance. topics: A selection of topics to visualize. top_n_topics: Only select the top n most frequent topics. n_words: Number of words to show in a topic custom_labels: Whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. title: Title of the plot. width: The width of each figure. height: The height of each figure. Returns: fig: A plotly figure Examples: To visualize the barchart of selected topics simply run: ```python topic_model.visualize_barchart() ``` Or if you want to save the resulting figure: ```python fig = topic_model.visualize_barchart() fig.write_html("path/to/file.html") ``` """ colors = itertools.cycle(['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']) # Select topics based on top_n and topics args freq_df = topic_model.get_topic_freq() if len(freq_df) > 1: freq_df = freq_df.loc[freq_df.Topic != -1, :] if topics is not None: topics = list(topics) elif top_n_topics is not None: topics = sorted(freq_df.Topic.to_list()[:top_n_topics]) else: topics = sorted(freq_df.Topic.to_list()[0:6]) # Initialize figure if topic_model.custom_labels_ is not None and custom_labels: subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics] else: subplot_titles = [f"Topic {topic}" for topic in topics] columns = 3 rows = int(np.ceil(len(topics) / columns)) fig = make_subplots(rows=rows, cols=columns, shared_xaxes=False, horizontal_spacing=.1, vertical_spacing=.4 / rows if rows > 1 else 0, subplot_titles=subplot_titles) # Add barchart for each topic row = 1 column = 1 for topic in topics: words = [word + " " for word, _ in topic_model.get_topic(topic)][:n_words][::-1] scores = [score for _, score in topic_model.get_topic(topic)][:n_words][::-1] fig.add_trace( go.Bar(x=scores, y=words, orientation='h', marker_color=next(colors)), row=row, col=column) if column == columns: column = 1 row += 1 else: column += 1 # Stylize graph fig.update_layout( showlegend=False, title={ 'text': f"{title}", 'xanchor': 'center', 'yanchor': 'top', 'font': dict( size=22, color="Black") }, width=width*3, height=height*rows if rows > 1 else height * 1.3, hoverlabel=dict( bgcolor="white", font_size=13, font_family="Rockwell" ), margin=dict(l=40, r=40) ) fig.update_xaxes(showgrid=True) fig.update_yaxes(showgrid=True) return fig