KevlarVK commited on
Commit
3bb118d
1 Parent(s): 7a1d9d9

Media support, Code cleanup for summarization, Support for chunk and auto chapters summarize

Browse files
Files changed (5) hide show
  1. Utils.py +87 -21
  2. app.py +111 -9
  3. process_media.py +72 -0
  4. summarize.py +0 -69
  5. summarizer.py +88 -0
Utils.py CHANGED
@@ -1,7 +1,13 @@
1
  import requests
2
  from bs4 import BeautifulSoup
3
- import string
 
 
 
 
 
4
 
 
5
  def fetch_article_text(url: str):
6
 
7
  r = requests.get(url)
@@ -9,27 +15,87 @@ def fetch_article_text(url: str):
9
  results = soup.find_all(["h1", "p"])
10
  text = [result.text for result in results]
11
  ARTICLE = " ".join(text)
12
- ARTICLE = ARTICLE.replace(".", ".<eos>")
13
- ARTICLE = ARTICLE.replace("!", "!<eos>")
14
- ARTICLE = ARTICLE.replace("?", "?<eos>")
15
- sentences = ARTICLE.split("<eos>")
16
- current_chunk = 0
17
- chunks = []
18
- for sentence in sentences:
19
- if len(chunks) == current_chunk + 1:
20
- if len(chunks[current_chunk]) + len(sentence.split(" ")) <= 500:
21
- chunks[current_chunk].extend(sentence.split(" "))
22
- else:
23
- current_chunk += 1
24
- chunks.append(sentence.split(" "))
25
- else:
26
- print(current_chunk)
27
- chunks.append(sentence.split(" "))
28
-
29
- for chunk_id in range(len(chunks)):
30
- chunks[chunk_id] = " ".join(chunks[chunk_id])
31
 
32
- return ARTICLE, chunks
33
 
34
  def count_tokens(text: str):
35
  return len(text.split(" "))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  from bs4 import BeautifulSoup
3
+ from nltk.tokenize import sent_tokenize
4
+ import nltk
5
+ import re
6
+ import streamlit as st
7
+ from youtube_transcript_api import YouTubeTranscriptApi
8
+ import spacy
9
 
10
+ @st.cache_data
11
  def fetch_article_text(url: str):
12
 
13
  r = requests.get(url)
 
15
  results = soup.find_all(["h1", "p"])
16
  text = [result.text for result in results]
17
  ARTICLE = " ".join(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ return ARTICLE
20
 
21
  def count_tokens(text: str):
22
  return len(text.split(" "))
23
+
24
+ @st.cache_data
25
+ def get_text_from_youtube_url(url: str):
26
+
27
+ id = url.split("=")[1]
28
+ try:
29
+ transcript = YouTubeTranscriptApi.get_transcript(id)
30
+ except:
31
+ transcript = YouTubeTranscriptApi.find_transcript(["en"])
32
+ script = ""
33
+
34
+ for text in transcript:
35
+ t = text["text"]
36
+ if t != '[Music]':
37
+ script += t.lower() + " "
38
+
39
+ return add_punctuation(script)
40
+
41
+ def add_punctuation(text: str):
42
+
43
+ # try:
44
+ nlp = spacy.load("en_core_web_sm")
45
+ # except:
46
+ # import spacy.cli
47
+ # spacy.cli.download("en_core_web_sm")
48
+ # nlp = spacy.load("en_core_web_sm")
49
+
50
+ doc = nlp(text)
51
+ punctuation = [".", ",", ";", ":", "?", "!"]
52
+
53
+ sentences = []
54
+ for sentence in doc.sents:
55
+
56
+ last_token = sentence[-1]
57
+ if last_token.text in punctuation:
58
+ sentence = sentence[:-1]
59
+
60
+ last_word = sentence[-1]
61
+ if last_word.pos_ == "NOUN":
62
+ sentence = sentence.text + "."
63
+ elif last_word.pos_ == "VERB":
64
+ sentence = sentence.text + "?"
65
+ else:
66
+ sentence = sentence.text + "."
67
+
68
+ sentence = sentence[0].upper() + sentence[1:]
69
+ sentences.append(sentence)
70
+
71
+ text_with_punctuation = " ".join(sentences)
72
+
73
+ return text_with_punctuation
74
+
75
+
76
+ def get_input_chunks(text: str, max_length: int = 500):
77
+ try:
78
+ sentences = sent_tokenize(text)
79
+ except:
80
+ nltk.download('punkt')
81
+ sentences = sent_tokenize(text)
82
+
83
+ sentences = [re.sub(r'\[[0-9]*\]', ' ', sentence) for sentence in sentences if len(sentence.strip()) > 0 and count_tokens(sentence) > 4]
84
+
85
+ input_chunks = []
86
+ temp_sentences = ""
87
+ tokens = 0
88
+
89
+ for sentence in sentences:
90
+ if tokens + count_tokens(sentence) < max_length:
91
+ temp_sentences += sentence
92
+ tokens += count_tokens(sentence)
93
+ else:
94
+ input_chunks.append(temp_sentences)
95
+ tokens = count_tokens(sentence)
96
+ temp_sentences = sentence
97
+
98
+ if len(temp_sentences) > 0:
99
+ input_chunks.append(temp_sentences)
100
+
101
+ return input_chunks
app.py CHANGED
@@ -1,13 +1,115 @@
 
 
 
 
1
  import streamlit as st
2
- from summarize import bart_summarize
 
 
3
 
4
- # Create a text field
5
- text = st.text_input("Enter text here")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # Create a button
8
- button = st.button("Click here")
 
9
 
10
- # get text from text field and print it
11
- if button:
12
- summary = bart_summarize(text)
13
- st.write(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import time
3
+ import wave
4
+ from process_media import MediaProcessor
5
  import streamlit as st
6
+ from summarizer import BARTSummarizer
7
+ from pydub import AudioSegment
8
+ from Utils import fetch_article_text, get_text_from_youtube_url
9
 
10
+ st.markdown(
11
+ """
12
+ <style>
13
+ section[data-testid="stSidebar"] div[role="radiogroup"] label {
14
+ padding: 0px 0px 20px 20px;
15
+ }
16
+ section[data-testid="stSidebar"] h2 {
17
+ margin: 10px;
18
+ }
19
+ section.main div[role="radiogroup"] label {
20
+ padding: 10px 10px 10px 0px;
21
+ }
22
+ </style>
23
+ """,
24
+ unsafe_allow_html=True,
25
+ )
26
 
27
+ with st.sidebar:
28
+ st.header("CHOOSE INPUT TYPE")
29
+ input_type = st.radio("", ["Text", "Media"], label_visibility = "hidden")
30
 
31
+
32
+ text_to_summarize = None
33
+
34
+ if input_type == "Text":
35
+
36
+ st.header("Summarize from text or URL")
37
+
38
+ text_type = st.radio("", ["Raw Text", "URL", "Document"], key="text_type", horizontal=True, label_visibility = "hidden")
39
+
40
+ if text_type == "Raw Text":
41
+ text = st.text_area("Enter raw text here", height=240, max_chars=10000, placeholder="Enter a paragraph to summarize")
42
+ if text:
43
+ text_to_summarize = text
44
+
45
+ elif text_type == "URL":
46
+ url = st.text_input("Enter URL here", placeholder="Enter URL to an article, blog post, etc.")
47
+ if url:
48
+ article_text = fetch_article_text(url)
49
+ if article_text:
50
+ st.markdown("#### Text from url:")
51
+ st.write(article_text)
52
+ text_to_summarize = article_text
53
+ else:
54
+ ## TODO: Add file upload option
55
+ pass
56
+
57
+ elif input_type == "Media":
58
+
59
+ st.header("Summarize from file or YouTube URL")
60
+
61
+ media_type = st.radio("", ["Audio file", "Video file", "Youtube video link"], key="media_type", horizontal=True, label_visibility = "hidden")
62
+
63
+ if media_type == "Audio file":
64
+ audio_file = st.file_uploader("Upload an audio file", type=["mp3", "wav"], label_visibility="visible")
65
+ if audio_file is not None:
66
+ with st.spinner("Fetching text from audio..."):
67
+ # print(audio_file.read())
68
+ wav_bytes = None
69
+ media_processor = MediaProcessor()
70
+ if audio_file.type == "audio/mpeg":
71
+ wav_bytes = media_processor.get_wav_from_audio(audio_file.read())
72
+ else:
73
+ wav_bytes = audio_file.read()
74
+ text = media_processor.process_audio(wav_bytes)
75
+ st.markdown("#### Text from audio:")
76
+ st.write(text)
77
+ elif media_type == "Video file":
78
+ video_file = st.file_uploader("Upload a video file", type=["mp4"], label_visibility="visible")
79
+ if video_file is not None:
80
+ with st.spinner("Fetching text from video..."):
81
+ media_processor = MediaProcessor()
82
+ text = media_processor.process_video(video_file.read())
83
+ st.markdown("#### Text from video:")
84
+ st.write(text)
85
+ else:
86
+ youtube_url = st.text_input("Enter YouTube URL here", placeholder="Enter URL to an YouTube video", label_visibility="visible")
87
+ if youtube_url:
88
+ with st.spinner("Fetching text from video..."):
89
+ try:
90
+ text_to_summarize = get_text_from_youtube_url(youtube_url)
91
+ st.markdown("#### Text from video:")
92
+ st.markdown('<div style="height: 300px; overflow: auto; margin-bottom: 20px;">' + text_to_summarize + '</div>', unsafe_allow_html=True)
93
+ except:
94
+ st.error("Unable to fetch text from video. Please try a different video.")
95
+ text_to_summarize = None
96
+
97
+ if text_to_summarize is not None:
98
+ overall_summary = st.button("Overall summary")
99
+ auto_chapters_summary = st.button("Auto Chapters summary")
100
+ if overall_summary:
101
+ with st.spinner("Summarizing..."):
102
+ # time.sleep(2)
103
+ # st.write(text_to_summarize)
104
+ summarizer = BARTSummarizer()
105
+ summary = summarizer.chunk_summarize(text_to_summarize)
106
+ st.markdown("#### Summary:")
107
+ st.write(summary)
108
+ elif auto_chapters_summary:
109
+ with st.spinner("Summarizing..."):
110
+ # time.sleep(2)
111
+ # st.write(text_to_summarize)
112
+ summarizer = BARTSummarizer()
113
+ summary = summarizer.auto_chapters_summarize(text_to_summarize)
114
+ st.markdown("#### Summary:")
115
+ st.write(summary)
process_media.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import wave
3
+ import tensorflow as tf
4
+ import tensorflow_io as tfio
5
+ import moviepy.editor as mp
6
+ import numpy as np
7
+ from pydub import AudioSegment
8
+ from transformers import AutoProcessor, TFWhisperForConditionalGeneration
9
+ from moviepy.video.io.VideoFileClip import VideoFileClip
10
+
11
+ # tf.config.run_functions_eagerly(True)
12
+
13
+ class MediaProcessor:
14
+
15
+ def __init__(self):
16
+ self.processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
17
+ self.model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
18
+
19
+ def load_wav_16k_mono(self, file_bytes):
20
+ """ Load a WAV file, convert it to a float tensor, resample to 16 kHz single-channel audio. """
21
+ wav, sample_rate = tf.audio.decode_wav(
22
+ file_bytes,
23
+ desired_channels=1)
24
+ wav = tf.squeeze(wav, axis=-1)
25
+ sample_rate = tf.cast(sample_rate, dtype=tf.int64)
26
+ wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)
27
+ return wav.numpy()
28
+
29
+ def get_text_from_audio(self, resampled_audio_data):
30
+ # Split the resampled audio data into 30-second chunks
31
+ chunk_size = 30 * 16000
32
+ audio_chunks = [resampled_audio_data[i:i+chunk_size] for i in range(0, len(resampled_audio_data), chunk_size)]
33
+
34
+ text = []
35
+ for chunk in audio_chunks:
36
+ inputs = self.processor(chunk, sampling_rate=16000, return_tensors="tf").input_features
37
+ predicted_ids = self.model.generate(inputs, max_new_tokens=500)
38
+ transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
39
+ text.append(transcription[0])
40
+
41
+ return " ".join(text)
42
+
43
+ def get_audio_from_video(self, video_buffer):
44
+ buffer = io.BytesIO(video_buffer)
45
+ video_file = AudioSegment.from_file(buffer)
46
+ audio = video_file.set_channels(1)
47
+ with io.BytesIO() as wav_buffer:
48
+ audio.export(wav_buffer, format="wav")
49
+ wav_bytes = wav_buffer.getvalue()
50
+ return wav_bytes
51
+
52
+ def get_wav_from_audio(self, audio_buffer):
53
+ buffer = io.BytesIO(audio_buffer)
54
+ audio_file = AudioSegment.from_mp3(buffer)
55
+ raw_data = audio_file.raw_data
56
+ with io.BytesIO() as wav_buffer:
57
+ with wave.open(wav_buffer, "wb") as wav_file:
58
+ wav_file.setnchannels(audio_file.channels)
59
+ wav_file.setsampwidth(audio_file.sample_width)
60
+ wav_file.setframerate(audio_file.frame_rate)
61
+ wav_file.writeframes(raw_data)
62
+ wav_bytes = wav_buffer.getvalue()
63
+ return wav_bytes
64
+
65
+ def process_audio(self, audio_bytes):
66
+ resampled_audio_data = self.load_wav_16k_mono(audio_bytes)
67
+ return self.get_text_from_audio(resampled_audio_data)
68
+
69
+ def process_video(self, buffer):
70
+ audio_bytes = self.get_audio_from_video(buffer)
71
+ return self.process_audio(audio_bytes)
72
+
summarize.py DELETED
@@ -1,69 +0,0 @@
1
- from datetime import datetime
2
- import multiprocessing
3
- from transformers import BartTokenizer, TFBartForConditionalGeneration, pipeline
4
- from Utils import fetch_article_text, count_tokens
5
- import re
6
- from nltk.tokenize import sent_tokenize
7
- import nltk
8
- import threading
9
-
10
- tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
11
- model = TFBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
12
- max_length = model.config.max_position_embeddings
13
-
14
- summaries = []
15
-
16
- def generate_summary(text: str):
17
- encoded_input = tokenizer.encode(text, max_length=max_length, return_tensors='tf')
18
-
19
- # generate summary for the input chunk
20
- summary_ids = model.generate(encoded_input, max_length=300, num_beams=4, early_stopping=True)
21
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
22
-
23
- # add the summary to the list of summaries
24
- summaries.append(summary)
25
-
26
- def bart_summarize(text: str):
27
-
28
- try:
29
- sentences = sent_tokenize(text)
30
- except:
31
- nltk.download('punkt')
32
- sentences = sent_tokenize(text)
33
- sentences = [sentence for sentence in sentences if len(sentence.strip()) > 0 and len(sentence.split(" ")) > 4]
34
-
35
- input_chunks = []
36
- temp_sentences = ""
37
- tokens = 0
38
-
39
- for sentence in sentences:
40
- if tokens + count_tokens(sentence) < max_length:
41
- temp_sentences += sentence
42
- tokens += count_tokens(sentence)
43
- else:
44
- input_chunks.append(temp_sentences)
45
- tokens = count_tokens(sentence)
46
- temp_sentences = sentence
47
-
48
- if len(temp_sentences) > 0:
49
- input_chunks.append(temp_sentences)
50
-
51
- # summarize each input chunk separately
52
- results = []
53
- print(datetime.now().strftime("%H:%M:%S"))
54
- for chunk in input_chunks:
55
- result_t = multiprocessing.Process(target=generate_summary, args=(chunk,))
56
- results.append(result_t)
57
-
58
- for result in results:
59
- result.start()
60
-
61
- for result in results:
62
- result.join()
63
-
64
- # # combine the summaries to get the final summary for the entire input
65
- final_summary = " ".join(summaries)
66
-
67
- print(datetime.now().strftime("%H:%M:%S"))
68
-
69
- return final_summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
summarizer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from transformers import BartTokenizer, TFBartForConditionalGeneration
3
+ from Utils import get_input_chunks
4
+ import networkx as nx
5
+ from nltk.tokenize import sent_tokenize
6
+ from sklearn.feature_extraction.text import TfidfVectorizer
7
+ import community
8
+
9
+
10
+ class BARTSummarizer:
11
+
12
+ def __init__(self, model_name: str = 'facebook/bart-large-cnn'):
13
+ self.model_name = model_name
14
+ self.tokenizer = BartTokenizer.from_pretrained(model_name)
15
+ self.model = TFBartForConditionalGeneration.from_pretrained(model_name)
16
+ self.max_length = self.model.config.max_position_embeddings
17
+
18
+ def summarize(self, text: str):
19
+ encoded_input = self.tokenizer.encode(text, max_length=self.max_length, return_tensors='tf', truncation=True)
20
+ summary_ids = self.model.generate(encoded_input, max_length=300, num_beams=4, early_stopping=True)
21
+ summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
22
+ return summary
23
+
24
+ def chunk_summarize(self, text: str):
25
+
26
+ # split the input into chunks
27
+ summaries = []
28
+ input_chunks = get_input_chunks(text, self.max_length)
29
+
30
+ # summarize each input chunk separately
31
+ print(datetime.now().strftime("%H:%M:%S"))
32
+ for chunk in input_chunks:
33
+ summaries.append(self.summarize(chunk))
34
+
35
+ # # combine the summaries to get the final summary for the entire input
36
+ final_summary = " ".join(summaries)
37
+
38
+ print(datetime.now().strftime("%H:%M:%S"))
39
+
40
+ return final_summary
41
+
42
+ def preprocess_for_auto_chapters(self, text: str):
43
+
44
+ # Tokenize the text into sentences
45
+ sentences = sent_tokenize(text)
46
+
47
+ # Filter out empty sentences and sentences with less than 5 words
48
+ sentences = [sentence for sentence in sentences if len(sentence.strip()) > 0 and len(sentence.split(" ")) > 4]
49
+
50
+ # Combine every 5 sentences into a single sentence
51
+ sentences = [' '.join(sentences[i:i + 5]) for i in range(0, len(sentences), 5)]
52
+
53
+ return sentences
54
+
55
+ def auto_chapters_summarize(self, text: str):
56
+
57
+ sentences = self.preprocess_for_auto_chapters(text)
58
+
59
+ vectorizer = TfidfVectorizer(stop_words='english')
60
+ X = vectorizer.fit_transform(sentences)
61
+
62
+ # Compute the similarity matrix using cosine similarity
63
+ similarity_matrix = X * X.T
64
+
65
+ # Convert the similarity matrix to a graph
66
+ graph = nx.from_scipy_sparse_array(similarity_matrix)
67
+
68
+ # Apply the Louvain algorithm to identify communities
69
+ partition = community.best_partition(graph, resolution=0.7, random_state=42)
70
+
71
+ # Cluster the sentences
72
+ clustered_sentences = []
73
+ for cluster in set(partition.values()):
74
+ sentences_to_print = []
75
+ for i, sentence in enumerate(sentences):
76
+ if partition[i] == cluster:
77
+ sentences_to_print.append(sentence)
78
+ if len(sentences_to_print) > 1:
79
+ clustered_sentences.append(" ".join(sentences_to_print))
80
+
81
+ # Summarize each cluster
82
+ summaries = []
83
+ for cluster in clustered_sentences:
84
+ summaries.append(self.chunk_summarize(cluster))
85
+
86
+ # Combine the summaries to get the final summary for the entire input
87
+ final_summary = "\n\n".join(summaries)
88
+