Spaces:
Running
Running
import itertools | |
import json | |
import re | |
from functools import partial | |
from pathlib import Path | |
import pandas as pd | |
import requests | |
import streamlit as st | |
import webvtt | |
from transformers import AutoTokenizer | |
from generate_text_api import TextGenerator | |
from model_inferences.utils.chunking import Truncater | |
from model_inferences.utils.files import get_captions_from_vtt, get_transcript | |
USE_PARAGRAPHING_MODEL = True | |
def get_sublist_by_flattened_index(A, i): | |
current_index = 0 | |
for sublist in A: | |
sublist_length = len(sublist) | |
if current_index <= i < current_index + sublist_length: | |
return sublist, A.index(sublist) | |
current_index += sublist_length | |
return None, None | |
import requests | |
def get_talk_metadata(video_id): | |
url = "https://www.ted.com/graphql" | |
headers = { | |
"Content-Type": "application/json", | |
"Accept": "application/json", | |
"x-operation-name": "Transcript", # Replace with the actual operation name | |
} | |
data = { | |
"query": """ | |
query GetTalk($videoId: ID!) { | |
video(id: $videoId) { | |
title, | |
presenterDisplayName, | |
nativeDownloads {medium} | |
} | |
} | |
""", | |
"variables": { | |
"videoId": video_id, # Corrected key to "videoId" | |
}, | |
} | |
response = requests.post(url, json=data, headers=headers) | |
if response.status_code == 200: | |
result = response.json() | |
return result | |
else: | |
print(f"Error: {response.status_code}, {response.text}") | |
class OfflineTextSegmenterClient: | |
def __init__(self, host_url): | |
self.host_url = host_url.rstrip("/") + "/segment" | |
def segment(self, text, captions=None, generate_titles=False, threshold=0.4): | |
payload = { | |
'text': text, | |
'captions': captions, | |
'generate_titles': generate_titles, | |
"prefix_titles": True, | |
"threshold": threshold, | |
} | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json() | |
#segments = response["annotated_segments"] if "annotated_segments" in response else response["segments"] | |
return {'segments':response["segments"], 'titles': response["titles"], 'sentences': response["sentences"]} | |
class Toc: | |
def __init__(self): | |
self._items = [] | |
self._placeholder = None | |
def title(self, text): | |
self._markdown(text, "h1") | |
def header(self, text): | |
self._markdown(text, "h2", " " * 2) | |
def subheader(self, text): | |
self._markdown(text, "h3", " " * 4) | |
def placeholder(self, sidebar=False): | |
self._placeholder = st.sidebar.empty() if sidebar else st.empty() | |
def generate(self): | |
if self._placeholder: | |
self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True) | |
def _markdown(self, text, level, space=""): | |
key = re.sub(r'[^\w-]', '', text.replace(" ", "-").replace("'", "-").lower()) | |
st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True) | |
self._items.append(f"{space}* <a href='#{key}'>{text}</a>") | |
# custom_css = "<style type='text/css'>" + Path('style.css').read_text() + "</style>" | |
# st.write(custom_css, unsafe_allow_html=True) | |
def concat_prompt(prompt_text, text, model_name): | |
if 'flan' in model_name: | |
input_ = prompt_text + "\n\n" + text | |
elif 'galactica' in model_name: | |
input_ = text + "\n\n" + prompt_text | |
return input_ | |
endpoint = "http://hiaisc.isl.iar.kit.edu/summarize" | |
ENDPOINTS = {"http://hiaisc.isl.iar.kit.edu/summarize": "meta-llama/Llama-2-13b-chat-hf",} | |
client = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/chapterize") | |
if USE_PARAGRAPHING_MODEL: | |
paragrapher = OfflineTextSegmenterClient("http://hiaisc.isl.iar.kit.edu/paragraph") | |
summarizer = TextGenerator(endpoint) | |
tokenizer = AutoTokenizer.from_pretrained(ENDPOINTS[endpoint], use_fast=False) | |
# TLDR PROMPT | |
SYSTEM_PROMPT = "You are an assistant who replies with a summary to every message." | |
TLDR_PROMPT_TEMPLATE = """<s>[INST] <<SYS>> | |
{system_prompt} | |
<</SYS>> | |
{user_message} [/INST] Sure! Here is a summary of the research presentation in a single, short sentence:""" | |
TLDR_USER_PROMPT = "Summarize the following research presentation in a single, short sentence:\n\n{input}" | |
TLDR_PROMPT = TLDR_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT) | |
TLDR_PROMPT_LENGTH = tokenizer(TLDR_PROMPT, return_tensors="pt")["input_ids"].size(1) | |
BP_PROMPT_TEMPLATE = """<s>[INST] <<SYS>> | |
{system_prompt} | |
<</SYS>> | |
{user_message} [/INST] Sure! Here is a summary of the research presentation using three bullet points:\n\n\u2022""" | |
BP_USER_PROMPT = "Summarize the following research presentation using three bullet points:\n\n{input}" | |
BP_PROMPT = BP_PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_message=TLDR_USER_PROMPT) | |
BP_PROMPT_LENGTH = tokenizer(BP_PROMPT, return_tensors="pt")["input_ids"].size(1) | |
CONTEXT_LENGTH = 3072 | |
MAX_SUMMARY_LENGTH = 1024 | |
TLDR_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - TLDR_PROMPT_LENGTH - 1 | |
BP_MAX_INPUT_LENGTH = CONTEXT_LENGTH - MAX_SUMMARY_LENGTH - BP_PROMPT_LENGTH - 1 | |
text_generator = TextGenerator(endpoint) | |
temperature = 0.7 | |
import re | |
def replace_newlines(text): | |
updated_text = re.sub(r'\n+', r'\n\n', text) | |
return updated_text | |
def generate_summary(summarizer, generated_text_box, input_, prompt, max_input_length, prefix=""): | |
all_generated_text = prefix | |
truncater = Truncater(tokenizer, max_length=max_input_length) | |
input_ = truncater(input_) | |
input_ = prompt.format(input=input_) | |
for generated_text in summarizer.generate_text_stream(input_, max_new_tokens=MAX_SUMMARY_LENGTH, do_sample=True, temperature=temperature): | |
all_generated_text += replace_newlines(generated_text) | |
generated_text_box.info(all_generated_text) | |
print(all_generated_text) | |
return all_generated_text.strip() | |
st.header("Demo: Intelligent Recap") | |
if not hasattr(st, 'global_state'): | |
st.global_state = {'NIPS 2021 Talks': None, 'TED Talks': None} | |
# NIPS 2021 Talks | |
transcript_files = itertools.islice(Path("demo_data/nips-2021/").rglob("transcript_whisper_large-v2.vtt"), 15) | |
# get titles from metadata.json | |
transcripts_map = {} | |
for transcript_file in transcript_files: | |
base_path = transcript_file.parent | |
metadata = base_path / "metadata.json" | |
txt_file = base_path / "transcript_whisper_large-v2.txt" | |
with open(metadata) as f: | |
metadata = json.load(f) | |
title = metadata["title"] | |
transcript = get_transcript(txt_file) | |
captions = get_captions_from_vtt(transcript_file) | |
transcripts_map[title] = {"transcript": transcript, "captions": captions, "video": base_path / "video.mp4"} | |
st.global_state['NIPS 2021 Talks'] = transcripts_map | |
data = pd.read_json("demo_data/ted_talks.json") | |
video_ids = data.talk_id.tolist() | |
transcripts = data.text.apply(lambda x: " ".join(x)).tolist() | |
transcripts_map = {} | |
for video_id, transcript in zip(video_ids, transcripts): | |
metadata = get_talk_metadata(video_id) | |
title = metadata["data"]["video"]["title"] | |
presenter = metadata["data"]["video"]["presenterDisplayName"] | |
print(metadata["data"]) | |
if metadata["data"]["video"]["nativeDownloads"] is None: | |
continue | |
video_url = metadata["data"]["video"]["nativeDownloads"]["medium"] | |
transcripts_map[title] = {"transcript": transcript, "video": video_url, "presenter": presenter} | |
st.global_state['TED Talks'] = transcripts_map | |
def get_lecture_id(path): | |
return int(path.parts[-2].split('-')[1]) | |
transcript_files = Path("demo_data/lectures/").rglob("English.vtt") | |
sorted_path_list = sorted(transcript_files, key=get_lecture_id) | |
transcripts_map = {} | |
for transcript_file in sorted_path_list: | |
base_path = transcript_file.parent | |
lecture_id = base_path.parts[-1] | |
transcript = " ".join([c["text"].strip() for c in get_captions_from_vtt(transcript_file)]).replace("\n", " ") | |
video_path = Path(base_path, "video.mp4") | |
transcripts_map["Machine Translation: " + lecture_id] = {"transcript": transcript, "video": video_path} | |
st.global_state['KIT Lectures'] = transcripts_map | |
type_of_document = st.selectbox('What kind of document do you want to test it on?', list(st.global_state.keys())) | |
transcripts_map = st.global_state[type_of_document] | |
selected_talk = st.selectbox("Choose a document...", list(transcripts_map.keys())) | |
st.video(str(transcripts_map[selected_talk]['video']), format="video/mp4", start_time=0) | |
input_text = st.text_area("Transcript", value=transcripts_map[selected_talk]['transcript'], height=300) | |
toc = Toc() | |
summarization_todos = [] | |
with st.expander("Adjust Thresholds"): | |
threshold = st.slider('Chapter Segmentation Threshold', 0.00, 1.00, value=0.4, step=0.05) | |
paragraphing_threshold = st.slider('Paragraphing Threshold', 0.00, 1.00, value=0.5, step=0.05) | |
if st.button("Process Transcript"): | |
with st.sidebar: | |
st.header("Table of Contents") | |
toc.placeholder() | |
st.header(selected_talk, divider='rainbow') | |
# if 'presenter' in transcripts_map[selected_talk]: | |
# st.markdown(f"### *by **{transcripts_map[selected_talk]['presenter']}***") | |
captions = transcripts_map[selected_talk]['captions'] if 'captions' in transcripts_map[selected_talk] else None | |
result = client.segment(input_text, captions, generate_titles=True, threshold=threshold) | |
if USE_PARAGRAPHING_MODEL: | |
presult = paragrapher.segment(input_text, captions, generate_titles=False, threshold=paragraphing_threshold) | |
paragraphs = presult['segments'] | |
segments, titles, sentences = result['segments'], result['titles'], result['sentences'] | |
if USE_PARAGRAPHING_MODEL: | |
prev_chapter_idx = 0 | |
prev_paragraph_idx = 0 | |
segment = [] | |
for i, sentence in enumerate(sentences): | |
chapter, chapter_idx = get_sublist_by_flattened_index(segments, i) | |
paragraph, paragraph_idx = get_sublist_by_flattened_index(paragraphs, i) | |
if (chapter_idx != prev_chapter_idx and paragraph_idx == prev_paragraph_idx) or (paragraph_idx != prev_paragraph_idx and chapter_idx != prev_chapter_idx): | |
print("Chapter / Chapter & Paragraph") | |
segment_text = " ".join(segment) | |
toc.subheader(titles[prev_chapter_idx]) | |
if len(segment_text) > 1200: | |
generated_text_box = st.info("") | |
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022")) | |
elif len(segment_text) > 450: | |
generated_text_box = st.info("") | |
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH)) | |
st.write(segment_text) | |
segment = [] | |
elif paragraph_idx != prev_paragraph_idx and chapter_idx == prev_chapter_idx: | |
print("Paragraph") | |
segment.append("\n\n") | |
segment.append(sentence) | |
prev_chapter_idx = chapter_idx | |
prev_paragraph_idx = paragraph_idx | |
segment_text = " ".join(segment) | |
toc.subheader(titles[prev_chapter_idx]) | |
if len(segment_text) > 1200: | |
generated_text_box = st.info("") | |
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022")) | |
elif len(segment_text) > 450: | |
generated_text_box = st.info("") | |
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH)) | |
st.write(segment_text) | |
else: | |
segments = [" ".join([sentence for sentence in segment]) for segment in segments] | |
for title, segment in zip(titles, segments): | |
toc.subheader(title) | |
if len(segment) > 1200: | |
generated_text_box = st.info("") | |
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, BP_PROMPT, BP_MAX_INPUT_LENGTH, prefix="\u2022")) | |
elif len(segment) > 450: | |
generated_text_box = st.info("") | |
summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment, TLDR_PROMPT, TLDR_MAX_INPUT_LENGTH)) | |
st.write(segment) | |
toc.generate() | |
for summarization_todo in summarization_todos: | |
summarization_todo() | |