KevlarVK commited on
Commit
9a4b6ed
1 Parent(s): ae59ea3

Added support for title generation

Browse files
Files changed (2) hide show
  1. summarizer.py +7 -3
  2. title_generator.py +14 -0
summarizer.py CHANGED
@@ -5,6 +5,7 @@ import networkx as nx
5
  from nltk.tokenize import sent_tokenize
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  import community
 
8
 
9
 
10
  class BARTSummarizer:
@@ -14,6 +15,7 @@ class BARTSummarizer:
14
  self.tokenizer = BartTokenizer.from_pretrained(model_name)
15
  self.model = TFBartForConditionalGeneration.from_pretrained(model_name)
16
  self.max_length = self.model.config.max_position_embeddings
 
17
 
18
  def summarize(self, text: str, auto: bool = False):
19
  encoded_input = self.tokenizer.encode(text, max_length=self.max_length, return_tensors='tf', truncation=True)
@@ -82,12 +84,14 @@ class BARTSummarizer:
82
  clustered_sentences.append(" ".join(sentences_to_print))
83
 
84
  # Summarize each cluster
85
- summaries = []
86
  for cluster in clustered_sentences:
87
- summaries.append(self.chunk_summarize(cluster, auto=True))
 
 
88
 
89
  # Combine the summaries to get the final summary for the entire input
90
- final_summary = "\n\n".join(summaries)
91
 
92
  return final_summary
93
 
 
5
  from nltk.tokenize import sent_tokenize
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  import community
8
+ from title_generator import T5Summarizer
9
 
10
 
11
  class BARTSummarizer:
 
15
  self.tokenizer = BartTokenizer.from_pretrained(model_name)
16
  self.model = TFBartForConditionalGeneration.from_pretrained(model_name)
17
  self.max_length = self.model.config.max_position_embeddings
18
+ self.title_model = T5Summarizer()
19
 
20
  def summarize(self, text: str, auto: bool = False):
21
  encoded_input = self.tokenizer.encode(text, max_length=self.max_length, return_tensors='tf', truncation=True)
 
84
  clustered_sentences.append(" ".join(sentences_to_print))
85
 
86
  # Summarize each cluster
87
+ summaries_with_title = []
88
  for cluster in clustered_sentences:
89
+ summary = self.chunk_summarize(cluster, auto=True)
90
+ summary_with_title = "#### " + self.title_model.summarize(summary) + "\n" + summary
91
+ summaries_with_title.append(summary_with_title)
92
 
93
  # Combine the summaries to get the final summary for the entire input
94
+ final_summary = "\n\n".join(summaries_with_title)
95
 
96
  return final_summary
97
 
title_generator.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
2
+
3
+ class T5Summarizer:
4
+ def __init__(self, model_name: str = "fabiochiu/t5-small-medium-title-generation"):
5
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
7
+
8
+ def summarize(self, text: str):
9
+ inputs = ["summarize: " + text]
10
+ max_input_length = self.tokenizer.model_max_length
11
+ inputs = self.tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="tf")
12
+ output = self.model.generate(**inputs, num_beams=8, do_sample=True, min_length=1, max_length=10, early_stopping=True)
13
+ summary = self.tokenizer.batch_decode(output, skip_special_tokens=True)[0]
14
+ return summary