Spaces:
Sleeping
Sleeping
import torch | |
from langchain.prompts import PromptTemplate | |
from langchain_community.document_loaders import JSONLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain | |
from langchain.memory.buffer_window import ConversationBufferWindowMemory | |
from langchain_community.llms import HuggingFaceHub | |
import yt_dlp | |
import json | |
import gc | |
import gradio as gr | |
from gradio_client import Client | |
import datetime | |
import os | |
whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/' | |
whisper_jax = Client(whisper_jax_api) | |
def transcribe_audio(audio_path, | |
task='transcribe', | |
return_timestamps=True) -> str: | |
text, runtime = whisper_jax.predict( | |
audio_path, | |
task, | |
return_timestamps, | |
api_name='/predict_1', | |
) | |
return text | |
def format_whisper_jax_output(whisper_jax_output: str, | |
max_duration: int = 60) -> list[dict]: | |
"""Returns a list of dict with keys 'start', 'end', 'text' | |
The segments from whisper jax output are merged to form paragraphs. | |
`max_duration` controls how many seconds of the audio's transcripts are merged | |
For example, if `max_duration`=60, in the final output, each segment is roughly | |
60 seconds. | |
""" | |
final_output = [] | |
max_duration = datetime.timedelta(seconds=max_duration) | |
segments = whisper_jax_output.split('\n') | |
current_start = datetime.datetime.strptime('00:00', '%M:%S') | |
current_text = '' | |
for i, seg in enumerate(segments): | |
text = seg.split(']')[-1].strip() | |
# Sometimes whisper jax returns None for timestamp | |
try: | |
end = datetime.datetime.strptime(seg[14:19], '%M:%S') | |
except ValueError: | |
end = current_start + max_duration | |
if (end-current_start >= max_duration) or (i == len(segments)-1): | |
# If we have exceeded max duration or at the last segment, | |
# stop merging and append to final_output. | |
current_text += text | |
final_output.append({ | |
'start': current_start.strftime('%H:%M:%S'), | |
'end': end.strftime('%H:%M:%S'), | |
'text': current_text | |
}) | |
# Update current start and text | |
current_start = end | |
current_text = '' | |
else: | |
# If we have not exceeded max duration, keep merging. | |
current_text += text | |
return final_output | |
audio_file_number = 1 | |
def yt_audio_to_text(url: str, | |
max_duration: int = 60 | |
): | |
global audio_file_number | |
progress = gr.Progress() | |
progress(0.1) | |
with yt_dlp.YoutubeDL({'extract_audio': True, | |
'format': 'bestaudio', | |
'outtmpl': f'{audio_file_number}.mp3' | |
}) as video: | |
info_dict = video.extract_info(url, download=False) | |
global video_title | |
video_title = info_dict['title'] | |
video.download(url) | |
progress(0.4) | |
audio_file = f'{audio_file_number}.mp3' | |
audio_file_number += 1 | |
result = transcribe_audio(audio_file, return_timestamps=True) | |
progress(0.7) | |
result = format_whisper_jax_output(result, max_duration=max_duration) | |
progress(0.9) | |
with open('audio.json', 'w') as f: | |
json.dump(result, f) | |
def metadata_func(record: dict, metadata: dict) -> dict: | |
metadata['start'] = record.get('start') | |
metadata['end'] = record.get('end') | |
metadata['source'] = metadata['start'] + ' -> ' + metadata['end'] | |
return metadata | |
def load_data(): | |
loader = JSONLoader( | |
file_path='audio.json', | |
jq_schema='.[]', | |
content_key='text', | |
metadata_func=metadata_func | |
) | |
data = loader.load() | |
return data | |
embedding_model_name = 'sentence-transformers/all-mpnet-base-v2' | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
embedding_model_kwargs = {'device': device} | |
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name, | |
model_kwargs=embedding_model_kwargs) | |
def create_vectordb(data, k: int): | |
"""Returns a vector database, and its retriever | |
`k` is the number of retrieved documents | |
""" | |
vectordb = Chroma.from_documents(documents=data, embedding=embeddings) | |
retriever = vectordb.as_retriever(search_type='similarity', | |
search_kwargs={'k': k}) | |
return vectordb, retriever | |
repo_id = 'mistralai/Mistral-7B-Instruct-v0.1' | |
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_new_tokens': 1000}) | |
# Map | |
map_template = """Summarise the following text: | |
{docs} | |
Answer:""" | |
map_prompt = PromptTemplate.from_template(map_template) | |
map_chain = LLMChain(llm=llm, prompt=map_prompt) | |
# Reduce | |
reduce_template = """The following is a set of summaries: | |
{docs} | |
Take these and distill it into a final, consolidated summary of the main themes \ | |
in 150 words or less. | |
Answer:""" | |
reduce_prompt = PromptTemplate.from_template(reduce_template) | |
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) | |
# Takes a list of documents, combines them into a single string, and passes this to llm | |
combine_documents_chain = StuffDocumentsChain( | |
llm_chain=reduce_chain, document_variable_name="docs" | |
) | |
# Combines and iteravely reduces the mapped documents | |
reduce_documents_chain = ReduceDocumentsChain( | |
# This is final chain that is called. | |
combine_documents_chain=combine_documents_chain, | |
# If documents exceed context for `StuffDocumentsChain` | |
collapse_documents_chain=combine_documents_chain, | |
# The maximum number of tokens to group documents into. | |
token_max=4000 | |
) | |
# Combining documents by mapping a chain over them, then combining results | |
map_reduce_chain = MapReduceDocumentsChain( | |
# Map chain | |
llm_chain=map_chain, | |
# Reduce chain | |
reduce_documents_chain=reduce_documents_chain, | |
# The variable name in the llm_chain to put the documents in | |
document_variable_name="docs", | |
# Return the results of the map steps in the output | |
return_intermediate_steps=False | |
) | |
def get_summary(documents) -> str: | |
summary = map_reduce_chain.invoke(documents, return_only_outputs=True) | |
return summary['output_text'].strip() | |
contextualise_q_prompt = PromptTemplate.from_template( | |
"""Given a chat history and the latest user question \ | |
which might reference the chat history, formulate a standalone question \ | |
that can be understood without the chat history. Do NOT answer the question, \ | |
just reformulate it if needed and otherwise return it as is. | |
Chat history: {chat_history} | |
Question: {question} | |
Answer: | |
""" | |
) | |
contextualise_q_chain = contextualise_q_prompt | llm | |
standalone_prompt = PromptTemplate.from_template( | |
"""Given a chat history and the latest user question, \ | |
identify whether the question is a standalone question or the question \ | |
references the chat history. Answer 'yes' if the question is a standalone \ | |
question, and 'no' if the question references the chat history. Do not \ | |
answer anything other than 'yes' or 'no'. | |
Chat history: | |
{chat_history} | |
Question: | |
{question} | |
Answer: | |
""" | |
) | |
def format_output(answer: str) -> str: | |
# All lower case and remove all whitespace | |
return ''.join(answer.lower().split()) | |
standalone_chain = standalone_prompt | llm | format_output | |
qa_prompt = PromptTemplate.from_template( | |
"""You are an assistant for question-answering tasks. \ | |
ONLY use the following context to answer the question. \ | |
Do NOT answer with information that is not contained in \ | |
the context. If you don't know the answer, just say:\ | |
"Sorry, I cannot find the answer to that question in the video." | |
Context: | |
{context} | |
Question: | |
{question} | |
Answer: | |
""" | |
) | |
class YouTubeChatbot: | |
def __init__(self, | |
n_sources: int, | |
k: int, | |
timestamp_interval: datetime.timedelta, | |
memory: int, | |
): | |
self.n_sources = n_sources | |
self.k = k | |
self.timestamp_interval = timestamp_interval | |
self.chat_history = ConversationBufferWindowMemory(k=memory) | |
def format_docs(self, docs: list) -> str: | |
"""Combine documents | |
""" | |
self.sources = [doc.metadata['start'] for doc in docs] | |
return '\n\n'.join(doc.page_content for doc in docs) | |
def standalone_question(self, input_: dict) -> str: | |
"""If the question is a not a standalone question, | |
run contextualise_q_chain. | |
""" | |
if input_['standalone']=='yes': | |
return contextualise_q_chain | |
else: | |
return input_['question'] | |
def format_answer(self, answer: str) -> str: | |
if 'cannot find the answer' in answer: | |
return answer.strip() | |
else: | |
timestamps = self.filter_timestamps() | |
answer_with_sources = ( | |
answer.strip() | |
+ ' You can find more information '\ | |
'at these timestamps: {}.'.format(', '.join(timestamps)) | |
) | |
return answer_with_sources | |
def filter_timestamps(self) -> list[str]: | |
"""Returns a list of timestamps with length `n_sources`. | |
The timestamps are at least an `timestamp_interval` apart. | |
This prevents returning a list of timestamps that are too | |
close together. | |
""" | |
sorted_timestamps = sorted(self.sources) | |
filtered_timestamps = [sorted_timestamps[0]] | |
i=1 | |
while len(filtered_timestamps) < self.n_sources: | |
timestamp1 = datetime.datetime.strptime(filtered_timestamps[-1], | |
'%H:%M:%S') | |
try: | |
timestamp2 = datetime.datetime.strptime(sorted_timestamps[i], | |
'%H:%M:%S') | |
except IndexError: | |
break | |
time_diff = timestamp2 - timestamp1 | |
if time_diff>=self.timestamp_interval: | |
filtered_timestamps.append(str(timestamp2.time())) | |
i += 1 | |
return filtered_timestamps | |
def setup_chatbot(self, url: str) -> str: | |
"""Given a YouTube url, set up the chatbot. | |
""" | |
yt_audio_to_text(url) | |
self.data = load_data() | |
_, self.retriever = create_vectordb(self.data, self.k) | |
self.qa_chain = ( | |
RunnablePassthrough.assign(standalone=standalone_chain) | |
| {'question':self.standalone_question, | |
'context':self.standalone_question|self.retriever|self.format_docs} | |
| qa_prompt | |
| llm) | |
return url | |
def get_answer(self, question: str) -> str: | |
try: | |
ai_msg = self.qa_chain.invoke({'question': question, | |
'chat_history': self.chat_history}) | |
except AttributeError: | |
raise AttributeError("You haven't setup the chatbot yet. " | |
"Setup the chatbot by calling the " | |
"instance method `setup_chatbot`.") | |
answer = self.format_answer(ai_msg) | |
self.chat_history.save_context({'question':question}, | |
{'answer':answer}) | |
return answer | |
class YouTubeChatbotApp(YouTubeChatbot): | |
def __init__(self, | |
n_sources: int, | |
k: int, | |
timestamp_interval: datetime.timedelta, | |
memory: int, | |
default_youtube_url: str | |
): | |
super().__init__(n_sources, k, timestamp_interval, memory) | |
self.default_youtube_url = default_youtube_url | |
self.gradio_chat_history = [] | |
def greet(self) -> list[tuple[str|None, str|None]]: | |
summary = get_summary(self.data) | |
summary_message = f'Here is a summary of the video "{video_title}":' | |
self.gradio_chat_history.append((None, summary_message)) | |
self.gradio_chat_history.append((None, summary)) | |
greeting_message = ('You can ask me anything about the video. ' | |
'I will do my best to answer!') | |
self.gradio_chat_history.append((None, greeting_message)) | |
return self.gradio_chat_history | |
def question(self, user_message: str) -> list[tuple[str|None, str|None]]: | |
self.gradio_chat_history.append((user_message, None)) | |
return '', self.gradio_chat_history | |
def respond(self) -> tuple[str, list[tuple[str|None, str|None]]]: | |
try: | |
ai_message = self.get_answer(self.gradio_chat_history[-1][0]) | |
except AttributeError: | |
raise gr.Error('You need to process the video ' | |
'first by pressing the `Go` button.') | |
self.gradio_chat_history.append((None, ai_message)) | |
return self.gradio_chat_history | |
def clear_chat_history(self) -> list: | |
self.chat_history.clear() | |
self.gradio_chat_history = [] | |
return self.gradio_chat_history | |
def launch(self, **kwargs): | |
with gr.Blocks() as demo: | |
# Structure | |
with gr.Row(): | |
url_input = gr.Textbox(value=self.default_youtube_url, | |
label='YouTube URL', | |
scale=5) | |
button = gr.Button(value='Go', scale=1) | |
chatbot = gr.Chatbot() | |
user_message = gr.Textbox(label='Ask a question:') | |
clear = gr.ClearButton([user_message, chatbot]) | |
# Actions | |
button.click(self.clear_chat_history, | |
inputs=[], | |
outputs=[chatbot], | |
trigger_mode='once' | |
).then(self.setup_chatbot, | |
inputs=[url_input], | |
outputs=[url_input] | |
).then(self.greet, | |
inputs=[], | |
outputs=[chatbot]) | |
user_message.submit(self.question, | |
inputs=[user_message], | |
outputs=[user_message, chatbot] | |
).then(self.respond, | |
inputs=[], | |
outputs=[chatbot]) | |
clear.click(self.clear_chat_history, inputs=[], outputs=[chatbot]) | |
demo.launch(**kwargs) | |
if __name__ == "__main__": | |
app = YouTubeChatbotApp(n_sources=3, | |
k=5, | |
timestamp_interval=datetime.timedelta(minutes=2), | |
memory=5, | |
default_youtube_url='https://www.youtube.com/watch?v=4Bdc55j80l8' | |
) | |
app.launch() |