File size: 3,434 Bytes
99e744f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from model import get_model
from mapReduceSummarizer import get_map_reduce_chain
from refineSummarizer import get_refine_chain
from preprocess import prepare_for_summarize
from transformers import AutoTokenizer
from langchain.prompts import PromptTemplate
from logging import getLogger
import time

logger = getLogger(__name__)
class Summarizer:


    def __init__(self,model_name,model_type,api_key=None) -> None:
        self.model_type = model_type
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.base_summarizer = get_model(model_type,model_name,api_key) 

    def summarize(self,text:str,summarizer_type = "map_reduce")->str:

        text_to_summarize,length_type = prepare_for_summarize(text,self.tokenizer)

        if length_type =="short":

            logger.info("Processing Input Text less than 12000 Tokens")
            if self.model_type=="openai":
                llm = self.base_summarizer
                prompt = PromptTemplate.from_template(
                    template="""Write a concise and complete summary in bullet points of the given annual report.
                        Important:
                        * Note that the summary should contain all important information and it should not contain any unwanted information.
                        * Make sure to keep the summary as short as possible. And Summary should be in bullet points. Seperate each point with a new line.
                        TEXT: {text}
                        SUMMARY:"""
                )
                llm_chain = prompt|llm
                start = time.time()
                summary =  llm_chain.invoke({"text": text_to_summarize})
                end = time.time()
                print(f"Summary generation took {round((end-start),2)}s.")
                return summary,round((end-start),2)
            
            elif self.model_type == "local":
                pipe = self.base_summarizer
                start = time.time()
                summary = pipe(text_to_summarize)[0]['summary_text']
                end = time.time()
                print(f"Summary generation took {round((end-start),2)}s.")
                return summary,round((end-start),2)
        else:
            if summarizer_type == "refine":
                print("The text is too long, Running Refine Summarizer")
                llm_chain = get_refine_chain(self.base_summarizer,self.model_type)
                logger.info("Running Refine Chain for Summarization")
                start = time.time()
                summary = llm_chain.invoke({"input_documents": text_to_summarize}, return_only_outputs=True)['output_text']
                end = time.time()
                print(f"Summary generation took {round((end-start),2)}s.")
                return summary,round((end-start),2)


            else: 
                print("The text is too long, Running Map Reduce Summarizer")
                
                llm_chain = get_map_reduce_chain(self.base_summarizer,model_type=self.model_type)
                logger.info("Running Map Reduce Chain for Summarization")
                start = time.time()
                summary = llm_chain.invoke({"input_documents": text_to_summarize}, return_only_outputs=True)['output_text']
                end = time.time()
                print(f"Summary generation took {round((end-start),2)}s.")
                return summary,round((end-start),2)