KevlarVK commited on
Commit
9a24169
1 Parent(s): 9d81456

Included support for chunk summarize

Browse files
Files changed (1) hide show
  1. summarize.py +23 -12
summarize.py CHANGED
@@ -1,16 +1,29 @@
1
  from datetime import datetime
 
2
  from transformers import BartTokenizer, TFBartForConditionalGeneration, pipeline
3
  from Utils import fetch_article_text, count_tokens
4
  import re
5
  from nltk.tokenize import sent_tokenize
6
  import nltk
 
7
 
8
  tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
9
  model = TFBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
 
10
 
11
- def bart_summarize(text: str):
 
 
 
 
 
 
 
 
 
 
12
 
13
- max_length = model.config.max_position_embeddings
14
 
15
  try:
16
  sentences = sent_tokenize(text)
@@ -36,19 +49,17 @@ def bart_summarize(text: str):
36
  input_chunks.append(temp_sentences)
37
 
38
  # summarize each input chunk separately
39
- summaries = []
40
  print(datetime.now().strftime("%H:%M:%S"))
41
  for chunk in input_chunks:
42
- # encode the input chunk
 
 
 
 
43
 
44
- encoded_input = tokenizer.encode(chunk, max_length=max_length, return_tensors='tf')
45
-
46
- # generate summary for the input chunk
47
- summary_ids = model.generate(encoded_input, max_length=300, num_beams=4, early_stopping=True)
48
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
49
-
50
- # add the summary to the list of summaries
51
- summaries.append(summary)
52
 
53
  # # combine the summaries to get the final summary for the entire input
54
  final_summary = " ".join(summaries)
 
1
  from datetime import datetime
2
+ import multiprocessing
3
  from transformers import BartTokenizer, TFBartForConditionalGeneration, pipeline
4
  from Utils import fetch_article_text, count_tokens
5
  import re
6
  from nltk.tokenize import sent_tokenize
7
  import nltk
8
+ import threading
9
 
10
  tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
11
  model = TFBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
12
+ max_length = model.config.max_position_embeddings
13
 
14
+ summaries = []
15
+
16
+ def generate_summary(text: str):
17
+ encoded_input = tokenizer.encode(text, max_length=max_length, return_tensors='tf')
18
+
19
+ # generate summary for the input chunk
20
+ summary_ids = model.generate(encoded_input, max_length=300, num_beams=4, early_stopping=True)
21
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
22
+
23
+ # add the summary to the list of summaries
24
+ summaries.append(summary)
25
 
26
+ def bart_summarize(text: str):
27
 
28
  try:
29
  sentences = sent_tokenize(text)
 
49
  input_chunks.append(temp_sentences)
50
 
51
  # summarize each input chunk separately
52
+ results = []
53
  print(datetime.now().strftime("%H:%M:%S"))
54
  for chunk in input_chunks:
55
+ result_t = multiprocessing.Process(target=generate_summary, args=(chunk,))
56
+ results.append(result_t)
57
+
58
+ for result in results:
59
+ result.start()
60
 
61
+ for result in results:
62
+ result.join()
 
 
 
 
 
 
63
 
64
  # # combine the summaries to get the final summary for the entire input
65
  final_summary = " ".join(summaries)