File size: 1,867 Bytes
1fb8ec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import nltk
import ssl
import re
try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context
nltk.download('punkt')
nltk.download('stopwords')
from transformers import BartTokenizer, PegasusTokenizer
from transformers import BartForConditionalGeneration, PegasusForConditionalGeneration
from tqdm.notebook import tqdm



class Abstractive_Summarization_Model:
    def __init__(self):
        self.text = None
        self.IS_CNNDM = True # whether to use CNNDM dataset or XSum dataset
        self.LOWER = False
        self.max_length = 1024 if self.IS_CNNDM else 512
        self.model, self.tokenizer  = self.load_model()
        
    def load_model(self):
        # Load our model checkpoints
        print('[INFO]: Loading model ...')
        if self.IS_CNNDM:
            model = BartForConditionalGeneration.from_pretrained('Yale-LILY/brio-cnndm-uncased')
            tokenizer = BartTokenizer.from_pretrained('Yale-LILY/brio-cnndm-uncased')
        else:
            model = PegasusForConditionalGeneration.from_pretrained('Yale-LILY/brio-xsum-cased')
            tokenizer = PegasusTokenizer.from_pretrained('Yale-LILY/brio-xsum-cased')
        print('[INFO]: Model Successfully Loaded :)')    
        return model, tokenizer
        
        
    def summarize(self, text):
        # generation example
        if self.LOWER:
            article = text.lower()
        else:
            article = text
        inputs = self.tokenizer([article], max_length=self.max_length, return_tensors="pt", truncation=True)
        # Generate Summary
        summary_ids = self.model.generate(inputs["input_ids"])
        
        return self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]