amielle commited on
Commit
cb09dc9
1 Parent(s): 34f51c7

feat: Build class for summarizer pipeline

Browse files
Files changed (1) hide show
  1. util/summarizer.py +61 -64
util/summarizer.py CHANGED
@@ -1,74 +1,71 @@
1
- from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
 
2
 
3
- # by default encoder-attention is `block_sparse` with num_random_blocks=3, block_size=64
4
-
5
- # TODO: add pre-trained summarizer models
6
- # Placeholder text for testing input
7
- test_text = """
8
- Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.
9
- """
10
  summary_options = ["Abstract", "Background", "Claims"]
11
-
12
- def get_word_index(s, idx):
13
- words = re.findall(r'\s*\S+\s*', s)
14
- return sum(map(len, words[:idx])) + len(words[idx]) - len(words[idx].lstrip())
15
-
16
-
 
 
 
 
 
 
17
  class PatentSummarizer():
18
- def __init__(self, base_model_name="google/bigbird-pegasus-large-bigpatent"):
19
- # Possible to tweak other summaries with different models in the future
20
- self.model = dict()
21
- self.tokenizer = dict()
22
-
23
- self.base_model = BigBirdPegasusForConditionalGeneration.from_pretrained(base_model_name)
24
- self.base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
25
-
26
  self.max_word_input = 1000
27
 
28
 
29
- def pipeline(patent_information, summaries_generated, min_char_abs, min_char_bg, min_char_claims):
30
- # TODO: add checker if valid patent info, return None if invalid
31
- # TODO: add scraper to get document
32
 
33
- # TODO: add parser to get the following info from the base document:
34
- abstract, background, claims = None, None, None
 
 
35
 
 
36
  summaries = list()
37
- if "Abstract" in summaries_generated:
38
- abstract_summary = summarizer.generate_abs_summary(abstract, min_char_abs)
39
- summaries.append(abstract_summary)
40
- else:
41
- summaries.append(None)
42
-
43
- if "Background" in summaries_generated:
44
- background_summary = summarizer.generate_bg_summary(background, min_char_bg)
45
- summaries.append(background_summary)
46
- else:
47
- summaries.append(None)
48
-
49
- if "Claims" in summaries_generated:
50
- claims_summary = summarizer.generate_claims_summary(claims, min_char_claims)
51
- summaries.append(claims_summary)
52
- else:
53
- summaries.append(None)
54
-
55
- return summaries
56
-
57
-
58
- def generate_abs_summary(abstract, min_char_abs):
59
- return "Abstract" + test_text
60
-
61
-
62
- def generate_bg_summary(background, min_char_bg):
63
- stop_idx = get_word_index(background, self.max_word_input)
64
- inputs = self.base_tokenizer(background[0:stop_idx],
65
- return_tensors='pt')
66
- prediction = self.base_model.generate(**inputs)
67
- bg_summary = self.base_tokenizer.batch_decode(prediction)
68
- bg_summary = textproc.clean_text(bg_summary[0])
69
-
70
- return bg_summary
71
-
72
 
73
- def generate_claims_summary(claims, min_char_claims):
74
- return "Claims" + test_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from util import textproc
3
 
 
 
 
 
 
 
 
4
  summary_options = ["Abstract", "Background", "Claims"]
5
+ model_names = ["huggingface/google/bigbird-pegasus-large-bigpatent",
6
+ "huggingface/cnicu/t5-small-booksum",
7
+ "huggingface/sshleifer/distilbart-cnn-6-6",
8
+ "huggingface/google/pegasus-xsum"]
9
+
10
+ def init_models():
11
+ model = dict()
12
+ for name in model_names:
13
+ model[name] = gr.Interface.load(name)
14
+ return model
15
+
16
+
17
  class PatentSummarizer():
18
+ def __init__(self, model_collection):
19
+ self.model = model_collection
 
 
 
 
 
 
20
  self.max_word_input = 1000
21
 
22
 
23
+ def pipeline(self, patent_information, summaries_generated, abstract_model, \
24
+ background_model, claims_model, collate_claims, word_limit):
 
25
 
26
+ parsed_info = textproc.retrieve_parsed_doc(patent_information,
27
+ summaries_generated)
28
+ if parsed_info is None:
29
+ return ["[ERROR] Invalid Patent Information.", None, None]
30
 
31
+ abstract, background, claims = parsed_info
32
  summaries = list()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ try:
35
+ if "Abstract" in summaries_generated and abstract is not None:
36
+ abstract = abstract[0: textproc.get_word_index(abstract, word_limit)]
37
+
38
+ abstract_summary = self.model[abstract_model](abstract)
39
+ abstract_summary = textproc.post_process(abstract_summary)
40
+ summaries.append(abstract_summary)
41
+ else:
42
+ summaries.append(None)
43
+
44
+ if "Background" in summaries_generated and background is not None:
45
+ background = background[0: textproc.get_word_index(background, word_limit)]
46
+
47
+ background_summary = self.model[background_model](background)
48
+ background_summary = textproc.post_process(background_summary)
49
+ summaries.append(background_summary)
50
+ else:
51
+ summaries.append(None)
52
+
53
+ if "Claims" in summaries_generated and claims is not None:
54
+ if collate_claims:
55
+ claims = ' '.join(claims)
56
+ print(len(claims))
57
+ claims = claims[0: textproc.get_word_index(claims, word_limit)]
58
+ print(len(claims))
59
+ claims_summary = self.model[claims_model](claims)
60
+ else:
61
+ claims_summary = ''
62
+ for claim in claims:
63
+ claims_summary += self.model[claims_model](claim)
64
+ claims_summary = textproc.post_process(claims_summary)
65
+ summaries.append(claims_summary)
66
+ else:
67
+ summaries.append(None)
68
+
69
+ return summaries
70
+ except Exception as e:
71
+ return [f'[ERROR] {e}'] + [None]*(len(summaries_generated) - 1)