youtube-chatbot / app.py
ww0's picture
Update app.py
78a328c verified
raw
history blame contribute delete
No virus
20 kB
import torch
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import JSONLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain_community.llms import HuggingFaceHub
from langchain.chains import (
LLMChain,
StuffDocumentsChain,
MapReduceDocumentsChain,
ReduceDocumentsChain,
)
from gradio_client import Client
import gradio as gr
import yt_dlp
import json
import gc
import datetime
import os
import numpy as np
"""Prepare data"""
whisper_jax_api = "https://sanchit-gandhi-whisper-jax.hf.space/"
whisper_jax = Client(whisper_jax_api)
def transcribe_audio(audio_path, task="transcribe", return_timestamps=True) -> str:
text, runtime = whisper_jax.predict(
audio_path,
task,
return_timestamps,
api_name="/predict_1",
)
return text
def format_whisper_jax_output(
whisper_jax_output: str, max_duration: int = 60
) -> list[dict]:
"""Whisper JAX outputs are in the format
'[00:00.000 -> 00:00.000] text\n[00:00.000 -> 00:00.000] text'.
Returns a list of dict with keys 'start', 'end', 'text'
The segments from whisper jax output are merged to form paragraphs.
`max_duration` controls how many seconds of the audio's transcripts are merged
For example, if `max_duration`=60, in the final output, each segment is roughly
60 seconds.
"""
final_output = []
max_duration = datetime.timedelta(seconds=max_duration)
segments = whisper_jax_output.split("\n")
current_start = datetime.datetime.strptime("00:00", "%M:%S")
current_text = ""
for i, seg in enumerate(segments):
text = seg.split("]")[-1].strip()
current_text += " " + text
# Sometimes whisper jax returns None for timestamp
try:
end = datetime.datetime.strptime(seg[14:19], "%M:%S")
except ValueError:
end = current_start + max_duration
if i == len(segments) - 1:
final_output.append(
{
"start": current_start.strftime("%H:%M:%S"),
"end": end.strftime("%H:%M:%S"),
"text": current_text.strip(),
}
)
else:
if end - current_start >= max_duration and current_text[-1] == ".":
# If we have exceeded max duration, check whether we have
# reached the end of a sentence. If not, keep merging.
final_output.append(
{
"start": current_start.strftime("%H:%M:%S"),
"end": end.strftime("%H:%M:%S"),
"text": current_text.strip(),
}
)
# Update current start and text
current_start = end
current_text = ""
return final_output
def yt_audio_to_text(url: str, max_duration: int = 60):
"""Given a YouTube url, download audio and transcribe it to text. Reformat
the output from Whisper JAX and save the final result in a json file.
"""
progress = gr.Progress()
progress(0.1)
with yt_dlp.YoutubeDL(
{"extract_audio": True, "format": "bestaudio", "outtmpl": "audio.mp3"}
) as video:
info_dict = video.extract_info(url, download=False)
global video_title
video_title = info_dict["title"]
video.download(url)
progress(0.4)
audio_file = "audio.mp3"
result = transcribe_audio(audio_file, return_timestamps=True)
progress(0.7)
result = format_whisper_jax_output(result, max_duration=max_duration)
progress(0.9)
with open("audio.json", "w") as f:
json.dump(result, f)
os.remove(audio_file)
"""Load data"""
def metadata_func(record: dict, metadata: dict) -> dict:
"""This function is used to tell the Langchain loader the keys that
contain metadata and extract them.
"""
metadata["start"] = record.get("start")
metadata["end"] = record.get("end")
metadata["source"] = metadata["start"] + " -> " + metadata["end"]
return metadata
def load_data():
loader = JSONLoader(
file_path="audio.json",
jq_schema=".[]",
content_key="text",
metadata_func=metadata_func,
)
data = loader.load()
os.remove("audio.json")
return data
"""Create embeddings and vector store"""
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
embedding_model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_name, model_kwargs=embedding_model_kwargs
)
def create_vectordb(data, n_retrieved_docs: int, collection_name="YouTube"):
"""Returns a retriever which is used to fetch relevant documents from
the vector database.
`n_retrieved_docs` is the number of retrieved documents.
"""
vectordb = Chroma.from_documents(
documents=data, embedding=embeddings, collection_name=collection_name
)
n_docs = len(vectordb.get()["ids"])
retriever = vectordb.as_retriever(
search_type="mmr", search_kwargs={"k": n_retrieved_docs, "fetch_k": n_docs}
)
return retriever
"""Load LLM"""
repo_id = "mistralai/Mistral-7B-Instruct-v0.1"
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"max_new_tokens": 1000})
"""Summarisation"""
# Map
map_template = """Summarise the following text:
{docs}
Answer:"""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=llm, prompt=map_prompt)
# Reduce
reduce_template = """The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary of the main themes \
in 150 words or less.
Answer:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
# Takes a list of documents, combines them into a single string, and passes this to llm
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs"
)
# Combines and iteravely reduces the mapped documents
reduce_documents_chain = ReduceDocumentsChain(
# This is final chain that is called.
combine_documents_chain=combine_documents_chain,
# If documents exceed context for `StuffDocumentsChain`
collapse_documents_chain=combine_documents_chain,
# The maximum number of tokens to group documents into.
token_max=4000,
)
# Combining documents by mapping a chain over them, then combining results
map_reduce_chain = MapReduceDocumentsChain(
# Map chain
llm_chain=map_chain,
# Reduce chain
reduce_documents_chain=reduce_documents_chain,
# The variable name in the llm_chain to put the documents in
document_variable_name="docs",
# Return the results of the map steps in the output
return_intermediate_steps=False,
)
def get_summary(documents) -> str:
summary = map_reduce_chain.invoke(documents, return_only_outputs=True)
return summary["output_text"].strip()
"""Contextualising the question"""
contextualise_q_prompt = PromptTemplate.from_template(
"""Given a chat history and the latest user question \
which might reference the chat history, formulate a \
standalone question that can be understood without \
the chat history. Do NOT answer the question, just \
reformulate it if needed and otherwise return it as is.
Chat history: {chat_history}
Question: {question}
Answer:
"""
)
contextualise_q_chain = contextualise_q_prompt | llm
"""Standalone question chain"""
standalone_prompt = PromptTemplate.from_template(
"""Given a chat history and the latest user question, \
identify whether the question is a standalone question \
or the question references the chat history. Answer 'yes' \
if the question is a standalone question, and 'no' if the \
question references the chat history. Do not answer \
anything other than 'yes' or 'no'.
Chat history:
{chat_history}
Question:
{question}
Answer:
"""
)
def format_output(answer: str) -> str:
"""All lower case and remove all whitespace to ensure
that the answer given by the LLM is either 'yes' or 'no'.
"""
return "".join(answer.lower().split())
standalone_chain = standalone_prompt | llm | format_output
"""Q&A chain"""
qa_prompt = PromptTemplate.from_template(
"""You are an assistant for question-answering tasks. \
ONLY use the following context to answer the question. \
Do NOT answer with information that is not contained in \
the context. If you don't know the answer, just say:\
"Sorry, I cannot find the answer to that question in the video."
Context:
{context}
Question:
{question}
Answer:
"""
)
class YouTubeChatbot:
instance_count = 0
def __init__(
self,
n_sources: int = 3,
n_retrieved_docs: int = 5,
timestamp_interval: datetime.timedelta = datetime.timedelta(minutes=2),
memory: int = 5,
):
YouTubeChatbot.instance_count += 1
self.chatbot_id = YouTubeChatbot.instance_count
self.n_sources = n_sources
self.n_retrieved_docs = n_retrieved_docs
self.timestamp_interval = timestamp_interval
self.chat_history = ConversationBufferWindowMemory(k=memory)
self.retriever = None
self.qa_chain = None
def format_docs(self, docs: list) -> str:
"""Combine documents into a single string which will be included
in the prompt given to the LLM.
"""
self.sources = [doc.metadata["start"] for doc in docs]
return "\n\n".join(doc.page_content for doc in docs)
def standalone_question(self, input_: dict) -> str:
"""If the question is a not a standalone question,
run contextualise_q_chain.
"""
if input_["standalone"] == "yes":
return contextualise_q_chain
else:
return input_["question"]
def format_answer(self, answer: str) -> str:
"""Add timestamps to answers.
"""
if "cannot find the answer" in answer:
return answer.strip()
else:
timestamps = self.filter_timestamps()
answer_with_sources = (
answer.strip() + " You can find more information "
"at these timestamps: {}.".format(", ".join(timestamps))
)
return answer_with_sources
def filter_timestamps(self) -> list[str]:
"""Returns a list of timestamps with length less or
equal to `n_sources`. The timestamps are at least an
`timestamp_interval` apart. This prevents returning
a list of timestamps that are too close together.
"""
filtered_timestamps = np.array(
[datetime.datetime.strptime(self.sources[0], "%H:%M:%S")]
)
i = 1
while len(filtered_timestamps) < self.n_sources:
try:
new_timestamp = datetime.datetime.strptime(self.sources[i], "%H:%M:%S")
except IndexError:
break
absolute_time_difference = abs(new_timestamp - filtered_timestamps)
if all(absolute_time_difference >= self.timestamp_interval):
filtered_timestamps = np.append(filtered_timestamps, new_timestamp)
i += 1
filtered_timestamps = [
timestamp.strftime("%H:%M:%S") for timestamp in filtered_timestamps
]
filtered_timestamps.sort()
return filtered_timestamps
def process_video(self, url: str, data=None, retriever=None):
"""Given a YouTube URL, transcribe YouTube audio to text.
Then set up the vector database.
"""
yt_audio_to_text(url)
data = load_data()
if retriever is not None:
# If we already have documents in the vector store, delete them.
ids = retriever.vectorstore.get()["ids"]
retriever.vectorstore.delete(ids)
retriever = create_vectordb(
data, self.n_retrieved_docs,
collection_name=f"Chatbot{self.chatbot_id}"
)
return url, data, retriever
def setup_qa_chain(self, retriever, qa_chain=None):
qa_chain = (
RunnablePassthrough.assign(standalone=standalone_chain)
| {
"question": self.standalone_question,
"context": self.standalone_question | retriever | self.format_docs,
}
| qa_prompt
| llm
)
return retriever, qa_chain
def setup_chatbot(self, url: str):
_, _, self.retriever = self.process_video(url=url, retriever=self.retriever)
_, self.qa_chain = self.setup_qa_chain(retriever=self.retriever)
def get_answer(self, question: str) -> str:
try:
ai_msg = self.qa_chain.invoke(
{"question": question, "chat_history": self.chat_history}
)
except AttributeError:
raise AttributeError(
"You haven't setup the chatbot yet. "
"Setup the chatbot by calling the "
"instance method `setup_chatbot`."
)
self.chat_history.save_context({"question": question}, {"answer": ai_msg})
answer = self.format_answer(ai_msg)
return answer
"""Web app"""
class YouTubeChatbotApp(YouTubeChatbot):
def __init__(
self,
n_sources: int,
n_retrieved_docs: int,
timestamp_interval: datetime.timedelta,
memory: int,
default_youtube_url: str,
):
super().__init__(n_sources, n_retrieved_docs, timestamp_interval, memory)
self.default_youtube_url = default_youtube_url
self.memory = memory
self.chat_history = None
self.data = None
self.retriever = None
self.qa_chain = None
# Gradio components
self.url_input = None
self.url_button = None
self.app_chat_history = None
self.chatbot = None
self.user_input = None
self.clear_button = None
def greet(self, data, app_chat_history) -> dict:
"""Summarise the video and greet the user.
"""
summary_message = f'Here is a summary of the video "{video_title}":'
app_chat_history.append((None, summary_message))
summary = get_summary(data)
self.data = gr.State(None)
app_chat_history.append((None, summary))
greeting_message = (
"You can ask me anything about the video. " "I will do my best to answer!"
)
app_chat_history.append((None, greeting_message))
return {self.app_chat_history: app_chat_history, self.chatbot: app_chat_history}
def question(self, user_question: str, app_chat_history) -> dict:
"""Display the question asked by the user in the chat window,
and delete from the input textbox.
"""
app_chat_history.append((user_question, None))
return {
self.user_input: "",
self.app_chat_history: app_chat_history,
self.chatbot: app_chat_history,
}
def respond(self, qa_chain, chat_history, app_chat_history) -> dict:
"""Respond to user's latest question"""
question = app_chat_history[-1][0]
try:
ai_msg = qa_chain.invoke(
{"question": question, "chat_history": chat_history}
)
except AttributeError:
raise gr.Error(
"You need to process the video " "first by pressing the `Go` button."
)
chat_history.save_context({"question": question}, {"answer": ai_msg})
answer = self.format_answer(ai_msg)
app_chat_history.append((None, answer))
return {
self.qa_chain: qa_chain,
self.chat_history: chat_history,
self.app_chat_history: app_chat_history,
self.chatbot: app_chat_history,
}
def clear_chat_history(self, chat_history, app_chat_history):
chat_history.clear()
app_chat_history = []
return {
self.chat_history: chat_history,
self.app_chat_history: app_chat_history,
self.chatbot: app_chat_history,
}
def launch(self, **kwargs):
with gr.Blocks() as demo:
self.chat_history = gr.State(ConversationBufferWindowMemory(k=self.memory))
self.app_chat_history = gr.State([])
self.data = gr.State()
self.retriever = gr.State()
self.qa_chain = gr.State()
# App structure
with gr.Row():
self.url_input = gr.Textbox(
value=self.default_youtube_url, label="YouTube URL", scale=5
)
self.url_button = gr.Button(value="Go", scale=1)
self.chatbot = gr.Chatbot()
self.user_input = gr.Textbox(label="Ask a question:")
self.clear_button = gr.Button(value="Clear")
# App actions
# When a new url is given, clear past chat history and process
# the new video. Set up the Q&A chain with the new video's data.
# Provide a summary of the new video.
self.url_button.click(
self.clear_chat_history,
inputs=[self.chat_history, self.app_chat_history],
outputs=[self.chat_history, self.app_chat_history, self.chatbot],
trigger_mode="once",
).then(
self.process_video,
inputs=[self.url_input, self.data, self.retriever],
outputs=[self.url_input, self.data, self.retriever],
).then(
self.setup_qa_chain,
inputs=[self.retriever, self.qa_chain],
outputs=[self.retriever, self.qa_chain],
).then(
self.greet,
inputs=[self.data, self.app_chat_history],
outputs=[self.app_chat_history, self.chatbot],
)
# When a user asks a question, display the question in the chat
# window and remove it from the text input area. Then respond
# with the Q&A chain.
self.user_input.submit(
self.question,
inputs=[self.user_input, self.app_chat_history],
outputs=[self.user_input, self.app_chat_history, self.chatbot],
queue=False,
).then(
self.respond,
inputs=[self.qa_chain, self.chat_history, self.app_chat_history],
outputs=[
self.qa_chain,
self.chat_history,
self.app_chat_history,
self.chatbot,
],
)
# When the `Clear` button is clicked, clear the chat history from
# the chat window.
self.clear_button.click(
self.clear_chat_history,
inputs=[self.chat_history, self.app_chat_history],
outputs=[self.chat_history, self.app_chat_history, self.chatbot],
queue=False,
)
demo.launch(**kwargs)
if __name__ == "__main__":
app = YouTubeChatbotApp(
n_sources=3,
n_retrieved_docs=5,
timestamp_interval=datetime.timedelta(minutes=2),
memory=5,
default_youtube_url="https://www.youtube.com/watch?v=SZorAJ4I-sA",
)
app.launch()