youtube-chatbot / app.py
ww0's picture
Update app.py
ab9ce80 verified
raw
history blame
No virus
12.7 kB
import torch
from langchain import PromptTemplate
from langchain.document_loaders import JSONLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain
from langchain.llms import HuggingFaceHub
import yt_dlp
import json
import gc
import gradio as gr
from gradio_client import Client
import datetime
whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
whisper_jax = Client(whisper_jax_api)
def transcribe_audio(audio_path, task='transcribe', return_timestamps=True):
text, runtime = whisper_jax.predict(
audio_path,
task,
return_timestamps,
api_name='/predict_1',
)
return text
def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) -> list:
'''
Returns a list of dict with keys 'start', 'end', 'text'
The segments from whisper jax output are merged to form paragraphs.
`max_duration` controls how many seconds of the audio's transcripts are merged
For example, if `max_duration`=60, in the final output, each segment is roughly
60 seconds.
'''
final_output = []
max_duration = datetime.timedelta(seconds=max_duration)
segments = whisper_jax_output.split('\n')
current_start = datetime.datetime.strptime('00:00', '%M:%S')
current_text = ''
for i, seg in enumerate(segments):
text = seg.split(']')[-1].strip()
end = datetime.datetime.strptime(seg[14:19], '%M:%S')
if (end - current_start > max_duration) or (i == len(segments)-1):
# If we have exceeded max duration or
# at the last segment, stop merging
# and append to final_output
current_text += text
final_output.append({'start': current_start.strftime('%H:%M:%S'),
'end': end.strftime('%H:%M:%S'),
'text': current_text
})
# Update current start and text
current_start = end
current_text = ''
else:
# If we have not exceeded max duration,
# keep merging.
current_text += text
return final_output
audio_file_number = 1
def yt_audio_to_text(url: str,
max_duration: int = 60
):
global audio_file_number
global progress
progress = gr.Progress()
progress(0.1)
with yt_dlp.YoutubeDL({'extract_audio': True,
'format': 'bestaudio',
'outtmpl': f'{audio_file_number}.mp3'}) as video:
info_dict = video.extract_info(url, download=False)
global video_title
video_title = info_dict['title']
video.download(url)
progress(0.4)
audio_file = f'{audio_file_number}.mp3'
audio_file_number += 1
result = transcribe_audio(audio_file, return_timestamps=True)
progress(0.7)
result = format_whisper_jax_output(result, max_duration=max_duration)
progress(0.9)
with open('audio.json', 'w') as f:
json.dump(result, f)
def metadata_func(record: dict, metadata: dict) -> dict:
metadata['start'] = record.get('start')
metadata['end'] = record.get('end')
metadata['source'] = metadata['start'] + '->' + metadata['end']
return metadata
def load_data():
loader = JSONLoader(
file_path='audio.json',
jq_schema='.[]',
content_key='text',
metadata_func=metadata_func
)
data = loader.load()
return data
embedding_model_name = 'sentence-transformers/all-mpnet-base-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model_kwargs = {'device': device}
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name,
model_kwargs=embedding_model_kwargs)
def create_vectordb(data, k: int):
'''
`k` is the number of retrieved documents
'''
vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
retriever = vectordb.as_retriever(search_type='similarity',
search_kwargs={'k': k})
return vectordb, retriever
repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_length': 1024})
# Map
map_template = """Summarise the following text:
{docs}
Answer:"""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=llm, prompt=map_prompt)
# Reduce
reduce_template = """The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary of the main themes.
Answer:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
# Takes a list of documents, combines them into a single string, and passes this to llm
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs"
)
# Combines and iteravely reduces the mapped documents
reduce_documents_chain = ReduceDocumentsChain(
# This is final chain that is called.
combine_documents_chain=combine_documents_chain,
# If documents exceed context for `StuffDocumentsChain`
collapse_documents_chain=combine_documents_chain,
# The maximum number of tokens to group documents into.
token_max=4000,
)
# Combining documents by mapping a chain over them, then combining results
map_reduce_chain = MapReduceDocumentsChain(
# Map chain
llm_chain=map_chain,
# Reduce chain
reduce_documents_chain=reduce_documents_chain,
# The variable name in the llm_chain to put the documents in
document_variable_name="docs",
# Return the results of the map steps in the output
return_intermediate_steps=False,
)
def get_summary():
summary = map_reduce_chain.run(data)
return summary
contextualise_q_prompt = PromptTemplate.from_template(
'''Given a chat history and the latest user question \
which might reference the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is.
Chat history: {chat_history}
Question: {question}
Answer:
'''
)
contextualise_q_chain = contextualise_q_prompt | llm
standalone_prompt = PromptTemplate.from_template(
'''Given a chat history and the latest user question, \
identify whether the question is a standalone question or the question \
references the chat history. Answer 'yes' if the question is a standalone \
question, and 'no' if the question references the chat history. Do not \
answer anything other than 'yes' or 'no'.
Chat history:
{chat_history}
Question:
{question}
Answer:
'''
)
def format_output(answer: str) -> str:
# All lower case and remove all whitespace
return ''.join(answer.lower().split())
standalone_chain = standalone_prompt | llm | format_output
qa_prompt = PromptTemplate.from_template(
'''You are an assistant for question-answering tasks. \
ONLY use the following context to answer the question. \
Do NOT answer with information that is not contained in \
the context. If you don't know the answer, just say:\
"Sorry, I cannot find the answer to that question in the video."
Context:
{context}
Question:
{question}
Answer:
'''
)
def format_docs(docs: list) -> str:
'''
Combine documents
'''
global sources
sources = [doc.metadata['start'] for doc in docs]
return '\n\n'.join(doc.page_content for doc in docs)
def standalone_question(input_: dict) -> str:
'''
If the question is a not a standalone question, run contextualise_q_chain
'''
if input_['standalone']=='yes':
return contextualise_q_chain
else:
return input_['question']
def format_answer(answer: str,
n_sources: int=1,
timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> str:
if 'cannot find the answer' in answer:
return answer.strip()
else:
timestamps = filter_timestamps(n_sources, timestamp_interval)
answer_with_sources = (answer.strip()
+ ' You can find more information at these timestamps: {}.'.format(', '.join(timestamps))
)
return answer_with_sources
def filter_timestamps(n_sources: int,
timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> list:
'''Returns a list of timestamps with length `n_sources`.
The timestamps are at least an `timestamp_interval` apart.
This prevents returning a list of timestamps that are too
close together.
'''
sorted_timestamps = sorted(sources)
output = [sorted_timestamps[0]]
i=1
while len(output)<n_sources:
timestamp1 = datetime.datetime.strptime(output[-1], '%H:%M:%S')
try:
timestamp2 = datetime.datetime.strptime(sorted_timestamps[i], '%H:%M:%S')
except IndexError:
break
time_diff = timestamp2 - timestamp1
if time_diff>timestamp_interval:
output.append(str(timestamp2.time()))
i += 1
return output
def setup_rag(url):
'''Given a YouTube url, set up the vector database and the RAG chain.
'''
yt_audio_to_text(url)
global data
data = load_data()
global retriever
_, retriever = create_vectordb(data, k)
global rag_chain
rag_chain = (
RunnablePassthrough.assign(standalone=standalone_chain)
| {'question':standalone_question,
'context':standalone_question|retriever|format_docs
}
| qa_prompt
| llm
)
return url
def get_answer(question: str) -> str:
global chat_history
ai_msg = rag_chain.invoke({'question': question,
'chat_history': chat_history
})
answer = format_answer(ai_msg, n_sources, timestamp_interval)
chat_history.extend([HumanMessage(content=question),
AIMessage(content=answer)])
return answer
# Chatbot settings
n_sources = 3 # Number of sources provided in the answer
k = 5 # Number of documents returned by the retriever
timestamp_interval = datetime.timedelta(minutes=2)
default_youtube_url = 'https://www.youtube.com/watch?v=4Bdc55j80l8'
def greet():
summary = get_summary()
global gradio_chat_history
summary_message = f'Here is a summary of the video "{video_title}":'
gradio_chat_history.append((None, summary_message))
gradio_chat_history.append((None, summary))
greeting_message = f'You can ask me anything about the video. I will do my best to answer!'
gradio_chat_history.append((None, greeting_message))
return gradio_chat_history
def question(user_message):
global gradio_chat_history
gradio_chat_history.append((user_message, None))
return gradio_chat_history
def respond():
global gradio_chat_history
ai_message = get_answer(gradio_chat_history[-1][0])
gradio_chat_history.append((None, ai_message))
return '', gradio_chat_history
def clear_chat_history():
global chat_history
global gradio_chat_history
chat_history = []
gradio_chat_history = []
chat_history = []
gradio_chat_history = []
with gr.Blocks() as demo:
# Structure
with gr.Row():
url_input = gr.Textbox(value=default_youtube_url,
label='YouTube URL',
scale=5)
button = gr.Button(value='Go', scale=1)
chatbot = gr.Chatbot()
user_message = gr.Textbox(label='Ask a question:')
clear = gr.ClearButton([user_message, chatbot])
# Actions
button.click(setup_rag,
inputs=[url_input],
outputs=[url_input],
trigger_mode='once').then(greet,
inputs=[],
outputs=[chatbot])
user_message.submit(question,
inputs=[user_message],
outputs=[chatbot]).then(respond,
inputs=[],
outputs=[user_message, chatbot])
clear.click(clear_chat_history)
demo.launch()