File size: 4,153 Bytes
0c969fd
 
 
 
 
 
 
 
5dd9751
0c969fd
 
 
581d930
 
0c969fd
581d930
0c969fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from bertopic import BERTopic
from scipy.cluster import hierarchy as sch
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.datasets import fetch_20newsgroups
from bertopic import BERTopic
# from wordcloud import WordCloud
import matplotlib.pyplot as plt
from wordcloud_fa import WordCloudFa
import os

import utils

embed_model = os.environ.get("EMBED_MODEL")

class TopicModeling:
    def __init__(self, stopwords_path='./assets/stopwords.txt', specific_stopwords_path='./assets/shahrara_stopwords.txt', embedding_model= embed_model) -> None:      
        stopwords = open(stopwords_path).read().splitlines()
        specific_stopwords = open(specific_stopwords_path).read().splitlines()
        stopwords = stopwords + specific_stopwords
        vectorizer_model = CountVectorizer(stop_words=stopwords)
        self.topic_model = BERTopic(embedding_model=embedding_model, vectorizer_model=vectorizer_model, verbose=True)


    def add_data(self,  df):
        print('add data')
        # df = df.dropna()
        df['FINAL_CONCATED_TEXT_FOR_TOPIC'] = df.apply(lambda x: '. '.join(x), axis=1)
        df['FINAL_CONCATED_TEXT_FOR_TOPIC'] = df['FINAL_CONCATED_TEXT_FOR_TOPIC'].apply(utils.normalize)
        docs = list(set(df['FINAL_CONCATED_TEXT_FOR_TOPIC'].tolist()))
        docs = [d for d in docs if d and type(d) == str and len(d.split())>3]
        print('len docs ', len(docs))
        return docs


    def fit(self, docs):
        print('self docs : ', len(docs))
        print(docs[:5])
        self.topics, self.probs = self.topic_model.fit_transform(docs)

    def get_barchart(self):
        return self.topic_model.visualize_barchart()


    def get_vis_topics(self):
        return self.topic_model.visualize_topics()


    def get_h_topics(self):
        linkage_function = lambda x: sch.linkage(x, 'single', optimal_ordering=True)
        hierarchical_topics = self.topic_model.hierarchical_topics(self.docs, linkage_function=linkage_function)
        return self.topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics)

    def topic_over_tome(self):
        # # Create topics over time
        # model = BERTopic(verbose=True)
        topics_over_time = self.topic_model.topics_over_time(self.docs, self.timestamps, datetime_format="%m-%d")
        return self.topic_model.visualize_topics_over_time(topics_over_time, top_n_topics=5)
      
    
    def visualize_documents(self, docs):
        self.topic_model.visualize_documents(docs, embeddings=embeddings)
        reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
        topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)


    def get_topic_info(self):
        return self.topic_model.get_topic_info()


    def get_wordcloud(self):
        all_plts = []
        topic_counts = len(self.topic_model.get_topic_info())
        if topic_counts > 30:
            topic_counts = 30
        print('topic count ', topic_counts)
        for topic_index in range(topic_counts):
            print(topic_index)
            top_n_words = self.topic_model.get_topic(topic_index)
            if type(top_n_words) != bool:
                text = {word: value for word, value in  top_n_words}
                wc = WordCloudFa(background_color="white", max_words=1000, no_reshape=True)
                wc.generate_from_frequencies(text)
                plt.imshow(wc, interpolation="bilinear")
                plt.axis("off")
                fig = plt.figure()
                all_plts.append(fig)
                # plt.show()
        return all_plts
    
    def get_wordcloud_by_topic(self, topic_index):
        top_n_words = self.topic_model.get_topic(topic_index)
        if type(top_n_words) != bool:
            text = {word: value for word, value in  top_n_words}
            wc = WordCloudFa(background_color="white", max_words=1000, no_reshape=True)
            wc.generate_from_frequencies(text)
            plt.imshow(wc, interpolation="bilinear")
            plt.axis("off")
            fig = plt.figure()
            return fig
        return None