youtube-chatbot / app.py
ww0's picture
Update app.py
7fea2c0 verified
raw
history blame
No virus
15.4 kB
import torch
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import JSONLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain_community.llms import HuggingFaceHub
import yt_dlp
import json
import gc
import gradio as gr
from gradio_client import Client
import datetime
import os
whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
whisper_jax = Client(whisper_jax_api)
def transcribe_audio(audio_path,
task='transcribe',
return_timestamps=True) -> str:
text, runtime = whisper_jax.predict(
audio_path,
task,
return_timestamps,
api_name='/predict_1',
)
return text
def format_whisper_jax_output(whisper_jax_output: str,
max_duration: int = 60) -> list[dict]:
"""Returns a list of dict with keys 'start', 'end', 'text'
The segments from whisper jax output are merged to form paragraphs.
`max_duration` controls how many seconds of the audio's transcripts are merged
For example, if `max_duration`=60, in the final output, each segment is roughly
60 seconds.
"""
final_output = []
max_duration = datetime.timedelta(seconds=max_duration)
segments = whisper_jax_output.split('\n')
current_start = datetime.datetime.strptime('00:00', '%M:%S')
current_text = ''
for i, seg in enumerate(segments):
text = seg.split(']')[-1].strip()
# Sometimes whisper jax returns None for timestamp
try:
end = datetime.datetime.strptime(seg[14:19], '%M:%S')
except ValueError:
end = current_start + max_duration
if (end-current_start >= max_duration) or (i == len(segments)-1):
# If we have exceeded max duration or at the last segment,
# stop merging and append to final_output.
current_text += text
final_output.append({
'start': current_start.strftime('%H:%M:%S'),
'end': end.strftime('%H:%M:%S'),
'text': current_text
})
# Update current start and text
current_start = end
current_text = ''
else:
# If we have not exceeded max duration, keep merging.
current_text += text
return final_output
audio_file_number = 1
def yt_audio_to_text(url: str,
max_duration: int = 60
):
global audio_file_number
progress = gr.Progress()
progress(0.1)
with yt_dlp.YoutubeDL({'extract_audio': True,
'format': 'bestaudio',
'outtmpl': f'{audio_file_number}.mp3'
}) as video:
info_dict = video.extract_info(url, download=False)
global video_title
video_title = info_dict['title']
video.download(url)
progress(0.4)
audio_file = f'{audio_file_number}.mp3'
audio_file_number += 1
result = transcribe_audio(audio_file, return_timestamps=True)
progress(0.7)
result = format_whisper_jax_output(result, max_duration=max_duration)
progress(0.9)
with open('audio.json', 'w') as f:
json.dump(result, f)
def metadata_func(record: dict, metadata: dict) -> dict:
metadata['start'] = record.get('start')
metadata['end'] = record.get('end')
metadata['source'] = metadata['start'] + ' -> ' + metadata['end']
return metadata
def load_data():
loader = JSONLoader(
file_path='audio.json',
jq_schema='.[]',
content_key='text',
metadata_func=metadata_func
)
data = loader.load()
return data
embedding_model_name = 'sentence-transformers/all-mpnet-base-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model_kwargs = {'device': device}
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name,
model_kwargs=embedding_model_kwargs)
def create_vectordb(data, k: int):
"""Returns a vector database, and its retriever
`k` is the number of retrieved documents
"""
vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
retriever = vectordb.as_retriever(search_type='similarity',
search_kwargs={'k': k})
return vectordb, retriever
repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_new_tokens': 1000})
# Map
map_template = """Summarise the following text:
{docs}
Answer:"""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=llm, prompt=map_prompt)
# Reduce
reduce_template = """The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary of the main themes \
in 150 words or less.
Answer:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
# Takes a list of documents, combines them into a single string, and passes this to llm
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs"
)
# Combines and iteravely reduces the mapped documents
reduce_documents_chain = ReduceDocumentsChain(
# This is final chain that is called.
combine_documents_chain=combine_documents_chain,
# If documents exceed context for `StuffDocumentsChain`
collapse_documents_chain=combine_documents_chain,
# The maximum number of tokens to group documents into.
token_max=4000
)
# Combining documents by mapping a chain over them, then combining results
map_reduce_chain = MapReduceDocumentsChain(
# Map chain
llm_chain=map_chain,
# Reduce chain
reduce_documents_chain=reduce_documents_chain,
# The variable name in the llm_chain to put the documents in
document_variable_name="docs",
# Return the results of the map steps in the output
return_intermediate_steps=False
)
def get_summary(documents) -> str:
summary = map_reduce_chain.invoke(documents, return_only_outputs=True)
return summary['output_text'].strip()
contextualise_q_prompt = PromptTemplate.from_template(
"""Given a chat history and the latest user question \
which might reference the chat history, formulate a standalone question \
that can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is.
Chat history: {chat_history}
Question: {question}
Answer:
"""
)
contextualise_q_chain = contextualise_q_prompt | llm
standalone_prompt = PromptTemplate.from_template(
"""Given a chat history and the latest user question, \
identify whether the question is a standalone question or the question \
references the chat history. Answer 'yes' if the question is a standalone \
question, and 'no' if the question references the chat history. Do not \
answer anything other than 'yes' or 'no'.
Chat history:
{chat_history}
Question:
{question}
Answer:
"""
)
def format_output(answer: str) -> str:
# All lower case and remove all whitespace
return ''.join(answer.lower().split())
standalone_chain = standalone_prompt | llm | format_output
qa_prompt = PromptTemplate.from_template(
"""You are an assistant for question-answering tasks. \
ONLY use the following context to answer the question. \
Do NOT answer with information that is not contained in \
the context. If you don't know the answer, just say:\
"Sorry, I cannot find the answer to that question in the video."
Context:
{context}
Question:
{question}
Answer:
"""
)
class YouTubeChatbot:
def __init__(self,
n_sources: int,
k: int,
timestamp_interval: datetime.timedelta,
memory: int,
):
self.n_sources = n_sources
self.k = k
self.timestamp_interval = timestamp_interval
self.chat_history = ConversationBufferWindowMemory(k=memory)
def format_docs(self, docs: list) -> str:
"""Combine documents
"""
self.sources = [doc.metadata['start'] for doc in docs]
return '\n\n'.join(doc.page_content for doc in docs)
def standalone_question(self, input_: dict) -> str:
"""If the question is a not a standalone question,
run contextualise_q_chain.
"""
if input_['standalone']=='yes':
return contextualise_q_chain
else:
return input_['question']
def format_answer(self, answer: str) -> str:
if 'cannot find the answer' in answer:
return answer.strip()
else:
timestamps = self.filter_timestamps()
answer_with_sources = (
answer.strip()
+ ' You can find more information '\
'at these timestamps: {}.'.format(', '.join(timestamps))
)
return answer_with_sources
def filter_timestamps(self) -> list[str]:
"""Returns a list of timestamps with length `n_sources`.
The timestamps are at least an `timestamp_interval` apart.
This prevents returning a list of timestamps that are too
close together.
"""
sorted_timestamps = sorted(self.sources)
filtered_timestamps = [sorted_timestamps[0]]
i=1
while len(filtered_timestamps) < self.n_sources:
timestamp1 = datetime.datetime.strptime(filtered_timestamps[-1],
'%H:%M:%S')
try:
timestamp2 = datetime.datetime.strptime(sorted_timestamps[i],
'%H:%M:%S')
except IndexError:
break
time_diff = timestamp2 - timestamp1
if time_diff>=self.timestamp_interval:
filtered_timestamps.append(str(timestamp2.time()))
i += 1
return filtered_timestamps
def setup_chatbot(self, url: str) -> str:
"""Given a YouTube url, set up the chatbot.
"""
yt_audio_to_text(url)
self.data = load_data()
_, self.retriever = create_vectordb(self.data, self.k)
self.qa_chain = (
RunnablePassthrough.assign(standalone=standalone_chain)
| {'question':self.standalone_question,
'context':self.standalone_question|self.retriever|self.format_docs}
| qa_prompt
| llm)
return url
def get_answer(self, question: str) -> str:
try:
ai_msg = self.qa_chain.invoke({'question': question,
'chat_history': self.chat_history})
except AttributeError:
raise AttributeError("You haven't setup the chatbot yet. "
"Setup the chatbot by calling the "
"instance method `setup_chatbot`.")
answer = self.format_answer(ai_msg)
self.chat_history.save_context({'question':question},
{'answer':answer})
return answer
class YouTubeChatbotApp(YouTubeChatbot):
def __init__(self,
n_sources: int,
k: int,
timestamp_interval: datetime.timedelta,
memory: int,
default_youtube_url: str
):
super().__init__(n_sources, k, timestamp_interval, memory)
self.default_youtube_url = default_youtube_url
self.gradio_chat_history = []
def greet(self) -> list[tuple[str|None, str|None]]:
summary = get_summary(self.data)
summary_message = f'Here is a summary of the video "{video_title}":'
self.gradio_chat_history.append((None, summary_message))
self.gradio_chat_history.append((None, summary))
greeting_message = ('You can ask me anything about the video. '
'I will do my best to answer!')
self.gradio_chat_history.append((None, greeting_message))
return self.gradio_chat_history
def question(self, user_message: str) -> list[tuple[str|None, str|None]]:
self.gradio_chat_history.append((user_message, None))
return '', self.gradio_chat_history
def respond(self) -> tuple[str, list[tuple[str|None, str|None]]]:
try:
ai_message = self.get_answer(self.gradio_chat_history[-1][0])
except AttributeError:
raise gr.Error('You need to process the video '
'first by pressing the `Go` button.')
self.gradio_chat_history.append((None, ai_message))
return self.gradio_chat_history
def clear_chat_history(self) -> list:
self.chat_history.clear()
self.gradio_chat_history = []
return self.gradio_chat_history
def launch(self, **kwargs):
with gr.Blocks() as demo:
# Structure
with gr.Row():
url_input = gr.Textbox(value=self.default_youtube_url,
label='YouTube URL',
scale=5)
button = gr.Button(value='Go', scale=1)
chatbot = gr.Chatbot()
user_message = gr.Textbox(label='Ask a question:')
clear = gr.ClearButton([user_message, chatbot])
# Actions
button.click(self.clear_chat_history,
inputs=[],
outputs=[chatbot],
trigger_mode='once'
).then(self.setup_chatbot,
inputs=[url_input],
outputs=[url_input]
).then(self.greet,
inputs=[],
outputs=[chatbot])
user_message.submit(self.question,
inputs=[user_message],
outputs=[user_message, chatbot]
).then(self.respond,
inputs=[],
outputs=[chatbot])
clear.click(self.clear_chat_history, inputs=[], outputs=[chatbot])
demo.launch(**kwargs)
if __name__ == "__main__":
app = YouTubeChatbotApp(n_sources=3,
k=5,
timestamp_interval=datetime.timedelta(minutes=2),
memory=5,
default_youtube_url='https://www.youtube.com/watch?v=4Bdc55j80l8'
)
app.launch()