from bertopic import BERTopic from scipy.cluster import hierarchy as sch from sklearn.feature_extraction.text import CountVectorizer from sklearn.datasets import fetch_20newsgroups from bertopic import BERTopic # from wordcloud import WordCloud import matplotlib.pyplot as plt from wordcloud_fa import WordCloudFa import utils embed_model = os.environ.get("EMBED_MODEL") class TopicModeling: def __init__(self, stopwords_path='./assets/stopwords.txt', specific_stopwords_path='./assets/shahrara_stopwords.txt', embedding_model= embed_model) -> None: stopwords = open(stopwords_path).read().splitlines() specific_stopwords = open(specific_stopwords_path).read().splitlines() stopwords = stopwords + specific_stopwords vectorizer_model = CountVectorizer(stop_words=stopwords) self.topic_model = BERTopic(embedding_model=embedding_model, vectorizer_model=vectorizer_model, verbose=True) def add_data(self, df): print('add data') # df = df.dropna() df['FINAL_CONCATED_TEXT_FOR_TOPIC'] = df.apply(lambda x: '. '.join(x), axis=1) df['FINAL_CONCATED_TEXT_FOR_TOPIC'] = df['FINAL_CONCATED_TEXT_FOR_TOPIC'].apply(utils.normalize) docs = list(set(df['FINAL_CONCATED_TEXT_FOR_TOPIC'].tolist())) docs = [d for d in docs if d and type(d) == str and len(d.split())>3] print('len docs ', len(docs)) return docs def fit(self, docs): print('self docs : ', len(docs)) print(docs[:5]) self.topics, self.probs = self.topic_model.fit_transform(docs) def get_barchart(self): return self.topic_model.visualize_barchart() def get_vis_topics(self): return self.topic_model.visualize_topics() def get_h_topics(self): linkage_function = lambda x: sch.linkage(x, 'single', optimal_ordering=True) hierarchical_topics = self.topic_model.hierarchical_topics(self.docs, linkage_function=linkage_function) return self.topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) def topic_over_tome(self): # # Create topics over time # model = BERTopic(verbose=True) topics_over_time = self.topic_model.topics_over_time(self.docs, self.timestamps, datetime_format="%m-%d") return self.topic_model.visualize_topics_over_time(topics_over_time, top_n_topics=5) def visualize_documents(self, docs): self.topic_model.visualize_documents(docs, embeddings=embeddings) reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) def get_topic_info(self): return self.topic_model.get_topic_info() def get_wordcloud(self): all_plts = [] topic_counts = len(self.topic_model.get_topic_info()) if topic_counts > 30: topic_counts = 30 print('topic count ', topic_counts) for topic_index in range(topic_counts): print(topic_index) top_n_words = self.topic_model.get_topic(topic_index) if type(top_n_words) != bool: text = {word: value for word, value in top_n_words} wc = WordCloudFa(background_color="white", max_words=1000, no_reshape=True) wc.generate_from_frequencies(text) plt.imshow(wc, interpolation="bilinear") plt.axis("off") fig = plt.figure() all_plts.append(fig) # plt.show() return all_plts def get_wordcloud_by_topic(self, topic_index): top_n_words = self.topic_model.get_topic(topic_index) if type(top_n_words) != bool: text = {word: value for word, value in top_n_words} wc = WordCloudFa(background_color="white", max_words=1000, no_reshape=True) wc.generate_from_frequencies(text) plt.imshow(wc, interpolation="bilinear") plt.axis("off") fig = plt.figure() return fig return None