retkowski commited on
Commit
9383286
1 Parent(s): 06f911c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +322 -322
app.py CHANGED
@@ -1,322 +1,322 @@
1
- import itertools
2
- import json
3
- import re
4
- from functools import partial
5
- from pathlib import Path
6
-
7
- import pandas as pd
8
- import requests
9
- import streamlit as st
10
- import webvtt
11
- from transformers import AutoTokenizer
12
-
13
- from generate_text_api import TextGenerator
14
- from model_inferences.utils.chunking import Truncater
15
- from model_inferences.utils.files import get_captions_from_vtt, get_transcript
16
-
17
- USE_PARAGRAPHING_MODEL = True
18
-
19
- def get_sublist_by_flattened_index(A, i):
20
- current_index = 0
21
- for sublist in A:
22
- sublist_length = len(sublist)
23
- if current_index <= i < current_index + sublist_length:
24
- return sublist, A.index(sublist)
25
- current_index += sublist_length
26
- return None, None
27
-
28
- import requests
29
-
30
-
31
- def get_talk_metadata(video_id):
32
- url = "https://www.ted.com/graphql"
33
-
34
- headers = {
35
- "Content-Type": "application/json",
36
- "Accept": "application/json",
37
- "x-operation-name": "Transcript", # Replace with the actual operation name
38
- }
39
-
40
- data = {
41
- "query": """
42
- query GetTalk($videoId: ID!) {
43
- video(id: $videoId) {
44
- title,
45
- presenterDisplayName,
46
- nativeDownloads {medium}
47
- }
48
- }
49
- """,
50
- "variables": {
51
- "videoId": video_id, # Corrected key to "videoId"
52
- },
53
- }
54
-
55
- response = requests.post(url, json=data, headers=headers)
56
-
57
- if response.status_code == 200:
58
- result = response.json()
59
- return result
60
- else:
61
- print(f"Error: {response.status_code}, {response.text}")
62
-
63
- class OfflineTextSegmenterClient:
64
- def __init__(self, host_url):
65
- self.host_url = host_url.rstrip("/") + "/segment"
66
-
67
- def segment(self, text, captions=None, generate_titles=False, threshold=0.4):
68
- payload = {
69
- 'text': text,
70
- 'captions': captions,
71
- 'generate_titles': generate_titles,
72
- "prefix_titles": True,
73
- "threshold": threshold,
74
- }
75
-
76
- headers = {
77
- 'Content-Type': 'application/json'
78
- }
79
-
80
- response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
81
- #segments = response["annotated_segments"] if "annotated_segments" in response else response["segments"]
82
- return {'segments':response["segments"], 'titles': response["titles"], 'sentences': response["sentences"]}
83
-
84
- class Toc:
85
-
86
- def __init__(self):
87
- self._items = []
88
- self._placeholder = None
89
-
90
- def title(self, text):
91
- self._markdown(text, "h1")
92
-
93
- def header(self, text):
94
- self._markdown(text, "h2", " " * 2)
95
-
96
- def subheader(self, text):
97
- self._markdown(text, "h3", " " * 4)
98
-
99
- def placeholder(self, sidebar=False):
100
- self._placeholder = st.sidebar.empty() if sidebar else st.empty()
101
-
102
- def generate(self):
103
- if self._placeholder:
104
- self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True)
105
-
106
- def _markdown(self, text, level, space=""):
107
- key = re.sub(r'[^\w-]', '', text.replace(" ", "-").replace("'", "-").lower())
108
- st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True)
109
- self._items.append(f"{space}* <a href='#{key}'>{text}</a>")
110
-
111
- # custom_css = "<style type='text/css'>" + Path('style.css').read_text() + "</style>"
112
- # st.write(custom_css, unsafe_allow_html=True)
113
-
114
- def concat_prompt(prompt_text, text, model_name):
115
- if 'flan' in model_name:
116
- input_ = prompt_text + "\n\n" + text
117
- elif 'galactica' in model_name:
118
- input_ = text + "\n\n" + prompt_text
119
- return input_
120
-
121
- endpoint = "http://hiaisc.isl.iar.kit.edu/summarize"
122
- ENDPOINTS = {"http://hiaisc.isl.iar.kit.edu/summarize": "meta-llama/Llama-2-13b-chat-hf",}
123
-
124
- client = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/chapter")
125
- if USE_PARAGRAPHING_MODEL:
126
- paragrapher = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/paragraph")
127
- summarizer = TextGenerator(endpoint)
128
-
129
- tokenizer = AutoTokenizer.from_pretrained(ENDPOINTS[endpoint], use_fast=False)
130
-
131
- # TLDR PROMPT
132
-
133
- SYSTEM_PROMPT = "You are an assistant who replies with a summary to every message."
134
-
135
- TLDR_PROMPT_TEMPLATE = """<s>[INST] <<SYS>>
136
- {system_prompt}
137
- <</SYS>>
138
-
139
- {user_message} [/INST] Sure! Here is a summary of the research presentation in a single, short sentence:"""
140
-
141
- TLDR_USER_PROMPT = "Summarize the following research presentation in a single, short sentence:\n\n{input}"
142
-
143
- TLDR_PROMPT = TLDR_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT)
144
- TLDR_PROMPT_LENGTH = tokenizer(TLDR_PROMPT, return_tensors="pt")["input_ids"].size(1)
145
-
146
- BP_PROMPT_TEMPLATE = """<s>[INST] <<SYS>>
147
- {system_prompt}
148
- <</SYS>>
149
-
150
- {user_message} [/INST] Sure! Here is a summary of the research presentation using three bullet points:\n\n\u2022"""
151
-
152
- BP_USER_PROMPT = "Summarize the following research presentation using three bullet points:\n\n{input}"
153
-
154
- BP_PROMPT = BP_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT)
155
- BP_PROMPT_LENGTH = tokenizer(BP_PROMPT, return_tensors="pt")["input_ids"].size(1)
156
-
157
- CONTEXT_LENGTH = 3072
158
- MAX_SUMMARY_LENGTH = 1024
159
- TLDR_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - TLDR_PROMPT_LENGTH - 1
160
- BP_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - BP_PROMPT_LENGTH - 1
161
-
162
-
163
- text_generator = TextGenerator(endpoint)
164
- temperature = 0.7
165
-
166
- import re
167
-
168
-
169
- def replace_newlines(text):
170
- updated_text = re.sub(r'\n+', r'\n\n', text)
171
- return updated_text
172
-
173
- def generate_summary(summarizer, generated_text_box, input_, prompt, max_input_length, prefix=""):
174
- all_generated_text = prefix
175
- truncater = Truncater(tokenizer, max_length=max_input_length)
176
- input_ = truncater(input_)
177
- input_ = prompt.format(input=input_)
178
- for generated_text in summarizer.generate_text_stream(input_, max_new_tokens=MAX_SUMMARY_LENGTH, do_sample=True, temperature=temperature):
179
- all_generated_text += replace_newlines(generated_text)
180
- generated_text_box.info(all_generated_text)
181
- print(all_generated_text)
182
- return all_generated_text.strip()
183
-
184
- st.header("Demo: Intelligent Recap")
185
-
186
- if not hasattr(st, 'global_state'):
187
- st.global_state = {'NIPS 2021 Talks': None, 'TED Talks': None}
188
- # NIPS 2021 Talks
189
- transcript_files = itertools.islice(Path("demo_data/nips-2021/").rglob("transcript_whisper_large-v2.vtt"), 15)
190
- # get titles from metadata.json
191
- transcripts_map = {}
192
- for transcript_file in transcript_files:
193
- base_path = transcript_file.parent
194
- metadata = base_path / "metadata.json"
195
- txt_file = base_path / "transcript_whisper_large-v2.txt"
196
- with open(metadata) as f:
197
- metadata = json.load(f)
198
- title = metadata["title"]
199
- transcript = get_transcript(txt_file)
200
- captions = get_captions_from_vtt(transcript_file)
201
- transcripts_map[title] = {"transcript": transcript, "captions": captions, "video": base_path / "video.mp4"}
202
- st.global_state['NIPS 2021 Talks'] = transcripts_map
203
-
204
- data = pd.read_json("demo_data/ted_talks.json")
205
- video_ids = data.talk_id.tolist()
206
- transcripts = data.text.apply(lambda x: " ".join(x)).tolist()
207
- transcripts_map = {}
208
- for video_id, transcript in zip(video_ids, transcripts):
209
- metadata = get_talk_metadata(video_id)
210
- title = metadata["data"]["video"]["title"]
211
- presenter = metadata["data"]["video"]["presenterDisplayName"]
212
- print(metadata["data"])
213
- if metadata["data"]["video"]["nativeDownloads"] is None:
214
- continue
215
- video_url = metadata["data"]["video"]["nativeDownloads"]["medium"]
216
- transcripts_map[title] = {"transcript": transcript, "video": video_url, "presenter": presenter}
217
- st.global_state['TED Talks'] = transcripts_map
218
-
219
- def get_lecture_id(path):
220
- return int(path.parts[-2].split('-')[1])
221
-
222
- transcript_files = Path("demo_data/lectures/").rglob("English.vtt")
223
- sorted_path_list = sorted(transcript_files, key=get_lecture_id)
224
-
225
- transcripts_map = {}
226
- for transcript_file in sorted_path_list:
227
- base_path = transcript_file.parent
228
- lecture_id = base_path.parts[-1]
229
- transcript = " ".join([c["text"].strip() for c in get_captions_from_vtt(transcript_file)]).replace("\n", " ")
230
- video_path = Path(base_path, "video.mp4")
231
- transcripts_map["Machine Translation: " + lecture_id] = {"transcript": transcript, "video": video_path}
232
- st.global_state['KIT Lectures'] = transcripts_map
233
-
234
- type_of_document = st.selectbox('What kind of document do you want to test it on?', list(st.global_state.keys()))
235
-
236
- transcripts_map = st.global_state[type_of_document]
237
-
238
- selected_talk = st.selectbox("Choose a document...", list(transcripts_map.keys()))
239
-
240
- st.video(str(transcripts_map[selected_talk]['video']), format="video/mp4", start_time=0)
241
-
242
- input_text = st.text_area("Transcript", value=transcripts_map[selected_talk]['transcript'], height=300)
243
-
244
- toc = Toc()
245
-
246
- summarization_todos = []
247
-
248
- with st.expander("Adjust Thresholds"):
249
- threshold = st.slider('Chapter Segmentation Threshold', 0.00, 1.00, value=0.4, step=0.05)
250
- paragraphing_threshold = st.slider('Paragraphing Threshold', 0.00, 1.00, value=0.5, step=0.05)
251
-
252
- if st.button("Process Transcript"):
253
- with st.sidebar:
254
- st.header("Table of Contents")
255
- toc.placeholder()
256
-
257
- st.header(selected_talk, divider='rainbow')
258
- # if 'presenter' in transcripts_map[selected_talk]:
259
- # st.markdown(f"### *by **{transcripts_map[selected_talk]['presenter']}***")
260
-
261
- captions = transcripts_map[selected_talk]['captions'] if 'captions' in transcripts_map[selected_talk] else None
262
- result = client.segment(input_text, captions, generate_titles=True, threshold=threshold)
263
- if USE_PARAGRAPHING_MODEL:
264
- presult = paragrapher.segment(input_text, captions, generate_titles=False, threshold=paragraphing_threshold)
265
- paragraphs = presult['segments']
266
- segments, titles, sentences = result['segments'], result['titles'], result['sentences']
267
-
268
- if USE_PARAGRAPHING_MODEL:
269
- prev_chapter_idx = 0
270
- prev_paragraph_idx = 0
271
- segment = []
272
- for i, sentence in enumerate(sentences):
273
- chapter, chapter_idx = get_sublist_by_flattened_index(segments, i)
274
- paragraph, paragraph_idx = get_sublist_by_flattened_index(paragraphs, i)
275
-
276
- if (chapter_idx != prev_chapter_idx and paragraph_idx == prev_paragraph_idx) or (paragraph_idx != prev_paragraph_idx and chapter_idx != prev_chapter_idx):
277
- print("Chapter / Chapter & Paragraph")
278
- segment_text = " ".join(segment)
279
- toc.subheader(titles[prev_chapter_idx])
280
- if len(segment_text) > 1200:
281
- generated_text_box = st.info("")
282
- summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
283
- elif len(segment_text) > 450:
284
- generated_text_box = st.info("")
285
- summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
286
- st.write(segment_text)
287
- segment = []
288
- elif paragraph_idx != prev_paragraph_idx and chapter_idx == prev_chapter_idx:
289
- print("Paragraph")
290
- segment.append("\n\n")
291
-
292
- segment.append(sentence)
293
-
294
- prev_chapter_idx = chapter_idx
295
- prev_paragraph_idx = paragraph_idx
296
-
297
- segment_text = " ".join(segment)
298
- toc.subheader(titles[prev_chapter_idx])
299
- if len(segment_text) > 1200:
300
- generated_text_box = st.info("")
301
- summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
302
- elif len(segment_text) > 450:
303
- generated_text_box = st.info("")
304
- summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
305
- st.write(segment_text)
306
-
307
-
308
- else:
309
- segments = [" ".join([sentence for sentence in segment]) for segment in segments]
310
- for title, segment in zip(titles, segments):
311
- toc.subheader(title)
312
- if len(segment) > 1200:
313
- generated_text_box = st.info("")
314
- summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
315
- elif len(segment) > 450:
316
- generated_text_box = st.info("")
317
- summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
318
- st.write(segment)
319
- toc.generate()
320
-
321
- for summarization_todo in summarization_todos:
322
- summarization_todo()
 
1
+ import itertools
2
+ import json
3
+ import re
4
+ from functools import partial
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+ import requests
9
+ import streamlit as st
10
+ import webvtt
11
+ from transformers import AutoTokenizer
12
+
13
+ from generate_text_api import TextGenerator
14
+ from model_inferences.utils.chunking import Truncater
15
+ from model_inferences.utils.files import get_captions_from_vtt, get_transcript
16
+
17
+ USE_PARAGRAPHING_MODEL = True
18
+
19
+ def get_sublist_by_flattened_index(A, i):
20
+ current_index = 0
21
+ for sublist in A:
22
+ sublist_length = len(sublist)
23
+ if current_index <= i < current_index + sublist_length:
24
+ return sublist, A.index(sublist)
25
+ current_index += sublist_length
26
+ return None, None
27
+
28
+ import requests
29
+
30
+
31
+ def get_talk_metadata(video_id):
32
+ url = "https://www.ted.com/graphql"
33
+
34
+ headers = {
35
+ "Content-Type": "application/json",
36
+ "Accept": "application/json",
37
+ "x-operation-name": "Transcript", # Replace with the actual operation name
38
+ }
39
+
40
+ data = {
41
+ "query": """
42
+ query GetTalk($videoId: ID!) {
43
+ video(id: $videoId) {
44
+ title,
45
+ presenterDisplayName,
46
+ nativeDownloads {medium}
47
+ }
48
+ }
49
+ """,
50
+ "variables": {
51
+ "videoId": video_id, # Corrected key to "videoId"
52
+ },
53
+ }
54
+
55
+ response = requests.post(url, json=data, headers=headers)
56
+
57
+ if response.status_code == 200:
58
+ result = response.json()
59
+ return result
60
+ else:
61
+ print(f"Error: {response.status_code}, {response.text}")
62
+
63
+ class OfflineTextSegmenterClient:
64
+ def __init__(self, host_url):
65
+ self.host_url = host_url.rstrip("/") + "/segment"
66
+
67
+ def segment(self, text, captions=None, generate_titles=False, threshold=0.4):
68
+ payload = {
69
+ 'text': text,
70
+ 'captions': captions,
71
+ 'generate_titles': generate_titles,
72
+ "prefix_titles": True,
73
+ "threshold": threshold,
74
+ }
75
+
76
+ headers = {
77
+ 'Content-Type': 'application/json'
78
+ }
79
+
80
+ response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
81
+ #segments = response["annotated_segments"] if "annotated_segments" in response else response["segments"]
82
+ return {'segments':response["segments"], 'titles': response["titles"], 'sentences': response["sentences"]}
83
+
84
+ class Toc:
85
+
86
+ def __init__(self):
87
+ self._items = []
88
+ self._placeholder = None
89
+
90
+ def title(self, text):
91
+ self._markdown(text, "h1")
92
+
93
+ def header(self, text):
94
+ self._markdown(text, "h2", " " * 2)
95
+
96
+ def subheader(self, text):
97
+ self._markdown(text, "h3", " " * 4)
98
+
99
+ def placeholder(self, sidebar=False):
100
+ self._placeholder = st.sidebar.empty() if sidebar else st.empty()
101
+
102
+ def generate(self):
103
+ if self._placeholder:
104
+ self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True)
105
+
106
+ def _markdown(self, text, level, space=""):
107
+ key = re.sub(r'[^\w-]', '', text.replace(" ", "-").replace("'", "-").lower())
108
+ st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True)
109
+ self._items.append(f"{space}* <a href='#{key}'>{text}</a>")
110
+
111
+ # custom_css = "<style type='text/css'>" + Path('style.css').read_text() + "</style>"
112
+ # st.write(custom_css, unsafe_allow_html=True)
113
+
114
+ def concat_prompt(prompt_text, text, model_name):
115
+ if 'flan' in model_name:
116
+ input_ = prompt_text + "\n\n" + text
117
+ elif 'galactica' in model_name:
118
+ input_ = text + "\n\n" + prompt_text
119
+ return input_
120
+
121
+ endpoint = "http://hiaisc.isl.iar.kit.edu/summarize"
122
+ ENDPOINTS = {"http://hiaisc.isl.iar.kit.edu/summarize": "meta-llama/Llama-2-13b-chat-hf",}
123
+
124
+ client = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/chapterize")
125
+ if USE_PARAGRAPHING_MODEL:
126
+ paragrapher = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/paragraph")
127
+ summarizer = TextGenerator(endpoint)
128
+
129
+ tokenizer = AutoTokenizer.from_pretrained(ENDPOINTS[endpoint], use_fast=False)
130
+
131
+ # TLDR PROMPT
132
+
133
+ SYSTEM_PROMPT = "You are an assistant who replies with a summary to every message."
134
+
135
+ TLDR_PROMPT_TEMPLATE = """<s>[INST] <<SYS>>
136
+ {system_prompt}
137
+ <</SYS>>
138
+
139
+ {user_message} [/INST] Sure! Here is a summary of the research presentation in a single, short sentence:"""
140
+
141
+ TLDR_USER_PROMPT = "Summarize the following research presentation in a single, short sentence:\n\n{input}"
142
+
143
+ TLDR_PROMPT = TLDR_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT)
144
+ TLDR_PROMPT_LENGTH = tokenizer(TLDR_PROMPT, return_tensors="pt")["input_ids"].size(1)
145
+
146
+ BP_PROMPT_TEMPLATE = """<s>[INST] <<SYS>>
147
+ {system_prompt}
148
+ <</SYS>>
149
+
150
+ {user_message} [/INST] Sure! Here is a summary of the research presentation using three bullet points:\n\n\u2022"""
151
+
152
+ BP_USER_PROMPT = "Summarize the following research presentation using three bullet points:\n\n{input}"
153
+
154
+ BP_PROMPT = BP_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT)
155
+ BP_PROMPT_LENGTH = tokenizer(BP_PROMPT, return_tensors="pt")["input_ids"].size(1)
156
+
157
+ CONTEXT_LENGTH = 3072
158
+ MAX_SUMMARY_LENGTH = 1024
159
+ TLDR_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - TLDR_PROMPT_LENGTH - 1
160
+ BP_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - BP_PROMPT_LENGTH - 1
161
+
162
+
163
+ text_generator = TextGenerator(endpoint)
164
+ temperature = 0.7
165
+
166
+ import re
167
+
168
+
169
+ def replace_newlines(text):
170
+ updated_text = re.sub(r'\n+', r'\n\n', text)
171
+ return updated_text
172
+
173
+ def generate_summary(summarizer, generated_text_box, input_, prompt, max_input_length, prefix=""):
174
+ all_generated_text = prefix
175
+ truncater = Truncater(tokenizer, max_length=max_input_length)
176
+ input_ = truncater(input_)
177
+ input_ = prompt.format(input=input_)
178
+ for generated_text in summarizer.generate_text_stream(input_, max_new_tokens=MAX_SUMMARY_LENGTH, do_sample=True, temperature=temperature):
179
+ all_generated_text += replace_newlines(generated_text)
180
+ generated_text_box.info(all_generated_text)
181
+ print(all_generated_text)
182
+ return all_generated_text.strip()
183
+
184
+ st.header("Demo: Intelligent Recap")
185
+
186
+ if not hasattr(st, 'global_state'):
187
+ st.global_state = {'NIPS 2021 Talks': None, 'TED Talks': None}
188
+ # NIPS 2021 Talks
189
+ transcript_files = itertools.islice(Path("demo_data/nips-2021/").rglob("transcript_whisper_large-v2.vtt"), 15)
190
+ # get titles from metadata.json
191
+ transcripts_map = {}
192
+ for transcript_file in transcript_files:
193
+ base_path = transcript_file.parent
194
+ metadata = base_path / "metadata.json"
195
+ txt_file = base_path / "transcript_whisper_large-v2.txt"
196
+ with open(metadata) as f:
197
+ metadata = json.load(f)
198
+ title = metadata["title"]
199
+ transcript = get_transcript(txt_file)
200
+ captions = get_captions_from_vtt(transcript_file)
201
+ transcripts_map[title] = {"transcript": transcript, "captions": captions, "video": base_path / "video.mp4"}
202
+ st.global_state['NIPS 2021 Talks'] = transcripts_map
203
+
204
+ data = pd.read_json("demo_data/ted_talks.json")
205
+ video_ids = data.talk_id.tolist()
206
+ transcripts = data.text.apply(lambda x: " ".join(x)).tolist()
207
+ transcripts_map = {}
208
+ for video_id, transcript in zip(video_ids, transcripts):
209
+ metadata = get_talk_metadata(video_id)
210
+ title = metadata["data"]["video"]["title"]
211
+ presenter = metadata["data"]["video"]["presenterDisplayName"]
212
+ print(metadata["data"])
213
+ if metadata["data"]["video"]["nativeDownloads"] is None:
214
+ continue
215
+ video_url = metadata["data"]["video"]["nativeDownloads"]["medium"]
216
+ transcripts_map[title] = {"transcript": transcript, "video": video_url, "presenter": presenter}
217
+ st.global_state['TED Talks'] = transcripts_map
218
+
219
+ def get_lecture_id(path):
220
+ return int(path.parts[-2].split('-')[1])
221
+
222
+ transcript_files = Path("demo_data/lectures/").rglob("English.vtt")
223
+ sorted_path_list = sorted(transcript_files, key=get_lecture_id)
224
+
225
+ transcripts_map = {}
226
+ for transcript_file in sorted_path_list:
227
+ base_path = transcript_file.parent
228
+ lecture_id = base_path.parts[-1]
229
+ transcript = " ".join([c["text"].strip() for c in get_captions_from_vtt(transcript_file)]).replace("\n", " ")
230
+ video_path = Path(base_path, "video.mp4")
231
+ transcripts_map["Machine Translation: " + lecture_id] = {"transcript": transcript, "video": video_path}
232
+ st.global_state['KIT Lectures'] = transcripts_map
233
+
234
+ type_of_document = st.selectbox('What kind of document do you want to test it on?', list(st.global_state.keys()))
235
+
236
+ transcripts_map = st.global_state[type_of_document]
237
+
238
+ selected_talk = st.selectbox("Choose a document...", list(transcripts_map.keys()))
239
+
240
+ st.video(str(transcripts_map[selected_talk]['video']), format="video/mp4", start_time=0)
241
+
242
+ input_text = st.text_area("Transcript", value=transcripts_map[selected_talk]['transcript'], height=300)
243
+
244
+ toc = Toc()
245
+
246
+ summarization_todos = []
247
+
248
+ with st.expander("Adjust Thresholds"):
249
+ threshold = st.slider('Chapter Segmentation Threshold', 0.00, 1.00, value=0.4, step=0.05)
250
+ paragraphing_threshold = st.slider('Paragraphing Threshold', 0.00, 1.00, value=0.5, step=0.05)
251
+
252
+ if st.button("Process Transcript"):
253
+ with st.sidebar:
254
+ st.header("Table of Contents")
255
+ toc.placeholder()
256
+
257
+ st.header(selected_talk, divider='rainbow')
258
+ # if 'presenter' in transcripts_map[selected_talk]:
259
+ # st.markdown(f"### *by **{transcripts_map[selected_talk]['presenter']}***")
260
+
261
+ captions = transcripts_map[selected_talk]['captions'] if 'captions' in transcripts_map[selected_talk] else None
262
+ result = client.segment(input_text, captions, generate_titles=True, threshold=threshold)
263
+ if USE_PARAGRAPHING_MODEL:
264
+ presult = paragrapher.segment(input_text, captions, generate_titles=False, threshold=paragraphing_threshold)
265
+ paragraphs = presult['segments']
266
+ segments, titles, sentences = result['segments'], result['titles'], result['sentences']
267
+
268
+ if USE_PARAGRAPHING_MODEL:
269
+ prev_chapter_idx = 0
270
+ prev_paragraph_idx = 0
271
+ segment = []
272
+ for i, sentence in enumerate(sentences):
273
+ chapter, chapter_idx = get_sublist_by_flattened_index(segments, i)
274
+ paragraph, paragraph_idx = get_sublist_by_flattened_index(paragraphs, i)
275
+
276
+ if (chapter_idx != prev_chapter_idx and paragraph_idx == prev_paragraph_idx) or (paragraph_idx != prev_paragraph_idx and chapter_idx != prev_chapter_idx):
277
+ print("Chapter / Chapter & Paragraph")
278
+ segment_text = " ".join(segment)
279
+ toc.subheader(titles[prev_chapter_idx])
280
+ if len(segment_text) > 1200:
281
+ generated_text_box = st.info("")
282
+ summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
283
+ elif len(segment_text) > 450:
284
+ generated_text_box = st.info("")
285
+ summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
286
+ st.write(segment_text)
287
+ segment = []
288
+ elif paragraph_idx != prev_paragraph_idx and chapter_idx == prev_chapter_idx:
289
+ print("Paragraph")
290
+ segment.append("\n\n")
291
+
292
+ segment.append(sentence)
293
+
294
+ prev_chapter_idx = chapter_idx
295
+ prev_paragraph_idx = paragraph_idx
296
+
297
+ segment_text = " ".join(segment)
298
+ toc.subheader(titles[prev_chapter_idx])
299
+ if len(segment_text) > 1200:
300
+ generated_text_box = st.info("")
301
+ summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
302
+ elif len(segment_text) > 450:
303
+ generated_text_box = st.info("")
304
+ summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
305
+ st.write(segment_text)
306
+
307
+
308
+ else:
309
+ segments = [" ".join([sentence for sentence in segment]) for segment in segments]
310
+ for title, segment in zip(titles, segments):
311
+ toc.subheader(title)
312
+ if len(segment) > 1200:
313
+ generated_text_box = st.info("")
314
+ summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022"))
315
+ elif len(segment) > 450:
316
+ generated_text_box = st.info("")
317
+ summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH))
318
+ st.write(segment)
319
+ toc.generate()
320
+
321
+ for summarization_todo in summarization_todos:
322
+ summarization_todo()