File size: 4,046 Bytes
3bb118d
 
 
 
 
385c7ce
3bb118d
 
9a4b6ed
3bb118d
 
 
 
 
 
 
 
 
9a4b6ed
3bb118d
3d281d8
3bb118d
3d281d8
 
 
 
3bb118d
 
 
3d281d8
3bb118d
 
 
 
 
 
 
 
3d281d8
3bb118d
 
 
 
 
 
 
 
 
 
 
385c7ce
 
 
 
 
3bb118d
 
 
 
 
3d281d8
3bb118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a4b6ed
3bb118d
c7a5c72
9a4b6ed
c7a5c72
9a4b6ed
3bb118d
 
9a4b6ed
3d281d8
 
3bb118d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from datetime import datetime
from transformers import BartTokenizer, TFBartForConditionalGeneration
from Utils import get_input_chunks
import networkx as nx
from nltk.tokenize import sent_tokenize
import nltk
from sklearn.feature_extraction.text import TfidfVectorizer
import community
from title_generator import T5Summarizer


class BARTSummarizer:

    def __init__(self, model_name: str = 'facebook/bart-large-cnn'):
        self.model_name = model_name
        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = TFBartForConditionalGeneration.from_pretrained(model_name)
        self.max_length = self.model.config.max_position_embeddings
        self.title_model = T5Summarizer()

    def summarize(self, text: str, auto: bool = False):
        encoded_input = self.tokenizer.encode(text, max_length=self.max_length, return_tensors='tf', truncation=True)
        if auto:
            summary_ids = self.model.generate(encoded_input, max_length=300, num_beams=1, no_repeat_ngram_size=2, min_length=60)
        else:
            summary_ids = self.model.generate(encoded_input, max_length=300, num_beams=4, early_stopping=True)
        summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return summary
    
    def chunk_summarize(self, text: str, auto: bool = False):

        # split the input into chunks
        summaries = []
        input_chunks = get_input_chunks(text, self.max_length)

        # summarize each input chunk separately
        print(datetime.now().strftime("%H:%M:%S"))
        for chunk in input_chunks:
            summaries.append(self.summarize(chunk, auto))
            
        # # combine the summaries to get the final summary for the entire input
        final_summary = " ".join(summaries)

        print(datetime.now().strftime("%H:%M:%S"))

        return final_summary
    
    def preprocess_for_auto_chapters(self, text: str):
       
        # Tokenize the text into sentences
        try:
            sentences = sent_tokenize(text)
        except:
            nltk.download('punkt')
            sentences = sent_tokenize(text)

        # Filter out empty sentences and sentences with less than 5 words
        sentences = [sentence for sentence in sentences if len(sentence.strip()) > 0 and len(sentence.split(" ")) > 4]

        # Combine every 5 sentences into a single sentence
        sentences = [' '.join(sentences[i:i + 6]) for i in range(0, len(sentences), 5)]

        return sentences
    
    def auto_chapters_summarize(self, text: str):

        sentences = self.preprocess_for_auto_chapters(text)

        vectorizer = TfidfVectorizer(stop_words='english')
        X = vectorizer.fit_transform(sentences)

        # Compute the similarity matrix using cosine similarity
        similarity_matrix = X * X.T

        # Convert the similarity matrix to a graph
        graph = nx.from_scipy_sparse_array(similarity_matrix)

        # Apply the Louvain algorithm to identify communities
        partition = community.best_partition(graph, resolution=0.7, random_state=42)

        # Cluster the sentences
        clustered_sentences = []
        for cluster in set(partition.values()):
            sentences_to_print = []
            for i, sentence in enumerate(sentences):
                if partition[i] == cluster:
                    sentences_to_print.append(sentence)
            if len(sentences_to_print) > 1:
                clustered_sentences.append(" ".join(sentences_to_print))
        
        # Summarize each cluster
        summaries_with_title = []
        for cluster in clustered_sentences:
            title = self.title_model.summarize(cluster)
            summary = self.chunk_summarize(cluster, auto=True)
            summary_with_title = "#### " + title + "\n" + summary
            summaries_with_title.append(summary_with_title)

        # Combine the summaries to get the final summary for the entire input
        final_summary = "\n\n".join(summaries_with_title)

        return final_summary