darveen commited on
Commit
1fb8ec8
1 Parent(s): 562f935

Create new file

Browse files
Files changed (1) hide show
  1. model.py +49 -0
model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import ssl
3
+ import re
4
+ try:
5
+ _create_unverified_https_context = ssl._create_unverified_context
6
+ except AttributeError:
7
+ pass
8
+ else:
9
+ ssl._create_default_https_context = _create_unverified_https_context
10
+ nltk.download('punkt')
11
+ nltk.download('stopwords')
12
+ from transformers import BartTokenizer, PegasusTokenizer
13
+ from transformers import BartForConditionalGeneration, PegasusForConditionalGeneration
14
+ from tqdm.notebook import tqdm
15
+
16
+
17
+
18
+ class Abstractive_Summarization_Model:
19
+ def __init__(self):
20
+ self.text = None
21
+ self.IS_CNNDM = True # whether to use CNNDM dataset or XSum dataset
22
+ self.LOWER = False
23
+ self.max_length = 1024 if self.IS_CNNDM else 512
24
+ self.model, self.tokenizer = self.load_model()
25
+
26
+ def load_model(self):
27
+ # Load our model checkpoints
28
+ print('[INFO]: Loading model ...')
29
+ if self.IS_CNNDM:
30
+ model = BartForConditionalGeneration.from_pretrained('Yale-LILY/brio-cnndm-uncased')
31
+ tokenizer = BartTokenizer.from_pretrained('Yale-LILY/brio-cnndm-uncased')
32
+ else:
33
+ model = PegasusForConditionalGeneration.from_pretrained('Yale-LILY/brio-xsum-cased')
34
+ tokenizer = PegasusTokenizer.from_pretrained('Yale-LILY/brio-xsum-cased')
35
+ print('[INFO]: Model Successfully Loaded :)')
36
+ return model, tokenizer
37
+
38
+
39
+ def summarize(self, text):
40
+ # generation example
41
+ if self.LOWER:
42
+ article = text.lower()
43
+ else:
44
+ article = text
45
+ inputs = self.tokenizer([article], max_length=self.max_length, return_tensors="pt", truncation=True)
46
+ # Generate Summary
47
+ summary_ids = self.model.generate(inputs["input_ids"])
48
+
49
+ return self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]