Persian-Topic-Modeling / topic_modeling.py
AhdCompnay's picture
Update topic_modeling.py
5dd9751
from bertopic import BERTopic
from scipy.cluster import hierarchy as sch
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.datasets import fetch_20newsgroups
from bertopic import BERTopic
# from wordcloud import WordCloud
import matplotlib.pyplot as plt
from wordcloud_fa import WordCloudFa
import os
import utils
embed_model = os.environ.get("EMBED_MODEL")
class TopicModeling:
def __init__(self, stopwords_path='./assets/stopwords.txt', specific_stopwords_path='./assets/shahrara_stopwords.txt', embedding_model= embed_model) -> None:
stopwords = open(stopwords_path).read().splitlines()
specific_stopwords = open(specific_stopwords_path).read().splitlines()
stopwords = stopwords + specific_stopwords
vectorizer_model = CountVectorizer(stop_words=stopwords)
self.topic_model = BERTopic(embedding_model=embedding_model, vectorizer_model=vectorizer_model, verbose=True)
def add_data(self, df):
print('add data')
# df = df.dropna()
df['FINAL_CONCATED_TEXT_FOR_TOPIC'] = df.apply(lambda x: '. '.join(x), axis=1)
df['FINAL_CONCATED_TEXT_FOR_TOPIC'] = df['FINAL_CONCATED_TEXT_FOR_TOPIC'].apply(utils.normalize)
docs = list(set(df['FINAL_CONCATED_TEXT_FOR_TOPIC'].tolist()))
docs = [d for d in docs if d and type(d) == str and len(d.split())>3]
print('len docs ', len(docs))
return docs
def fit(self, docs):
print('self docs : ', len(docs))
print(docs[:5])
self.topics, self.probs = self.topic_model.fit_transform(docs)
def get_barchart(self):
return self.topic_model.visualize_barchart()
def get_vis_topics(self):
return self.topic_model.visualize_topics()
def get_h_topics(self):
linkage_function = lambda x: sch.linkage(x, 'single', optimal_ordering=True)
hierarchical_topics = self.topic_model.hierarchical_topics(self.docs, linkage_function=linkage_function)
return self.topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics)
def topic_over_tome(self):
# # Create topics over time
# model = BERTopic(verbose=True)
topics_over_time = self.topic_model.topics_over_time(self.docs, self.timestamps, datetime_format="%m-%d")
return self.topic_model.visualize_topics_over_time(topics_over_time, top_n_topics=5)
def visualize_documents(self, docs):
self.topic_model.visualize_documents(docs, embeddings=embeddings)
reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)
def get_topic_info(self):
return self.topic_model.get_topic_info()
def get_wordcloud(self):
all_plts = []
topic_counts = len(self.topic_model.get_topic_info())
if topic_counts > 30:
topic_counts = 30
print('topic count ', topic_counts)
for topic_index in range(topic_counts):
print(topic_index)
top_n_words = self.topic_model.get_topic(topic_index)
if type(top_n_words) != bool:
text = {word: value for word, value in top_n_words}
wc = WordCloudFa(background_color="white", max_words=1000, no_reshape=True)
wc.generate_from_frequencies(text)
plt.imshow(wc, interpolation="bilinear")
plt.axis("off")
fig = plt.figure()
all_plts.append(fig)
# plt.show()
return all_plts
def get_wordcloud_by_topic(self, topic_index):
top_n_words = self.topic_model.get_topic(topic_index)
if type(top_n_words) != bool:
text = {word: value for word, value in top_n_words}
wc = WordCloudFa(background_color="white", max_words=1000, no_reshape=True)
wc.generate_from_frequencies(text)
plt.imshow(wc, interpolation="bilinear")
plt.axis("off")
fig = plt.figure()
return fig
return None