Spaces:
Sleeping
Sleeping
import torch | |
from langchain import PromptTemplate | |
from langchain.document_loaders import JSONLoader | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain | |
from langchain.llms import HuggingFaceHub | |
import yt_dlp | |
import json | |
import gc | |
import gradio as gr | |
from gradio_client import Client | |
import datetime | |
whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/' | |
whisper_jax = Client(whisper_jax_api) | |
def transcribe_audio(audio_path, task='transcribe', return_timestamps=True): | |
text, runtime = whisper_jax.predict( | |
audio_path, | |
task, | |
return_timestamps, | |
api_name='/predict_1', | |
) | |
return text | |
def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) -> list: | |
''' | |
Returns a list of dict with keys 'start', 'end', 'text' | |
The segments from whisper jax output are merged to form paragraphs. | |
`max_duration` controls how many seconds of the audio's transcripts are merged | |
For example, if `max_duration`=60, in the final output, each segment is roughly | |
60 seconds. | |
''' | |
final_output = [] | |
max_duration = datetime.timedelta(seconds=max_duration) | |
segments = whisper_jax_output.split('\n') | |
current_start = datetime.datetime.strptime('00:00', '%M:%S') | |
current_text = '' | |
for i, seg in enumerate(segments): | |
text = seg.split(']')[-1].strip() | |
end = datetime.datetime.strptime(seg[14:19], '%M:%S') | |
if (end - current_start > max_duration) or (i == len(segments)-1): | |
# If we have exceeded max duration or | |
# at the last segment, stop merging | |
# and append to final_output | |
current_text += text | |
final_output.append({'start': current_start.strftime('%H:%M:%S'), | |
'end': end.strftime('%H:%M:%S'), | |
'text': current_text | |
}) | |
# Update current start and text | |
current_start = end | |
current_text = '' | |
else: | |
# If we have not exceeded max duration, | |
# keep merging. | |
current_text += text | |
return final_output | |
audio_file_number = 1 | |
def yt_audio_to_text(url: str, | |
max_duration: int = 60 | |
): | |
global audio_file_number | |
global progress | |
progress = gr.Progress() | |
progress(0.1) | |
with yt_dlp.YoutubeDL({'extract_audio': True, | |
'format': 'bestaudio', | |
'outtmpl': f'{audio_file_number}.mp3'}) as video: | |
info_dict = video.extract_info(url, download=False) | |
global video_title | |
video_title = info_dict['title'] | |
video.download(url) | |
progress(0.4) | |
audio_file = f'{audio_file_number}.mp3' | |
audio_file_number += 1 | |
result = transcribe_audio(audio_file, return_timestamps=True) | |
progress(0.7) | |
result = format_whisper_jax_output(result, max_duration=max_duration) | |
progress(0.9) | |
with open('audio.json', 'w') as f: | |
json.dump(result, f) | |
def metadata_func(record: dict, metadata: dict) -> dict: | |
metadata['start'] = record.get('start') | |
metadata['end'] = record.get('end') | |
metadata['source'] = metadata['start'] + '->' + metadata['end'] | |
return metadata | |
def load_data(): | |
loader = JSONLoader( | |
file_path='audio.json', | |
jq_schema='.[]', | |
content_key='text', | |
metadata_func=metadata_func | |
) | |
data = loader.load() | |
return data | |
embedding_model_name = 'sentence-transformers/all-mpnet-base-v2' | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
embedding_model_kwargs = {'device': device} | |
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name, | |
model_kwargs=embedding_model_kwargs) | |
def create_vectordb(data, k: int): | |
''' | |
`k` is the number of retrieved documents | |
''' | |
vectordb = Chroma.from_documents(documents=data, embedding=embeddings) | |
retriever = vectordb.as_retriever(search_type='similarity', | |
search_kwargs={'k': k}) | |
return vectordb, retriever | |
repo_id = 'mistralai/Mistral-7B-Instruct-v0.1' | |
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_length': 1024}) | |
# Map | |
map_template = """Summarise the following text: | |
{docs} | |
Answer:""" | |
map_prompt = PromptTemplate.from_template(map_template) | |
map_chain = LLMChain(llm=llm, prompt=map_prompt) | |
# Reduce | |
reduce_template = """The following is a set of summaries: | |
{docs} | |
Take these and distill it into a final, consolidated summary of the main themes. | |
Answer:""" | |
reduce_prompt = PromptTemplate.from_template(reduce_template) | |
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) | |
# Takes a list of documents, combines them into a single string, and passes this to llm | |
combine_documents_chain = StuffDocumentsChain( | |
llm_chain=reduce_chain, document_variable_name="docs" | |
) | |
# Combines and iteravely reduces the mapped documents | |
reduce_documents_chain = ReduceDocumentsChain( | |
# This is final chain that is called. | |
combine_documents_chain=combine_documents_chain, | |
# If documents exceed context for `StuffDocumentsChain` | |
collapse_documents_chain=combine_documents_chain, | |
# The maximum number of tokens to group documents into. | |
token_max=4000, | |
) | |
# Combining documents by mapping a chain over them, then combining results | |
map_reduce_chain = MapReduceDocumentsChain( | |
# Map chain | |
llm_chain=map_chain, | |
# Reduce chain | |
reduce_documents_chain=reduce_documents_chain, | |
# The variable name in the llm_chain to put the documents in | |
document_variable_name="docs", | |
# Return the results of the map steps in the output | |
return_intermediate_steps=False, | |
) | |
def get_summary(): | |
summary = map_reduce_chain.run(data) | |
return summary | |
contextualise_q_prompt = PromptTemplate.from_template( | |
'''Given a chat history and the latest user question \ | |
which might reference the chat history, formulate a standalone question \ | |
which can be understood without the chat history. Do NOT answer the question, \ | |
just reformulate it if needed and otherwise return it as is. | |
Chat history: {chat_history} | |
Question: {question} | |
Answer: | |
''' | |
) | |
contextualise_q_chain = contextualise_q_prompt | llm | |
standalone_prompt = PromptTemplate.from_template( | |
'''Given a chat history and the latest user question, \ | |
identify whether the question is a standalone question or the question \ | |
references the chat history. Answer 'yes' if the question is a standalone \ | |
question, and 'no' if the question references the chat history. Do not \ | |
answer anything other than 'yes' or 'no'. | |
Chat history: | |
{chat_history} | |
Question: | |
{question} | |
Answer: | |
''' | |
) | |
def format_output(answer: str) -> str: | |
# All lower case and remove all whitespace | |
return ''.join(answer.lower().split()) | |
standalone_chain = standalone_prompt | llm | format_output | |
qa_prompt = PromptTemplate.from_template( | |
'''You are an assistant for question-answering tasks. \ | |
ONLY use the following context to answer the question. \ | |
Do NOT answer with information that is not contained in \ | |
the context. If you don't know the answer, just say:\ | |
"Sorry, I cannot find the answer to that question in the video." | |
Context: | |
{context} | |
Question: | |
{question} | |
Answer: | |
''' | |
) | |
def format_docs(docs: list) -> str: | |
''' | |
Combine documents | |
''' | |
global sources | |
sources = [doc.metadata['start'] for doc in docs] | |
return '\n\n'.join(doc.page_content for doc in docs) | |
def standalone_question(input_: dict) -> str: | |
''' | |
If the question is a not a standalone question, run contextualise_q_chain | |
''' | |
if input_['standalone']=='yes': | |
return contextualise_q_chain | |
else: | |
return input_['question'] | |
def format_answer(answer: str, | |
n_sources: int=1, | |
timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> str: | |
if 'cannot find the answer' in answer: | |
return answer.strip() | |
else: | |
timestamps = filter_timestamps(n_sources, timestamp_interval) | |
answer_with_sources = (answer.strip() | |
+ ' You can find more information at these timestamps: {}.'.format(', '.join(timestamps)) | |
) | |
return answer_with_sources | |
def filter_timestamps(n_sources: int, | |
timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> list: | |
'''Returns a list of timestamps with length `n_sources`. | |
The timestamps are at least an `timestamp_interval` apart. | |
This prevents returning a list of timestamps that are too | |
close together. | |
''' | |
sorted_timestamps = sorted(sources) | |
output = [sorted_timestamps[0]] | |
i=1 | |
while len(output)<n_sources: | |
timestamp1 = datetime.datetime.strptime(output[-1], '%H:%M:%S') | |
try: | |
timestamp2 = datetime.datetime.strptime(sorted_timestamps[i], '%H:%M:%S') | |
except IndexError: | |
break | |
time_diff = timestamp2 - timestamp1 | |
if time_diff>timestamp_interval: | |
output.append(str(timestamp2.time())) | |
i += 1 | |
return output | |
def setup_rag(url): | |
'''Given a YouTube url, set up the vector database and the RAG chain. | |
''' | |
yt_audio_to_text(url) | |
global data | |
data = load_data() | |
global retriever | |
_, retriever = create_vectordb(data, k) | |
global rag_chain | |
rag_chain = ( | |
RunnablePassthrough.assign(standalone=standalone_chain) | |
| {'question':standalone_question, | |
'context':standalone_question|retriever|format_docs | |
} | |
| qa_prompt | |
| llm | |
) | |
return url | |
def get_answer(question: str) -> str: | |
global chat_history | |
ai_msg = rag_chain.invoke({'question': question, | |
'chat_history': chat_history | |
}) | |
answer = format_answer(ai_msg, n_sources, timestamp_interval) | |
chat_history.extend([HumanMessage(content=question), | |
AIMessage(content=answer)]) | |
return answer | |
# Chatbot settings | |
n_sources = 3 # Number of sources provided in the answer | |
k = 5 # Number of documents returned by the retriever | |
timestamp_interval = datetime.timedelta(minutes=2) | |
default_youtube_url = 'https://www.youtube.com/watch?v=4Bdc55j80l8' | |
def greet(): | |
summary = get_summary() | |
global gradio_chat_history | |
summary_message = f'Here is a summary of the video "{video_title}":' | |
gradio_chat_history.append((None, summary_message)) | |
gradio_chat_history.append((None, summary)) | |
greeting_message = f'You can ask me anything about the video. I will do my best to answer!' | |
gradio_chat_history.append((None, greeting_message)) | |
return gradio_chat_history | |
def question(user_message): | |
global gradio_chat_history | |
gradio_chat_history.append((user_message, None)) | |
return gradio_chat_history | |
def respond(): | |
global gradio_chat_history | |
ai_message = get_answer(gradio_chat_history[-1][0]) | |
gradio_chat_history.append((None, ai_message)) | |
return '', gradio_chat_history | |
def clear_chat_history(): | |
global chat_history | |
global gradio_chat_history | |
chat_history = [] | |
gradio_chat_history = [] | |
chat_history = [] | |
gradio_chat_history = [] | |
with gr.Blocks() as demo: | |
# Structure | |
with gr.Row(): | |
url_input = gr.Textbox(value=default_youtube_url, | |
label='YouTube URL', | |
scale=5) | |
button = gr.Button(value='Go', scale=1) | |
chatbot = gr.Chatbot() | |
user_message = gr.Textbox(label='Ask a question:') | |
clear = gr.ClearButton([user_message, chatbot]) | |
# Actions | |
button.click(setup_rag, | |
inputs=[url_input], | |
outputs=[url_input], | |
trigger_mode='once').then(greet, | |
inputs=[], | |
outputs=[chatbot]) | |
user_message.submit(question, | |
inputs=[user_message], | |
outputs=[chatbot]).then(respond, | |
inputs=[], | |
outputs=[user_message, chatbot]) | |
clear.click(clear_chat_history) | |
demo.launch() |