Spaces:
Runtime error
Runtime error
Media support, Code cleanup for summarization, Support for chunk and auto chapters summarize
Browse files- Utils.py +87 -21
- app.py +111 -9
- process_media.py +72 -0
- summarize.py +0 -69
- summarizer.py +88 -0
Utils.py
CHANGED
@@ -1,7 +1,13 @@
|
|
1 |
import requests
|
2 |
from bs4 import BeautifulSoup
|
3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
4 |
|
|
|
5 |
def fetch_article_text(url: str):
|
6 |
|
7 |
r = requests.get(url)
|
@@ -9,27 +15,87 @@ def fetch_article_text(url: str):
|
|
9 |
results = soup.find_all(["h1", "p"])
|
10 |
text = [result.text for result in results]
|
11 |
ARTICLE = " ".join(text)
|
12 |
-
ARTICLE = ARTICLE.replace(".", ".<eos>")
|
13 |
-
ARTICLE = ARTICLE.replace("!", "!<eos>")
|
14 |
-
ARTICLE = ARTICLE.replace("?", "?<eos>")
|
15 |
-
sentences = ARTICLE.split("<eos>")
|
16 |
-
current_chunk = 0
|
17 |
-
chunks = []
|
18 |
-
for sentence in sentences:
|
19 |
-
if len(chunks) == current_chunk + 1:
|
20 |
-
if len(chunks[current_chunk]) + len(sentence.split(" ")) <= 500:
|
21 |
-
chunks[current_chunk].extend(sentence.split(" "))
|
22 |
-
else:
|
23 |
-
current_chunk += 1
|
24 |
-
chunks.append(sentence.split(" "))
|
25 |
-
else:
|
26 |
-
print(current_chunk)
|
27 |
-
chunks.append(sentence.split(" "))
|
28 |
-
|
29 |
-
for chunk_id in range(len(chunks)):
|
30 |
-
chunks[chunk_id] = " ".join(chunks[chunk_id])
|
31 |
|
32 |
-
return ARTICLE
|
33 |
|
34 |
def count_tokens(text: str):
|
35 |
return len(text.split(" "))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
from bs4 import BeautifulSoup
|
3 |
+
from nltk.tokenize import sent_tokenize
|
4 |
+
import nltk
|
5 |
+
import re
|
6 |
+
import streamlit as st
|
7 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
8 |
+
import spacy
|
9 |
|
10 |
+
@st.cache_data
|
11 |
def fetch_article_text(url: str):
|
12 |
|
13 |
r = requests.get(url)
|
|
|
15 |
results = soup.find_all(["h1", "p"])
|
16 |
text = [result.text for result in results]
|
17 |
ARTICLE = " ".join(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
return ARTICLE
|
20 |
|
21 |
def count_tokens(text: str):
|
22 |
return len(text.split(" "))
|
23 |
+
|
24 |
+
@st.cache_data
|
25 |
+
def get_text_from_youtube_url(url: str):
|
26 |
+
|
27 |
+
id = url.split("=")[1]
|
28 |
+
try:
|
29 |
+
transcript = YouTubeTranscriptApi.get_transcript(id)
|
30 |
+
except:
|
31 |
+
transcript = YouTubeTranscriptApi.find_transcript(["en"])
|
32 |
+
script = ""
|
33 |
+
|
34 |
+
for text in transcript:
|
35 |
+
t = text["text"]
|
36 |
+
if t != '[Music]':
|
37 |
+
script += t.lower() + " "
|
38 |
+
|
39 |
+
return add_punctuation(script)
|
40 |
+
|
41 |
+
def add_punctuation(text: str):
|
42 |
+
|
43 |
+
# try:
|
44 |
+
nlp = spacy.load("en_core_web_sm")
|
45 |
+
# except:
|
46 |
+
# import spacy.cli
|
47 |
+
# spacy.cli.download("en_core_web_sm")
|
48 |
+
# nlp = spacy.load("en_core_web_sm")
|
49 |
+
|
50 |
+
doc = nlp(text)
|
51 |
+
punctuation = [".", ",", ";", ":", "?", "!"]
|
52 |
+
|
53 |
+
sentences = []
|
54 |
+
for sentence in doc.sents:
|
55 |
+
|
56 |
+
last_token = sentence[-1]
|
57 |
+
if last_token.text in punctuation:
|
58 |
+
sentence = sentence[:-1]
|
59 |
+
|
60 |
+
last_word = sentence[-1]
|
61 |
+
if last_word.pos_ == "NOUN":
|
62 |
+
sentence = sentence.text + "."
|
63 |
+
elif last_word.pos_ == "VERB":
|
64 |
+
sentence = sentence.text + "?"
|
65 |
+
else:
|
66 |
+
sentence = sentence.text + "."
|
67 |
+
|
68 |
+
sentence = sentence[0].upper() + sentence[1:]
|
69 |
+
sentences.append(sentence)
|
70 |
+
|
71 |
+
text_with_punctuation = " ".join(sentences)
|
72 |
+
|
73 |
+
return text_with_punctuation
|
74 |
+
|
75 |
+
|
76 |
+
def get_input_chunks(text: str, max_length: int = 500):
|
77 |
+
try:
|
78 |
+
sentences = sent_tokenize(text)
|
79 |
+
except:
|
80 |
+
nltk.download('punkt')
|
81 |
+
sentences = sent_tokenize(text)
|
82 |
+
|
83 |
+
sentences = [re.sub(r'\[[0-9]*\]', ' ', sentence) for sentence in sentences if len(sentence.strip()) > 0 and count_tokens(sentence) > 4]
|
84 |
+
|
85 |
+
input_chunks = []
|
86 |
+
temp_sentences = ""
|
87 |
+
tokens = 0
|
88 |
+
|
89 |
+
for sentence in sentences:
|
90 |
+
if tokens + count_tokens(sentence) < max_length:
|
91 |
+
temp_sentences += sentence
|
92 |
+
tokens += count_tokens(sentence)
|
93 |
+
else:
|
94 |
+
input_chunks.append(temp_sentences)
|
95 |
+
tokens = count_tokens(sentence)
|
96 |
+
temp_sentences = sentence
|
97 |
+
|
98 |
+
if len(temp_sentences) > 0:
|
99 |
+
input_chunks.append(temp_sentences)
|
100 |
+
|
101 |
+
return input_chunks
|
app.py
CHANGED
@@ -1,13 +1,115 @@
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
-
from
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import time
|
3 |
+
import wave
|
4 |
+
from process_media import MediaProcessor
|
5 |
import streamlit as st
|
6 |
+
from summarizer import BARTSummarizer
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from Utils import fetch_article_text, get_text_from_youtube_url
|
9 |
|
10 |
+
st.markdown(
|
11 |
+
"""
|
12 |
+
<style>
|
13 |
+
section[data-testid="stSidebar"] div[role="radiogroup"] label {
|
14 |
+
padding: 0px 0px 20px 20px;
|
15 |
+
}
|
16 |
+
section[data-testid="stSidebar"] h2 {
|
17 |
+
margin: 10px;
|
18 |
+
}
|
19 |
+
section.main div[role="radiogroup"] label {
|
20 |
+
padding: 10px 10px 10px 0px;
|
21 |
+
}
|
22 |
+
</style>
|
23 |
+
""",
|
24 |
+
unsafe_allow_html=True,
|
25 |
+
)
|
26 |
|
27 |
+
with st.sidebar:
|
28 |
+
st.header("CHOOSE INPUT TYPE")
|
29 |
+
input_type = st.radio("", ["Text", "Media"], label_visibility = "hidden")
|
30 |
|
31 |
+
|
32 |
+
text_to_summarize = None
|
33 |
+
|
34 |
+
if input_type == "Text":
|
35 |
+
|
36 |
+
st.header("Summarize from text or URL")
|
37 |
+
|
38 |
+
text_type = st.radio("", ["Raw Text", "URL", "Document"], key="text_type", horizontal=True, label_visibility = "hidden")
|
39 |
+
|
40 |
+
if text_type == "Raw Text":
|
41 |
+
text = st.text_area("Enter raw text here", height=240, max_chars=10000, placeholder="Enter a paragraph to summarize")
|
42 |
+
if text:
|
43 |
+
text_to_summarize = text
|
44 |
+
|
45 |
+
elif text_type == "URL":
|
46 |
+
url = st.text_input("Enter URL here", placeholder="Enter URL to an article, blog post, etc.")
|
47 |
+
if url:
|
48 |
+
article_text = fetch_article_text(url)
|
49 |
+
if article_text:
|
50 |
+
st.markdown("#### Text from url:")
|
51 |
+
st.write(article_text)
|
52 |
+
text_to_summarize = article_text
|
53 |
+
else:
|
54 |
+
## TODO: Add file upload option
|
55 |
+
pass
|
56 |
+
|
57 |
+
elif input_type == "Media":
|
58 |
+
|
59 |
+
st.header("Summarize from file or YouTube URL")
|
60 |
+
|
61 |
+
media_type = st.radio("", ["Audio file", "Video file", "Youtube video link"], key="media_type", horizontal=True, label_visibility = "hidden")
|
62 |
+
|
63 |
+
if media_type == "Audio file":
|
64 |
+
audio_file = st.file_uploader("Upload an audio file", type=["mp3", "wav"], label_visibility="visible")
|
65 |
+
if audio_file is not None:
|
66 |
+
with st.spinner("Fetching text from audio..."):
|
67 |
+
# print(audio_file.read())
|
68 |
+
wav_bytes = None
|
69 |
+
media_processor = MediaProcessor()
|
70 |
+
if audio_file.type == "audio/mpeg":
|
71 |
+
wav_bytes = media_processor.get_wav_from_audio(audio_file.read())
|
72 |
+
else:
|
73 |
+
wav_bytes = audio_file.read()
|
74 |
+
text = media_processor.process_audio(wav_bytes)
|
75 |
+
st.markdown("#### Text from audio:")
|
76 |
+
st.write(text)
|
77 |
+
elif media_type == "Video file":
|
78 |
+
video_file = st.file_uploader("Upload a video file", type=["mp4"], label_visibility="visible")
|
79 |
+
if video_file is not None:
|
80 |
+
with st.spinner("Fetching text from video..."):
|
81 |
+
media_processor = MediaProcessor()
|
82 |
+
text = media_processor.process_video(video_file.read())
|
83 |
+
st.markdown("#### Text from video:")
|
84 |
+
st.write(text)
|
85 |
+
else:
|
86 |
+
youtube_url = st.text_input("Enter YouTube URL here", placeholder="Enter URL to an YouTube video", label_visibility="visible")
|
87 |
+
if youtube_url:
|
88 |
+
with st.spinner("Fetching text from video..."):
|
89 |
+
try:
|
90 |
+
text_to_summarize = get_text_from_youtube_url(youtube_url)
|
91 |
+
st.markdown("#### Text from video:")
|
92 |
+
st.markdown('<div style="height: 300px; overflow: auto; margin-bottom: 20px;">' + text_to_summarize + '</div>', unsafe_allow_html=True)
|
93 |
+
except:
|
94 |
+
st.error("Unable to fetch text from video. Please try a different video.")
|
95 |
+
text_to_summarize = None
|
96 |
+
|
97 |
+
if text_to_summarize is not None:
|
98 |
+
overall_summary = st.button("Overall summary")
|
99 |
+
auto_chapters_summary = st.button("Auto Chapters summary")
|
100 |
+
if overall_summary:
|
101 |
+
with st.spinner("Summarizing..."):
|
102 |
+
# time.sleep(2)
|
103 |
+
# st.write(text_to_summarize)
|
104 |
+
summarizer = BARTSummarizer()
|
105 |
+
summary = summarizer.chunk_summarize(text_to_summarize)
|
106 |
+
st.markdown("#### Summary:")
|
107 |
+
st.write(summary)
|
108 |
+
elif auto_chapters_summary:
|
109 |
+
with st.spinner("Summarizing..."):
|
110 |
+
# time.sleep(2)
|
111 |
+
# st.write(text_to_summarize)
|
112 |
+
summarizer = BARTSummarizer()
|
113 |
+
summary = summarizer.auto_chapters_summarize(text_to_summarize)
|
114 |
+
st.markdown("#### Summary:")
|
115 |
+
st.write(summary)
|
process_media.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import wave
|
3 |
+
import tensorflow as tf
|
4 |
+
import tensorflow_io as tfio
|
5 |
+
import moviepy.editor as mp
|
6 |
+
import numpy as np
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from transformers import AutoProcessor, TFWhisperForConditionalGeneration
|
9 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
10 |
+
|
11 |
+
# tf.config.run_functions_eagerly(True)
|
12 |
+
|
13 |
+
class MediaProcessor:
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
self.processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
17 |
+
self.model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
18 |
+
|
19 |
+
def load_wav_16k_mono(self, file_bytes):
|
20 |
+
""" Load a WAV file, convert it to a float tensor, resample to 16 kHz single-channel audio. """
|
21 |
+
wav, sample_rate = tf.audio.decode_wav(
|
22 |
+
file_bytes,
|
23 |
+
desired_channels=1)
|
24 |
+
wav = tf.squeeze(wav, axis=-1)
|
25 |
+
sample_rate = tf.cast(sample_rate, dtype=tf.int64)
|
26 |
+
wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)
|
27 |
+
return wav.numpy()
|
28 |
+
|
29 |
+
def get_text_from_audio(self, resampled_audio_data):
|
30 |
+
# Split the resampled audio data into 30-second chunks
|
31 |
+
chunk_size = 30 * 16000
|
32 |
+
audio_chunks = [resampled_audio_data[i:i+chunk_size] for i in range(0, len(resampled_audio_data), chunk_size)]
|
33 |
+
|
34 |
+
text = []
|
35 |
+
for chunk in audio_chunks:
|
36 |
+
inputs = self.processor(chunk, sampling_rate=16000, return_tensors="tf").input_features
|
37 |
+
predicted_ids = self.model.generate(inputs, max_new_tokens=500)
|
38 |
+
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
39 |
+
text.append(transcription[0])
|
40 |
+
|
41 |
+
return " ".join(text)
|
42 |
+
|
43 |
+
def get_audio_from_video(self, video_buffer):
|
44 |
+
buffer = io.BytesIO(video_buffer)
|
45 |
+
video_file = AudioSegment.from_file(buffer)
|
46 |
+
audio = video_file.set_channels(1)
|
47 |
+
with io.BytesIO() as wav_buffer:
|
48 |
+
audio.export(wav_buffer, format="wav")
|
49 |
+
wav_bytes = wav_buffer.getvalue()
|
50 |
+
return wav_bytes
|
51 |
+
|
52 |
+
def get_wav_from_audio(self, audio_buffer):
|
53 |
+
buffer = io.BytesIO(audio_buffer)
|
54 |
+
audio_file = AudioSegment.from_mp3(buffer)
|
55 |
+
raw_data = audio_file.raw_data
|
56 |
+
with io.BytesIO() as wav_buffer:
|
57 |
+
with wave.open(wav_buffer, "wb") as wav_file:
|
58 |
+
wav_file.setnchannels(audio_file.channels)
|
59 |
+
wav_file.setsampwidth(audio_file.sample_width)
|
60 |
+
wav_file.setframerate(audio_file.frame_rate)
|
61 |
+
wav_file.writeframes(raw_data)
|
62 |
+
wav_bytes = wav_buffer.getvalue()
|
63 |
+
return wav_bytes
|
64 |
+
|
65 |
+
def process_audio(self, audio_bytes):
|
66 |
+
resampled_audio_data = self.load_wav_16k_mono(audio_bytes)
|
67 |
+
return self.get_text_from_audio(resampled_audio_data)
|
68 |
+
|
69 |
+
def process_video(self, buffer):
|
70 |
+
audio_bytes = self.get_audio_from_video(buffer)
|
71 |
+
return self.process_audio(audio_bytes)
|
72 |
+
|
summarize.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
from datetime import datetime
|
2 |
-
import multiprocessing
|
3 |
-
from transformers import BartTokenizer, TFBartForConditionalGeneration, pipeline
|
4 |
-
from Utils import fetch_article_text, count_tokens
|
5 |
-
import re
|
6 |
-
from nltk.tokenize import sent_tokenize
|
7 |
-
import nltk
|
8 |
-
import threading
|
9 |
-
|
10 |
-
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
|
11 |
-
model = TFBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
|
12 |
-
max_length = model.config.max_position_embeddings
|
13 |
-
|
14 |
-
summaries = []
|
15 |
-
|
16 |
-
def generate_summary(text: str):
|
17 |
-
encoded_input = tokenizer.encode(text, max_length=max_length, return_tensors='tf')
|
18 |
-
|
19 |
-
# generate summary for the input chunk
|
20 |
-
summary_ids = model.generate(encoded_input, max_length=300, num_beams=4, early_stopping=True)
|
21 |
-
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
22 |
-
|
23 |
-
# add the summary to the list of summaries
|
24 |
-
summaries.append(summary)
|
25 |
-
|
26 |
-
def bart_summarize(text: str):
|
27 |
-
|
28 |
-
try:
|
29 |
-
sentences = sent_tokenize(text)
|
30 |
-
except:
|
31 |
-
nltk.download('punkt')
|
32 |
-
sentences = sent_tokenize(text)
|
33 |
-
sentences = [sentence for sentence in sentences if len(sentence.strip()) > 0 and len(sentence.split(" ")) > 4]
|
34 |
-
|
35 |
-
input_chunks = []
|
36 |
-
temp_sentences = ""
|
37 |
-
tokens = 0
|
38 |
-
|
39 |
-
for sentence in sentences:
|
40 |
-
if tokens + count_tokens(sentence) < max_length:
|
41 |
-
temp_sentences += sentence
|
42 |
-
tokens += count_tokens(sentence)
|
43 |
-
else:
|
44 |
-
input_chunks.append(temp_sentences)
|
45 |
-
tokens = count_tokens(sentence)
|
46 |
-
temp_sentences = sentence
|
47 |
-
|
48 |
-
if len(temp_sentences) > 0:
|
49 |
-
input_chunks.append(temp_sentences)
|
50 |
-
|
51 |
-
# summarize each input chunk separately
|
52 |
-
results = []
|
53 |
-
print(datetime.now().strftime("%H:%M:%S"))
|
54 |
-
for chunk in input_chunks:
|
55 |
-
result_t = multiprocessing.Process(target=generate_summary, args=(chunk,))
|
56 |
-
results.append(result_t)
|
57 |
-
|
58 |
-
for result in results:
|
59 |
-
result.start()
|
60 |
-
|
61 |
-
for result in results:
|
62 |
-
result.join()
|
63 |
-
|
64 |
-
# # combine the summaries to get the final summary for the entire input
|
65 |
-
final_summary = " ".join(summaries)
|
66 |
-
|
67 |
-
print(datetime.now().strftime("%H:%M:%S"))
|
68 |
-
|
69 |
-
return final_summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
summarizer.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from transformers import BartTokenizer, TFBartForConditionalGeneration
|
3 |
+
from Utils import get_input_chunks
|
4 |
+
import networkx as nx
|
5 |
+
from nltk.tokenize import sent_tokenize
|
6 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
+
import community
|
8 |
+
|
9 |
+
|
10 |
+
class BARTSummarizer:
|
11 |
+
|
12 |
+
def __init__(self, model_name: str = 'facebook/bart-large-cnn'):
|
13 |
+
self.model_name = model_name
|
14 |
+
self.tokenizer = BartTokenizer.from_pretrained(model_name)
|
15 |
+
self.model = TFBartForConditionalGeneration.from_pretrained(model_name)
|
16 |
+
self.max_length = self.model.config.max_position_embeddings
|
17 |
+
|
18 |
+
def summarize(self, text: str):
|
19 |
+
encoded_input = self.tokenizer.encode(text, max_length=self.max_length, return_tensors='tf', truncation=True)
|
20 |
+
summary_ids = self.model.generate(encoded_input, max_length=300, num_beams=4, early_stopping=True)
|
21 |
+
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
22 |
+
return summary
|
23 |
+
|
24 |
+
def chunk_summarize(self, text: str):
|
25 |
+
|
26 |
+
# split the input into chunks
|
27 |
+
summaries = []
|
28 |
+
input_chunks = get_input_chunks(text, self.max_length)
|
29 |
+
|
30 |
+
# summarize each input chunk separately
|
31 |
+
print(datetime.now().strftime("%H:%M:%S"))
|
32 |
+
for chunk in input_chunks:
|
33 |
+
summaries.append(self.summarize(chunk))
|
34 |
+
|
35 |
+
# # combine the summaries to get the final summary for the entire input
|
36 |
+
final_summary = " ".join(summaries)
|
37 |
+
|
38 |
+
print(datetime.now().strftime("%H:%M:%S"))
|
39 |
+
|
40 |
+
return final_summary
|
41 |
+
|
42 |
+
def preprocess_for_auto_chapters(self, text: str):
|
43 |
+
|
44 |
+
# Tokenize the text into sentences
|
45 |
+
sentences = sent_tokenize(text)
|
46 |
+
|
47 |
+
# Filter out empty sentences and sentences with less than 5 words
|
48 |
+
sentences = [sentence for sentence in sentences if len(sentence.strip()) > 0 and len(sentence.split(" ")) > 4]
|
49 |
+
|
50 |
+
# Combine every 5 sentences into a single sentence
|
51 |
+
sentences = [' '.join(sentences[i:i + 5]) for i in range(0, len(sentences), 5)]
|
52 |
+
|
53 |
+
return sentences
|
54 |
+
|
55 |
+
def auto_chapters_summarize(self, text: str):
|
56 |
+
|
57 |
+
sentences = self.preprocess_for_auto_chapters(text)
|
58 |
+
|
59 |
+
vectorizer = TfidfVectorizer(stop_words='english')
|
60 |
+
X = vectorizer.fit_transform(sentences)
|
61 |
+
|
62 |
+
# Compute the similarity matrix using cosine similarity
|
63 |
+
similarity_matrix = X * X.T
|
64 |
+
|
65 |
+
# Convert the similarity matrix to a graph
|
66 |
+
graph = nx.from_scipy_sparse_array(similarity_matrix)
|
67 |
+
|
68 |
+
# Apply the Louvain algorithm to identify communities
|
69 |
+
partition = community.best_partition(graph, resolution=0.7, random_state=42)
|
70 |
+
|
71 |
+
# Cluster the sentences
|
72 |
+
clustered_sentences = []
|
73 |
+
for cluster in set(partition.values()):
|
74 |
+
sentences_to_print = []
|
75 |
+
for i, sentence in enumerate(sentences):
|
76 |
+
if partition[i] == cluster:
|
77 |
+
sentences_to_print.append(sentence)
|
78 |
+
if len(sentences_to_print) > 1:
|
79 |
+
clustered_sentences.append(" ".join(sentences_to_print))
|
80 |
+
|
81 |
+
# Summarize each cluster
|
82 |
+
summaries = []
|
83 |
+
for cluster in clustered_sentences:
|
84 |
+
summaries.append(self.chunk_summarize(cluster))
|
85 |
+
|
86 |
+
# Combine the summaries to get the final summary for the entire input
|
87 |
+
final_summary = "\n\n".join(summaries)
|
88 |
+
|