| import numpy as np |
| import pandas as pd |
| from typing import Callable, List, Union |
| from scipy.sparse import csr_matrix |
| from scipy.cluster import hierarchy as sch |
| from scipy.spatial.distance import squareform |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| import plotly.graph_objects as go |
| import plotly.figure_factory as ff |
|
|
| from bertopic._utils import validate_distance_matrix |
|
|
| def visualize_hierarchy(topic_model, |
| orientation: str = "left", |
| topics: List[int] = None, |
| top_n_topics: int = None, |
| custom_labels: Union[bool, str] = False, |
| title: str = "<b>Hierarchical Clustering</b>", |
| width: int = 1000, |
| height: int = 600, |
| hierarchical_topics: pd.DataFrame = None, |
| linkage_function: Callable[[csr_matrix], np.ndarray] = None, |
| distance_function: Callable[[csr_matrix], csr_matrix] = None, |
| color_threshold: int = 1) -> go.Figure: |
| """ Visualize a hierarchical structure of the topics |
| |
| A ward linkage function is used to perform the |
| hierarchical clustering based on the cosine distance |
| matrix between topic embeddings. |
| |
| Arguments: |
| topic_model: A fitted BERTopic instance. |
| orientation: The orientation of the figure. |
| Either 'left' or 'bottom' |
| topics: A selection of topics to visualize |
| top_n_topics: Only select the top n most frequent topics |
| custom_labels: If bool, whether to use custom topic labels that were defined using |
| `topic_model.set_topic_labels`. |
| If `str`, it uses labels from other aspects, e.g., "Aspect1". |
| NOTE: Custom labels are only generated for the original |
| un-merged topics. |
| title: Title of the plot. |
| width: The width of the figure. Only works if orientation is set to 'left' |
| height: The height of the figure. Only works if orientation is set to 'bottom' |
| hierarchical_topics: A dataframe that contains a hierarchy of topics |
| represented by their parents and their children. |
| NOTE: The hierarchical topic names are only visualized |
| if both `topics` and `top_n_topics` are not set. |
| linkage_function: The linkage function to use. Default is: |
| `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` |
| NOTE: Make sure to use the same `linkage_function` as used |
| in `topic_model.hierarchical_topics`. |
| distance_function: The distance function to use on the c-TF-IDF matrix. Default is: |
| `lambda x: 1 - cosine_similarity(x)`. |
| You can pass any function that returns either a square matrix of |
| shape (n_samples, n_samples) with zeros on the diagonal and |
| non-negative values or condensed distance matrix of shape |
| (n_samples * (n_samples - 1) / 2,) containing the upper |
| triangular of the distance matrix. |
| NOTE: Make sure to use the same `distance_function` as used |
| in `topic_model.hierarchical_topics`. |
| color_threshold: Value at which the separation of clusters will be made which |
| will result in different colors for different clusters. |
| A higher value will typically lead in less colored clusters. |
| |
| Returns: |
| fig: A plotly figure |
| |
| Examples: |
| |
| To visualize the hierarchical structure of |
| topics simply run: |
| |
| ```python |
| topic_model.visualize_hierarchy() |
| ``` |
| |
| If you also want the labels visualized of hierarchical topics, |
| run the following: |
| |
| ```python |
| # Extract hierarchical topics and their representations |
| hierarchical_topics = topic_model.hierarchical_topics(docs) |
| |
| # Visualize these representations |
| topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) |
| ``` |
| |
| If you want to save the resulting figure: |
| |
| ```python |
| fig = topic_model.visualize_hierarchy() |
| fig.write_html("path/to/file.html") |
| ``` |
| <iframe src="../../getting_started/visualization/hierarchy.html" |
| style="width:1000px; height: 680px; border: 0px;""></iframe> |
| """ |
| if distance_function is None: |
| distance_function = lambda x: 1 - cosine_similarity(x) |
|
|
| if linkage_function is None: |
| linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) |
|
|
| |
| freq_df = topic_model.get_topic_freq() |
| freq_df = freq_df.loc[freq_df.Topic != -1, :] |
| if topics is not None: |
| topics = list(topics) |
| elif top_n_topics is not None: |
| topics = sorted(freq_df.Topic.to_list()[:top_n_topics]) |
| else: |
| topics = sorted(freq_df.Topic.to_list()) |
|
|
| |
| all_topics = sorted(list(topic_model.get_topics().keys())) |
| indices = np.array([all_topics.index(topic) for topic in topics]) |
|
|
| |
| if topic_model.c_tf_idf_ is not None: |
| embeddings = topic_model.c_tf_idf_[indices] |
| else: |
| embeddings = np.array(topic_model.topic_embeddings_)[indices] |
| |
| |
| if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()): |
| annotations = _get_annotations(topic_model=topic_model, |
| hierarchical_topics=hierarchical_topics, |
| embeddings=embeddings, |
| distance_function=distance_function, |
| linkage_function=linkage_function, |
| orientation=orientation, |
| custom_labels=custom_labels) |
| else: |
| annotations = None |
|
|
| |
| distance_function_viz = lambda x: validate_distance_matrix( |
| distance_function(x), embeddings.shape[0]) |
| |
| fig = ff.create_dendrogram(embeddings, |
| orientation=orientation, |
| distfun=distance_function_viz, |
| linkagefun=linkage_function, |
| hovertext=annotations, |
| color_threshold=color_threshold) |
|
|
| |
| axis = "yaxis" if orientation == "left" else "xaxis" |
| if isinstance(custom_labels, str): |
| new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]] |
| new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] |
| new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] |
| elif topic_model.custom_labels_ is not None and custom_labels: |
| new_labels = [topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]] |
| else: |
| new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) |
| for x in fig.layout[axis]["ticktext"]] |
| new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] |
| new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] |
|
|
| |
| fig.update_layout( |
| plot_bgcolor='#ECEFF1', |
| template="plotly_white", |
| title={ |
| 'text': f"{title}", |
| 'x': 0.5, |
| 'xanchor': 'center', |
| 'yanchor': 'top', |
| 'font': dict( |
| size=22, |
| color="Black") |
| }, |
| hoverlabel=dict( |
| bgcolor="white", |
| font_size=16, |
| font_family="Rockwell" |
| ), |
| ) |
|
|
| |
| if orientation == "left": |
| fig.update_layout(height=200 + (15 * len(topics)), |
| width=width, |
| yaxis=dict(tickmode="array", |
| ticktext=new_labels)) |
|
|
| |
| y_max = max([trace['y'].max() + 5 for trace in fig['data']]) |
| y_min = min([trace['y'].min() - 5 for trace in fig['data']]) |
| fig.update_layout(yaxis=dict(range=[y_min, y_max])) |
|
|
| else: |
| fig.update_layout(width=200 + (15 * len(topics)), |
| height=height, |
| xaxis=dict(tickmode="array", |
| ticktext=new_labels)) |
|
|
| if hierarchical_topics is not None: |
| for index in [0, 3]: |
| axis = "x" if orientation == "left" else "y" |
| xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] |
| ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] |
| hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] |
|
|
| fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black', |
| hovertext=hovertext, hoverinfo="text", |
| mode='markers', showlegend=False)) |
| return fig |
|
|
|
|
| def _get_annotations(topic_model, |
| hierarchical_topics: pd.DataFrame, |
| embeddings: csr_matrix, |
| linkage_function: Callable[[csr_matrix], np.ndarray], |
| distance_function: Callable[[csr_matrix], csr_matrix], |
| orientation: str, |
| custom_labels: bool = False) -> List[List[str]]: |
|
|
| """ Get annotations by replicating linkage function calculation in scipy |
| |
| Arguments |
| topic_model: A fitted BERTopic instance. |
| hierarchical_topics: A dataframe that contains a hierarchy of topics |
| represented by their parents and their children. |
| NOTE: The hierarchical topic names are only visualized |
| if both `topics` and `top_n_topics` are not set. |
| embeddings: The c-TF-IDF matrix on which to model the hierarchy |
| linkage_function: The linkage function to use. Default is: |
| `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)` |
| NOTE: Make sure to use the same `linkage_function` as used |
| in `topic_model.hierarchical_topics`. |
| distance_function: The distance function to use on the c-TF-IDF matrix. Default is: |
| `lambda x: 1 - cosine_similarity(x)`. |
| You can pass any function that returns either a square matrix of |
| shape (n_samples, n_samples) with zeros on the diagonal and |
| non-negative values or condensed distance matrix of shape |
| (n_samples * (n_samples - 1) / 2,) containing the upper |
| triangular of the distance matrix. |
| NOTE: Make sure to use the same `distance_function` as used |
| in `topic_model.hierarchical_topics`. |
| orientation: The orientation of the figure. |
| Either 'left' or 'bottom' |
| custom_labels: Whether to use custom topic labels that were defined using |
| `topic_model.set_topic_labels`. |
| NOTE: Custom labels are only generated for the original |
| un-merged topics. |
| |
| Returns: |
| text_annotations: Annotations to be used within Plotly's `ff.create_dendogram` |
| """ |
| df = hierarchical_topics.loc[hierarchical_topics.Parent_Name != "Top", :] |
|
|
| |
| X = distance_function(embeddings) |
| X = validate_distance_matrix(X, embeddings.shape[0]) |
|
|
| |
| Z = linkage_function(X) |
| P = sch.dendrogram(Z, orientation=orientation, no_plot=True) |
|
|
| |
| x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) |
| x_topic = dict(zip(P['leaves'], x_ticks)) |
|
|
| topic_vals = dict() |
| for key, val in x_topic.items(): |
| topic_vals[val] = [key] |
|
|
| parent_topic = dict(zip(df.Parent_ID, df.Topics)) |
|
|
| |
| text_annotations = [] |
| for index, trace in enumerate(P['icoord']): |
| fst_topic = topic_vals[trace[0]] |
| scnd_topic = topic_vals[trace[2]] |
|
|
| if len(fst_topic) == 1: |
| if isinstance(custom_labels, str): |
| fst_name = f"{fst_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][fst_topic[0]]))[0][:3]) |
| elif topic_model.custom_labels_ is not None and custom_labels: |
| fst_name = topic_model.custom_labels_[fst_topic[0] + topic_model._outliers] |
| else: |
| fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) |
| else: |
| for key, value in parent_topic.items(): |
| if set(value) == set(fst_topic): |
| fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] |
|
|
| if len(scnd_topic) == 1: |
| if isinstance(custom_labels, str): |
| scnd_name = f"{scnd_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][scnd_topic[0]]))[0][:3]) |
| elif topic_model.custom_labels_ is not None and custom_labels: |
| scnd_name = topic_model.custom_labels_[scnd_topic[0] + topic_model._outliers] |
| else: |
| scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) |
| else: |
| for key, value in parent_topic.items(): |
| if set(value) == set(scnd_topic): |
| scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0] |
|
|
| text_annotations.append([fst_name, "", "", scnd_name]) |
|
|
| center = (trace[0] + trace[2]) / 2 |
| topic_vals[center] = fst_topic + scnd_topic |
|
|
| return text_annotations |
|
|