Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,322 +1,322 @@
|
|
1 |
-
import itertools
|
2 |
-
import json
|
3 |
-
import re
|
4 |
-
from functools import partial
|
5 |
-
from pathlib import Path
|
6 |
-
|
7 |
-
import pandas as pd
|
8 |
-
import requests
|
9 |
-
import streamlit as st
|
10 |
-
import webvtt
|
11 |
-
from transformers import AutoTokenizer
|
12 |
-
|
13 |
-
from generate_text_api import TextGenerator
|
14 |
-
from model_inferences.utils.chunking import Truncater
|
15 |
-
from model_inferences.utils.files import get_captions_from_vtt, get_transcript
|
16 |
-
|
17 |
-
USE_PARAGRAPHING_MODEL = True
|
18 |
-
|
19 |
-
def get_sublist_by_flattened_index(A, i):
|
20 |
-
current_index = 0
|
21 |
-
for sublist in A:
|
22 |
-
sublist_length = len(sublist)
|
23 |
-
if current_index <= i < current_index + sublist_length:
|
24 |
-
return sublist, A.index(sublist)
|
25 |
-
current_index += sublist_length
|
26 |
-
return None, None
|
27 |
-
|
28 |
-
import requests
|
29 |
-
|
30 |
-
|
31 |
-
def get_talk_metadata(video_id):
|
32 |
-
url = "https://www.ted.com/graphql"
|
33 |
-
|
34 |
-
headers = {
|
35 |
-
"Content-Type": "application/json",
|
36 |
-
"Accept": "application/json",
|
37 |
-
"x-operation-name": "Transcript", # Replace with the actual operation name
|
38 |
-
}
|
39 |
-
|
40 |
-
data = {
|
41 |
-
"query": """
|
42 |
-
query GetTalk($videoId: ID!) {
|
43 |
-
video(id: $videoId) {
|
44 |
-
title,
|
45 |
-
presenterDisplayName,
|
46 |
-
nativeDownloads {medium}
|
47 |
-
}
|
48 |
-
}
|
49 |
-
""",
|
50 |
-
"variables": {
|
51 |
-
"videoId": video_id, # Corrected key to "videoId"
|
52 |
-
},
|
53 |
-
}
|
54 |
-
|
55 |
-
response = requests.post(url, json=data, headers=headers)
|
56 |
-
|
57 |
-
if response.status_code == 200:
|
58 |
-
result = response.json()
|
59 |
-
return result
|
60 |
-
else:
|
61 |
-
print(f"Error: {response.status_code}, {response.text}")
|
62 |
-
|
63 |
-
class OfflineTextSegmenterClient:
|
64 |
-
def __init__(self, host_url):
|
65 |
-
self.host_url = host_url.rstrip("/") + "/segment"
|
66 |
-
|
67 |
-
def segment(self, text, captions=None, generate_titles=False, threshold=0.4):
|
68 |
-
payload = {
|
69 |
-
'text': text,
|
70 |
-
'captions': captions,
|
71 |
-
'generate_titles': generate_titles,
|
72 |
-
"prefix_titles": True,
|
73 |
-
"threshold": threshold,
|
74 |
-
}
|
75 |
-
|
76 |
-
headers = {
|
77 |
-
'Content-Type': 'application/json'
|
78 |
-
}
|
79 |
-
|
80 |
-
response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
|
81 |
-
#segments = response["annotated_segments"] if "annotated_segments" in response else response["segments"]
|
82 |
-
return {'segments':response["segments"], 'titles': response["titles"], 'sentences': response["sentences"]}
|
83 |
-
|
84 |
-
class Toc:
|
85 |
-
|
86 |
-
def __init__(self):
|
87 |
-
self._items = []
|
88 |
-
self._placeholder = None
|
89 |
-
|
90 |
-
def title(self, text):
|
91 |
-
self._markdown(text, "h1")
|
92 |
-
|
93 |
-
def header(self, text):
|
94 |
-
self._markdown(text, "h2", " " * 2)
|
95 |
-
|
96 |
-
def subheader(self, text):
|
97 |
-
self._markdown(text, "h3", " " * 4)
|
98 |
-
|
99 |
-
def placeholder(self, sidebar=False):
|
100 |
-
self._placeholder = st.sidebar.empty() if sidebar else st.empty()
|
101 |
-
|
102 |
-
def generate(self):
|
103 |
-
if self._placeholder:
|
104 |
-
self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True)
|
105 |
-
|
106 |
-
def _markdown(self, text, level, space=""):
|
107 |
-
key = re.sub(r'[^\w-]', '', text.replace(" ", "-").replace("'", "-").lower())
|
108 |
-
st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True)
|
109 |
-
self._items.append(f"{space}* <a href='#{key}'>{text}</a>")
|
110 |
-
|
111 |
-
# custom_css = "<style type='text/css'>" + Path('style.css').read_text() + "</style>"
|
112 |
-
# st.write(custom_css, unsafe_allow_html=True)
|
113 |
-
|
114 |
-
def concat_prompt(prompt_text, text, model_name):
|
115 |
-
if 'flan' in model_name:
|
116 |
-
input_ = prompt_text + "\n\n" + text
|
117 |
-
elif 'galactica' in model_name:
|
118 |
-
input_ = text + "\n\n" + prompt_text
|
119 |
-
return input_
|
120 |
-
|
121 |
-
endpoint = "http://hiaisc.isl.iar.kit.edu/summarize"
|
122 |
-
ENDPOINTS = {"http://hiaisc.isl.iar.kit.edu/summarize": "meta-llama/Llama-2-13b-chat-hf",}
|
123 |
-
|
124 |
-
client = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/
|
125 |
-
if USE_PARAGRAPHING_MODEL:
|
126 |
-
paragrapher = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/paragraph")
|
127 |
-
summarizer = TextGenerator(endpoint)
|
128 |
-
|
129 |
-
tokenizer = AutoTokenizer.from_pretrained(ENDPOINTS[endpoint], use_fast=False)
|
130 |
-
|
131 |
-
# TLDR PROMPT
|
132 |
-
|
133 |
-
SYSTEM_PROMPT = "You are an assistant who replies with a summary to every message."
|
134 |
-
|
135 |
-
TLDR_PROMPT_TEMPLATE = """<s>[INST] <<SYS>>
|
136 |
-
{system_prompt}
|
137 |
-
<</SYS>>
|
138 |
-
|
139 |
-
{user_message} [/INST] Sure! Here is a summary of the research presentation in a single, short sentence:"""
|
140 |
-
|
141 |
-
TLDR_USER_PROMPT = "Summarize the following research presentation in a single, short sentence:\n\n{input}"
|
142 |
-
|
143 |
-
TLDR_PROMPT = TLDR_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT)
|
144 |
-
TLDR_PROMPT_LENGTH = tokenizer(TLDR_PROMPT, return_tensors="pt")["input_ids"].size(1)
|
145 |
-
|
146 |
-
BP_PROMPT_TEMPLATE = """<s>[INST] <<SYS>>
|
147 |
-
{system_prompt}
|
148 |
-
<</SYS>>
|
149 |
-
|
150 |
-
{user_message} [/INST] Sure! Here is a summary of the research presentation using three bullet points:\n\n\u2022"""
|
151 |
-
|
152 |
-
BP_USER_PROMPT = "Summarize the following research presentation using three bullet points:\n\n{input}"
|
153 |
-
|
154 |
-
BP_PROMPT = BP_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT)
|
155 |
-
BP_PROMPT_LENGTH = tokenizer(BP_PROMPT, return_tensors="pt")["input_ids"].size(1)
|
156 |
-
|
157 |
-
CONTEXT_LENGTH = 3072
|
158 |
-
MAX_SUMMARY_LENGTH = 1024
|
159 |
-
TLDR_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - TLDR_PROMPT_LENGTH - 1
|
160 |
-
BP_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - BP_PROMPT_LENGTH - 1
|
161 |
-
|
162 |
-
|
163 |
-
text_generator = TextGenerator(endpoint)
|
164 |
-
temperature = 0.7
|
165 |
-
|
166 |
-
import re
|
167 |
-
|
168 |
-
|
169 |
-
def replace_newlines(text):
|
170 |
-
updated_text = re.sub(r'\n+', r'\n\n', text)
|
171 |
-
return updated_text
|
172 |
-
|
173 |
-
def generate_summary(summarizer, generated_text_box, input_, prompt, max_input_length, prefix=""):
|
174 |
-
all_generated_text = prefix
|
175 |
-
truncater = Truncater(tokenizer, max_length=max_input_length)
|
176 |
-
input_ = truncater(input_)
|
177 |
-
input_ = prompt.format(input=input_)
|
178 |
-
for generated_text in summarizer.generate_text_stream(input_, max_new_tokens=MAX_SUMMARY_LENGTH, do_sample=True, temperature=temperature):
|
179 |
-
all_generated_text += replace_newlines(generated_text)
|
180 |
-
generated_text_box.info(all_generated_text)
|
181 |
-
print(all_generated_text)
|
182 |
-
return all_generated_text.strip()
|
183 |
-
|
184 |
-
st.header("Demo: Intelligent Recap")
|
185 |
-
|
186 |
-
if not hasattr(st, 'global_state'):
|
187 |
-
st.global_state = {'NIPS 2021 Talks': None, 'TED Talks': None}
|
188 |
-
# NIPS 2021 Talks
|
189 |
-
transcript_files = itertools.islice(Path("demo_data/nips-2021/").rglob("transcript_whisper_large-v2.vtt"), 15)
|
190 |
-
# get titles from metadata.json
|
191 |
-
transcripts_map = {}
|
192 |
-
for transcript_file in transcript_files:
|
193 |
-
base_path = transcript_file.parent
|
194 |
-
metadata = base_path / "metadata.json"
|
195 |
-
txt_file = base_path / "transcript_whisper_large-v2.txt"
|
196 |
-
with open(metadata) as f:
|
197 |
-
metadata = json.load(f)
|
198 |
-
title = metadata["title"]
|
199 |
-
transcript = get_transcript(txt_file)
|
200 |
-
captions = get_captions_from_vtt(transcript_file)
|
201 |
-
transcripts_map[title] = {"transcript": transcript, "captions": captions, "video": base_path / "video.mp4"}
|
202 |
-
st.global_state['NIPS 2021 Talks'] = transcripts_map
|
203 |
-
|
204 |
-
data = pd.read_json("demo_data/ted_talks.json")
|
205 |
-
video_ids = data.talk_id.tolist()
|
206 |
-
transcripts = data.text.apply(lambda x: " ".join(x)).tolist()
|
207 |
-
transcripts_map = {}
|
208 |
-
for video_id, transcript in zip(video_ids, transcripts):
|
209 |
-
metadata = get_talk_metadata(video_id)
|
210 |
-
title = metadata["data"]["video"]["title"]
|
211 |
-
presenter = metadata["data"]["video"]["presenterDisplayName"]
|
212 |
-
print(metadata["data"])
|
213 |
-
if metadata["data"]["video"]["nativeDownloads"] is None:
|
214 |
-
continue
|
215 |
-
video_url = metadata["data"]["video"]["nativeDownloads"]["medium"]
|
216 |
-
transcripts_map[title] = {"transcript": transcript, "video": video_url, "presenter": presenter}
|
217 |
-
st.global_state['TED Talks'] = transcripts_map
|
218 |
-
|
219 |
-
def get_lecture_id(path):
|
220 |
-
return int(path.parts[-2].split('-')[1])
|
221 |
-
|
222 |
-
transcript_files = Path("demo_data/lectures/").rglob("English.vtt")
|
223 |
-
sorted_path_list = sorted(transcript_files, key=get_lecture_id)
|
224 |
-
|
225 |
-
transcripts_map = {}
|
226 |
-
for transcript_file in sorted_path_list:
|
227 |
-
base_path = transcript_file.parent
|
228 |
-
lecture_id = base_path.parts[-1]
|
229 |
-
transcript = " ".join([c["text"].strip() for c in get_captions_from_vtt(transcript_file)]).replace("\n", " ")
|
230 |
-
video_path = Path(base_path, "video.mp4")
|
231 |
-
transcripts_map["Machine Translation: " + lecture_id] = {"transcript": transcript, "video": video_path}
|
232 |
-
st.global_state['KIT Lectures'] = transcripts_map
|
233 |
-
|
234 |
-
type_of_document = st.selectbox('What kind of document do you want to test it on?', list(st.global_state.keys()))
|
235 |
-
|
236 |
-
transcripts_map = st.global_state[type_of_document]
|
237 |
-
|
238 |
-
selected_talk = st.selectbox("Choose a document...", list(transcripts_map.keys()))
|
239 |
-
|
240 |
-
st.video(str(transcripts_map[selected_talk]['video']), format="video/mp4", start_time=0)
|
241 |
-
|
242 |
-
input_text = st.text_area("Transcript", value=transcripts_map[selected_talk]['transcript'], height=300)
|
243 |
-
|
244 |
-
toc = Toc()
|
245 |
-
|
246 |
-
summarization_todos = []
|
247 |
-
|
248 |
-
with st.expander("Adjust Thresholds"):
|
249 |
-
threshold = st.slider('Chapter Segmentation Threshold', 0.00, 1.00, value=0.4, step=0.05)
|
250 |
-
paragraphing_threshold = st.slider('Paragraphing Threshold', 0.00, 1.00, value=0.5, step=0.05)
|
251 |
-
|
252 |
-
if st.button("Process Transcript"):
|
253 |
-
with st.sidebar:
|
254 |
-
st.header("Table of Contents")
|
255 |
-
toc.placeholder()
|
256 |
-
|
257 |
-
st.header(selected_talk, divider='rainbow')
|
258 |
-
# if 'presenter' in transcripts_map[selected_talk]:
|
259 |
-
# st.markdown(f"### *by **{transcripts_map[selected_talk]['presenter']}***")
|
260 |
-
|
261 |
-
captions = transcripts_map[selected_talk]['captions'] if 'captions' in transcripts_map[selected_talk] else None
|
262 |
-
result = client.segment(input_text, captions, generate_titles=True, threshold=threshold)
|
263 |
-
if USE_PARAGRAPHING_MODEL:
|
264 |
-
presult = paragrapher.segment(input_text, captions, generate_titles=False, threshold=paragraphing_threshold)
|
265 |
-
paragraphs = presult['segments']
|
266 |
-
segments, titles, sentences = result['segments'], result['titles'], result['sentences']
|
267 |
-
|
268 |
-
if USE_PARAGRAPHING_MODEL:
|
269 |
-
prev_chapter_idx = 0
|
270 |
-
prev_paragraph_idx = 0
|
271 |
-
segment = []
|
272 |
-
for i, sentence in enumerate(sentences):
|
273 |
-
chapter, chapter_idx = get_sublist_by_flattened_index(segments, i)
|
274 |
-
paragraph, paragraph_idx = get_sublist_by_flattened_index(paragraphs, i)
|
275 |
-
|
276 |
-
if (chapter_idx != prev_chapter_idx and paragraph_idx == prev_paragraph_idx) or (paragraph_idx != prev_paragraph_idx and chapter_idx != prev_chapter_idx):
|
277 |
-
print("Chapter / Chapter & Paragraph")
|
278 |
-
segment_text = " ".join(segment)
|
279 |
-
toc.subheader(titles[prev_chapter_idx])
|
280 |
-
if len(segment_text) > 1200:
|
281 |
-
generated_text_box = st.info("")
|
282 |
-
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
|
283 |
-
elif len(segment_text) > 450:
|
284 |
-
generated_text_box = st.info("")
|
285 |
-
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
|
286 |
-
st.write(segment_text)
|
287 |
-
segment = []
|
288 |
-
elif paragraph_idx != prev_paragraph_idx and chapter_idx == prev_chapter_idx:
|
289 |
-
print("Paragraph")
|
290 |
-
segment.append("\n\n")
|
291 |
-
|
292 |
-
segment.append(sentence)
|
293 |
-
|
294 |
-
prev_chapter_idx = chapter_idx
|
295 |
-
prev_paragraph_idx = paragraph_idx
|
296 |
-
|
297 |
-
segment_text = " ".join(segment)
|
298 |
-
toc.subheader(titles[prev_chapter_idx])
|
299 |
-
if len(segment_text) > 1200:
|
300 |
-
generated_text_box = st.info("")
|
301 |
-
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
|
302 |
-
elif len(segment_text) > 450:
|
303 |
-
generated_text_box = st.info("")
|
304 |
-
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
|
305 |
-
st.write(segment_text)
|
306 |
-
|
307 |
-
|
308 |
-
else:
|
309 |
-
segments = [" ".join([sentence for sentence in segment]) for segment in segments]
|
310 |
-
for title, segment in zip(titles, segments):
|
311 |
-
toc.subheader(title)
|
312 |
-
if len(segment) > 1200:
|
313 |
-
generated_text_box = st.info("")
|
314 |
-
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
|
315 |
-
elif len(segment) > 450:
|
316 |
-
generated_text_box = st.info("")
|
317 |
-
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
|
318 |
-
st.write(segment)
|
319 |
-
toc.generate()
|
320 |
-
|
321 |
-
for summarization_todo in summarization_todos:
|
322 |
-
summarization_todo()
|
|
|
1 |
+
import itertools
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from functools import partial
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import requests
|
9 |
+
import streamlit as st
|
10 |
+
import webvtt
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
|
13 |
+
from generate_text_api import TextGenerator
|
14 |
+
from model_inferences.utils.chunking import Truncater
|
15 |
+
from model_inferences.utils.files import get_captions_from_vtt, get_transcript
|
16 |
+
|
17 |
+
USE_PARAGRAPHING_MODEL = True
|
18 |
+
|
19 |
+
def get_sublist_by_flattened_index(A, i):
|
20 |
+
current_index = 0
|
21 |
+
for sublist in A:
|
22 |
+
sublist_length = len(sublist)
|
23 |
+
if current_index <= i < current_index + sublist_length:
|
24 |
+
return sublist, A.index(sublist)
|
25 |
+
current_index += sublist_length
|
26 |
+
return None, None
|
27 |
+
|
28 |
+
import requests
|
29 |
+
|
30 |
+
|
31 |
+
def get_talk_metadata(video_id):
|
32 |
+
url = "https://www.ted.com/graphql"
|
33 |
+
|
34 |
+
headers = {
|
35 |
+
"Content-Type": "application/json",
|
36 |
+
"Accept": "application/json",
|
37 |
+
"x-operation-name": "Transcript", # Replace with the actual operation name
|
38 |
+
}
|
39 |
+
|
40 |
+
data = {
|
41 |
+
"query": """
|
42 |
+
query GetTalk($videoId: ID!) {
|
43 |
+
video(id: $videoId) {
|
44 |
+
title,
|
45 |
+
presenterDisplayName,
|
46 |
+
nativeDownloads {medium}
|
47 |
+
}
|
48 |
+
}
|
49 |
+
""",
|
50 |
+
"variables": {
|
51 |
+
"videoId": video_id, # Corrected key to "videoId"
|
52 |
+
},
|
53 |
+
}
|
54 |
+
|
55 |
+
response = requests.post(url, json=data, headers=headers)
|
56 |
+
|
57 |
+
if response.status_code == 200:
|
58 |
+
result = response.json()
|
59 |
+
return result
|
60 |
+
else:
|
61 |
+
print(f"Error: {response.status_code}, {response.text}")
|
62 |
+
|
63 |
+
class OfflineTextSegmenterClient:
|
64 |
+
def __init__(self, host_url):
|
65 |
+
self.host_url = host_url.rstrip("/") + "/segment"
|
66 |
+
|
67 |
+
def segment(self, text, captions=None, generate_titles=False, threshold=0.4):
|
68 |
+
payload = {
|
69 |
+
'text': text,
|
70 |
+
'captions': captions,
|
71 |
+
'generate_titles': generate_titles,
|
72 |
+
"prefix_titles": True,
|
73 |
+
"threshold": threshold,
|
74 |
+
}
|
75 |
+
|
76 |
+
headers = {
|
77 |
+
'Content-Type': 'application/json'
|
78 |
+
}
|
79 |
+
|
80 |
+
response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
|
81 |
+
#segments = response["annotated_segments"] if "annotated_segments" in response else response["segments"]
|
82 |
+
return {'segments':response["segments"], 'titles': response["titles"], 'sentences': response["sentences"]}
|
83 |
+
|
84 |
+
class Toc:
|
85 |
+
|
86 |
+
def __init__(self):
|
87 |
+
self._items = []
|
88 |
+
self._placeholder = None
|
89 |
+
|
90 |
+
def title(self, text):
|
91 |
+
self._markdown(text, "h1")
|
92 |
+
|
93 |
+
def header(self, text):
|
94 |
+
self._markdown(text, "h2", " " * 2)
|
95 |
+
|
96 |
+
def subheader(self, text):
|
97 |
+
self._markdown(text, "h3", " " * 4)
|
98 |
+
|
99 |
+
def placeholder(self, sidebar=False):
|
100 |
+
self._placeholder = st.sidebar.empty() if sidebar else st.empty()
|
101 |
+
|
102 |
+
def generate(self):
|
103 |
+
if self._placeholder:
|
104 |
+
self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True)
|
105 |
+
|
106 |
+
def _markdown(self, text, level, space=""):
|
107 |
+
key = re.sub(r'[^\w-]', '', text.replace(" ", "-").replace("'", "-").lower())
|
108 |
+
st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True)
|
109 |
+
self._items.append(f"{space}* <a href='#{key}'>{text}</a>")
|
110 |
+
|
111 |
+
# custom_css = "<style type='text/css'>" + Path('style.css').read_text() + "</style>"
|
112 |
+
# st.write(custom_css, unsafe_allow_html=True)
|
113 |
+
|
114 |
+
def concat_prompt(prompt_text, text, model_name):
|
115 |
+
if 'flan' in model_name:
|
116 |
+
input_ = prompt_text + "\n\n" + text
|
117 |
+
elif 'galactica' in model_name:
|
118 |
+
input_ = text + "\n\n" + prompt_text
|
119 |
+
return input_
|
120 |
+
|
121 |
+
endpoint = "http://hiaisc.isl.iar.kit.edu/summarize"
|
122 |
+
ENDPOINTS = {"http://hiaisc.isl.iar.kit.edu/summarize": "meta-llama/Llama-2-13b-chat-hf",}
|
123 |
+
|
124 |
+
client = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/chapterize")
|
125 |
+
if USE_PARAGRAPHING_MODEL:
|
126 |
+
paragrapher = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/paragraph")
|
127 |
+
summarizer = TextGenerator(endpoint)
|
128 |
+
|
129 |
+
tokenizer = AutoTokenizer.from_pretrained(ENDPOINTS[endpoint], use_fast=False)
|
130 |
+
|
131 |
+
# TLDR PROMPT
|
132 |
+
|
133 |
+
SYSTEM_PROMPT = "You are an assistant who replies with a summary to every message."
|
134 |
+
|
135 |
+
TLDR_PROMPT_TEMPLATE = """<s>[INST] <<SYS>>
|
136 |
+
{system_prompt}
|
137 |
+
<</SYS>>
|
138 |
+
|
139 |
+
{user_message} [/INST] Sure! Here is a summary of the research presentation in a single, short sentence:"""
|
140 |
+
|
141 |
+
TLDR_USER_PROMPT = "Summarize the following research presentation in a single, short sentence:\n\n{input}"
|
142 |
+
|
143 |
+
TLDR_PROMPT = TLDR_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT)
|
144 |
+
TLDR_PROMPT_LENGTH = tokenizer(TLDR_PROMPT, return_tensors="pt")["input_ids"].size(1)
|
145 |
+
|
146 |
+
BP_PROMPT_TEMPLATE = """<s>[INST] <<SYS>>
|
147 |
+
{system_prompt}
|
148 |
+
<</SYS>>
|
149 |
+
|
150 |
+
{user_message} [/INST] Sure! Here is a summary of the research presentation using three bullet points:\n\n\u2022"""
|
151 |
+
|
152 |
+
BP_USER_PROMPT = "Summarize the following research presentation using three bullet points:\n\n{input}"
|
153 |
+
|
154 |
+
BP_PROMPT = BP_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT)
|
155 |
+
BP_PROMPT_LENGTH = tokenizer(BP_PROMPT, return_tensors="pt")["input_ids"].size(1)
|
156 |
+
|
157 |
+
CONTEXT_LENGTH = 3072
|
158 |
+
MAX_SUMMARY_LENGTH = 1024
|
159 |
+
TLDR_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - TLDR_PROMPT_LENGTH - 1
|
160 |
+
BP_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - BP_PROMPT_LENGTH - 1
|
161 |
+
|
162 |
+
|
163 |
+
text_generator = TextGenerator(endpoint)
|
164 |
+
temperature = 0.7
|
165 |
+
|
166 |
+
import re
|
167 |
+
|
168 |
+
|
169 |
+
def replace_newlines(text):
|
170 |
+
updated_text = re.sub(r'\n+', r'\n\n', text)
|
171 |
+
return updated_text
|
172 |
+
|
173 |
+
def generate_summary(summarizer, generated_text_box, input_, prompt, max_input_length, prefix=""):
|
174 |
+
all_generated_text = prefix
|
175 |
+
truncater = Truncater(tokenizer, max_length=max_input_length)
|
176 |
+
input_ = truncater(input_)
|
177 |
+
input_ = prompt.format(input=input_)
|
178 |
+
for generated_text in summarizer.generate_text_stream(input_, max_new_tokens=MAX_SUMMARY_LENGTH, do_sample=True, temperature=temperature):
|
179 |
+
all_generated_text += replace_newlines(generated_text)
|
180 |
+
generated_text_box.info(all_generated_text)
|
181 |
+
print(all_generated_text)
|
182 |
+
return all_generated_text.strip()
|
183 |
+
|
184 |
+
st.header("Demo: Intelligent Recap")
|
185 |
+
|
186 |
+
if not hasattr(st, 'global_state'):
|
187 |
+
st.global_state = {'NIPS 2021 Talks': None, 'TED Talks': None}
|
188 |
+
# NIPS 2021 Talks
|
189 |
+
transcript_files = itertools.islice(Path("demo_data/nips-2021/").rglob("transcript_whisper_large-v2.vtt"), 15)
|
190 |
+
# get titles from metadata.json
|
191 |
+
transcripts_map = {}
|
192 |
+
for transcript_file in transcript_files:
|
193 |
+
base_path = transcript_file.parent
|
194 |
+
metadata = base_path / "metadata.json"
|
195 |
+
txt_file = base_path / "transcript_whisper_large-v2.txt"
|
196 |
+
with open(metadata) as f:
|
197 |
+
metadata = json.load(f)
|
198 |
+
title = metadata["title"]
|
199 |
+
transcript = get_transcript(txt_file)
|
200 |
+
captions = get_captions_from_vtt(transcript_file)
|
201 |
+
transcripts_map[title] = {"transcript": transcript, "captions": captions, "video": base_path / "video.mp4"}
|
202 |
+
st.global_state['NIPS 2021 Talks'] = transcripts_map
|
203 |
+
|
204 |
+
data = pd.read_json("demo_data/ted_talks.json")
|
205 |
+
video_ids = data.talk_id.tolist()
|
206 |
+
transcripts = data.text.apply(lambda x: " ".join(x)).tolist()
|
207 |
+
transcripts_map = {}
|
208 |
+
for video_id, transcript in zip(video_ids, transcripts):
|
209 |
+
metadata = get_talk_metadata(video_id)
|
210 |
+
title = metadata["data"]["video"]["title"]
|
211 |
+
presenter = metadata["data"]["video"]["presenterDisplayName"]
|
212 |
+
print(metadata["data"])
|
213 |
+
if metadata["data"]["video"]["nativeDownloads"] is None:
|
214 |
+
continue
|
215 |
+
video_url = metadata["data"]["video"]["nativeDownloads"]["medium"]
|
216 |
+
transcripts_map[title] = {"transcript": transcript, "video": video_url, "presenter": presenter}
|
217 |
+
st.global_state['TED Talks'] = transcripts_map
|
218 |
+
|
219 |
+
def get_lecture_id(path):
|
220 |
+
return int(path.parts[-2].split('-')[1])
|
221 |
+
|
222 |
+
transcript_files = Path("demo_data/lectures/").rglob("English.vtt")
|
223 |
+
sorted_path_list = sorted(transcript_files, key=get_lecture_id)
|
224 |
+
|
225 |
+
transcripts_map = {}
|
226 |
+
for transcript_file in sorted_path_list:
|
227 |
+
base_path = transcript_file.parent
|
228 |
+
lecture_id = base_path.parts[-1]
|
229 |
+
transcript = " ".join([c["text"].strip() for c in get_captions_from_vtt(transcript_file)]).replace("\n", " ")
|
230 |
+
video_path = Path(base_path, "video.mp4")
|
231 |
+
transcripts_map["Machine Translation: " + lecture_id] = {"transcript": transcript, "video": video_path}
|
232 |
+
st.global_state['KIT Lectures'] = transcripts_map
|
233 |
+
|
234 |
+
type_of_document = st.selectbox('What kind of document do you want to test it on?', list(st.global_state.keys()))
|
235 |
+
|
236 |
+
transcripts_map = st.global_state[type_of_document]
|
237 |
+
|
238 |
+
selected_talk = st.selectbox("Choose a document...", list(transcripts_map.keys()))
|
239 |
+
|
240 |
+
st.video(str(transcripts_map[selected_talk]['video']), format="video/mp4", start_time=0)
|
241 |
+
|
242 |
+
input_text = st.text_area("Transcript", value=transcripts_map[selected_talk]['transcript'], height=300)
|
243 |
+
|
244 |
+
toc = Toc()
|
245 |
+
|
246 |
+
summarization_todos = []
|
247 |
+
|
248 |
+
with st.expander("Adjust Thresholds"):
|
249 |
+
threshold = st.slider('Chapter Segmentation Threshold', 0.00, 1.00, value=0.4, step=0.05)
|
250 |
+
paragraphing_threshold = st.slider('Paragraphing Threshold', 0.00, 1.00, value=0.5, step=0.05)
|
251 |
+
|
252 |
+
if st.button("Process Transcript"):
|
253 |
+
with st.sidebar:
|
254 |
+
st.header("Table of Contents")
|
255 |
+
toc.placeholder()
|
256 |
+
|
257 |
+
st.header(selected_talk, divider='rainbow')
|
258 |
+
# if 'presenter' in transcripts_map[selected_talk]:
|
259 |
+
# st.markdown(f"### *by **{transcripts_map[selected_talk]['presenter']}***")
|
260 |
+
|
261 |
+
captions = transcripts_map[selected_talk]['captions'] if 'captions' in transcripts_map[selected_talk] else None
|
262 |
+
result = client.segment(input_text, captions, generate_titles=True, threshold=threshold)
|
263 |
+
if USE_PARAGRAPHING_MODEL:
|
264 |
+
presult = paragrapher.segment(input_text, captions, generate_titles=False, threshold=paragraphing_threshold)
|
265 |
+
paragraphs = presult['segments']
|
266 |
+
segments, titles, sentences = result['segments'], result['titles'], result['sentences']
|
267 |
+
|
268 |
+
if USE_PARAGRAPHING_MODEL:
|
269 |
+
prev_chapter_idx = 0
|
270 |
+
prev_paragraph_idx = 0
|
271 |
+
segment = []
|
272 |
+
for i, sentence in enumerate(sentences):
|
273 |
+
chapter, chapter_idx = get_sublist_by_flattened_index(segments, i)
|
274 |
+
paragraph, paragraph_idx = get_sublist_by_flattened_index(paragraphs, i)
|
275 |
+
|
276 |
+
if (chapter_idx != prev_chapter_idx and paragraph_idx == prev_paragraph_idx) or (paragraph_idx != prev_paragraph_idx and chapter_idx != prev_chapter_idx):
|
277 |
+
print("Chapter / Chapter & Paragraph")
|
278 |
+
segment_text = " ".join(segment)
|
279 |
+
toc.subheader(titles[prev_chapter_idx])
|
280 |
+
if len(segment_text) > 1200:
|
281 |
+
generated_text_box = st.info("")
|
282 |
+
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
|
283 |
+
elif len(segment_text) > 450:
|
284 |
+
generated_text_box = st.info("")
|
285 |
+
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
|
286 |
+
st.write(segment_text)
|
287 |
+
segment = []
|
288 |
+
elif paragraph_idx != prev_paragraph_idx and chapter_idx == prev_chapter_idx:
|
289 |
+
print("Paragraph")
|
290 |
+
segment.append("\n\n")
|
291 |
+
|
292 |
+
segment.append(sentence)
|
293 |
+
|
294 |
+
prev_chapter_idx = chapter_idx
|
295 |
+
prev_paragraph_idx = paragraph_idx
|
296 |
+
|
297 |
+
segment_text = " ".join(segment)
|
298 |
+
toc.subheader(titles[prev_chapter_idx])
|
299 |
+
if len(segment_text) > 1200:
|
300 |
+
generated_text_box = st.info("")
|
301 |
+
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
|
302 |
+
elif len(segment_text) > 450:
|
303 |
+
generated_text_box = st.info("")
|
304 |
+
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
|
305 |
+
st.write(segment_text)
|
306 |
+
|
307 |
+
|
308 |
+
else:
|
309 |
+
segments = [" ".join([sentence for sentence in segment]) for segment in segments]
|
310 |
+
for title, segment in zip(titles, segments):
|
311 |
+
toc.subheader(title)
|
312 |
+
if len(segment) > 1200:
|
313 |
+
generated_text_box = st.info("")
|
314 |
+
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
|
315 |
+
elif len(segment) > 450:
|
316 |
+
generated_text_box = st.info("")
|
317 |
+
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
|
318 |
+
st.write(segment)
|
319 |
+
toc.generate()
|
320 |
+
|
321 |
+
for summarization_todo in summarization_todos:
|
322 |
+
summarization_todo()
|