BERTinsights / helper.py
Yara Kyrychenko
Helper
17e20d0
from river import stream
from river import cluster
class River:
def __init__(self, model):
self.model = model
def partial_fit(self, umap_embeddings):
for umap_embedding, _ in stream.iter_array(umap_embeddings):
self.model = self.model.learn_one(umap_embedding)
labels = []
for umap_embedding, _ in stream.iter_array(umap_embeddings):
label = self.model.predict_one(umap_embedding)
labels.append(label)
self.labels_ = labels
return self
import pandas as pd
from typing import List
import plotly.graph_objects as go
from sklearn.preprocessing import normalize
def visualize_topics_over_time(topic_model,
topics_over_time: pd.DataFrame,
top_n_topics: int = None,
topics: List[int] = None,
normalize_frequency: bool = False,
custom_labels: bool = False,
title: str = "<b>Topics over Time</b>",
width: int = 860,
height: int = 600) -> go.Figure:
"""
Based on BERTopic's funciton https://github.com/MaartenGr/BERTopic/blob/809414b88ca3f12a46728069d098d82345986489/bertopic/plotting/_topics_over_time.py
"""
#colors = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#D55E00", "#0072B2", "#CC79A7"]
# Select topics based on top_n and topics args
freq_df = topic_model.get_topic_freq()
freq_df = freq_df.loc[freq_df.Topic != -1, :]
if topics is not None:
selected_topics = list(topics)
elif top_n_topics is not None:
selected_topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
else:
selected_topics = sorted(freq_df.Topic.to_list())
# Prepare data
if topic_model.custom_labels_ is not None and custom_labels:
topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()}
else:
topic_names = {key: value[:30] + "..." if len(value) > 30 else value
for key, value in topic_model.topic_labels_.items()}
topics_over_time["Name"] = topics_over_time.Topic.map(topic_names)
data = topics_over_time.loc[topics_over_time.Topic.isin(selected_topics), :].sort_values(["Topic", "Timestamp"])
# Add traces
fig = go.Figure()
for index, topic in enumerate(data.Topic.unique()):
trace_data = data.loc[data.Topic == topic, :]
topic_name = trace_data.Name.values[0]
words = trace_data.Words.values
if normalize_frequency:
y = normalize(trace_data.Frequency.values.reshape(1, -1))[0]
else:
y = trace_data.Frequency
fig.add_trace(go.Scatter(x=pd.to_datetime(trace_data.Timestamp), y=y,
mode='lines',
#marker_color=colors[index % 7],
hoverinfo="text",
name=topic_name,
hovertext=[f'<b>Topic {topic}</b><br>Words: {word}' for word in words]))
# Styling of the visualization
#fig.update_xaxes(
# dtick=7,
# tickformat="%b\n%Y"
# )
fig.update_layout(
yaxis_title="Normalized Frequency" if normalize_frequency else "Frequency",
title={'text':f'{title}',
'font': dict(size=22)
},
width=width,
height=height,
hoverlabel=dict(
bgcolor="white",
font_size=16,
#font_family="Rockwell"
),
legend=dict(
title="<b>Global Topic Representation",
orientation="h",
y = -.2,
x = 0
#yanchor="bottom",
#xanchor="left"
)
)
return fig
def visualize_topics_per_class(topic_model,
topics_per_class: pd.DataFrame,
top_n_topics: int = 10,
topics: List[int] = None,
normalize_frequency: bool = False,
custom_labels: bool = False,
title: str = "<b>Topics per Class</b>",
width: int = 900,
height: int = 900) -> go.Figure:
"""
Based on BERTopic's funciton https://github.com/MaartenGr/BERTopic/blob/809414b88ca3f12a46728069d098d82345986489/bertopic/plotting/_topics_per_class.py
"""
# Select topics based on top_n and topics args
freq_df = topic_model.get_topic_freq()
freq_df = freq_df.loc[freq_df.Topic != -1, :]
if topics is not None:
selected_topics = list(topics)
elif top_n_topics is not None:
#selected_topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
selected_topics = freq_df.Topic.to_list()[:top_n_topics]
else:
selected_topics = sorted(freq_df.Topic.to_list())
# Prepare data
if topic_model.custom_labels_ is not None and custom_labels:
topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()}
else:
topic_names = {key: value[:40] + "..." if len(value) > 40 else value
for key, value in topic_model.topic_labels_.items()}
topics_per_class["Name"] = topics_per_class.Topic.map(topic_names)
data = topics_per_class.loc[topics_per_class.Topic.isin(selected_topics), :]
# Add traces
fig = go.Figure()
for index, topic in enumerate(selected_topics):
if index == 0:
visible = True
else:
visible = "legendonly"
trace_data = data.loc[data.Topic == topic, :]
topic_name = trace_data.Name.values[0]
words = trace_data.Words.values
if normalize_frequency:
x = normalize(trace_data.Frequency.values.reshape(1, -1))[0]
else:
x = trace_data.Frequency
fig.add_trace(go.Bar(y=trace_data.Class,
x=x,
visible=visible,
hoverinfo="text",
name=topic_name,
orientation="h",
hovertext=[f'<b>Topic {topic}</b><br>Words: {word}' for word in words]))
# Styling of the visualization
fig.update_xaxes(showgrid=True)
fig.update_yaxes(showgrid=True)
fig.update_layout(
xaxis_title="Normalized Frequency" if normalize_frequency else "Frequency",
yaxis_title="Class",
title={
'text': f"{title}",
'font': dict(
size=22)
},
width=width,
height=height,
hoverlabel=dict(
bgcolor="white",
font_size=16,
),
legend=dict(
title="<b>Global Topic Representation",
)
)
return fig