import torch from langchain import PromptTemplate from langchain.document_loaders import JSONLoader from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Chroma from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnablePassthrough, RunnableLambda from langchain_core.messages import AIMessage, HumanMessage from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain from langchain.llms import HuggingFaceHub import yt_dlp import json import gc import gradio as gr from gradio_client import Client import datetime whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/' whisper_jax = Client(whisper_jax_api) def transcribe_audio(audio_path, task='transcribe', return_timestamps=True): text, runtime = whisper_jax.predict( audio_path, task, return_timestamps, api_name='/predict_1', ) return text def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) -> list: ''' Returns a list of dict with keys 'start', 'end', 'text' The segments from whisper jax output are merged to form paragraphs. `max_duration` controls how many seconds of the audio's transcripts are merged For example, if `max_duration`=60, in the final output, each segment is roughly 60 seconds. ''' final_output = [] max_duration = datetime.timedelta(seconds=max_duration) segments = whisper_jax_output.split('\n') current_start = datetime.datetime.strptime('00:00', '%M:%S') current_text = '' for i, seg in enumerate(segments): text = seg.split(']')[-1].strip() end = datetime.datetime.strptime(seg[14:19], '%M:%S') if (end - current_start > max_duration) or (i == len(segments)-1): # If we have exceeded max duration or # at the last segment, stop merging # and append to final_output current_text += text final_output.append({'start': current_start.strftime('%H:%M:%S'), 'end': end.strftime('%H:%M:%S'), 'text': current_text }) # Update current start and text current_start = end current_text = '' else: # If we have not exceeded max duration, # keep merging. current_text += text return final_output audio_file_number = 1 def yt_audio_to_text(url: str, max_duration: int = 60 ): global audio_file_number global progress progress = gr.Progress() progress(0.1) with yt_dlp.YoutubeDL({'extract_audio': True, 'format': 'bestaudio', 'outtmpl': f'{audio_file_number}.mp3'}) as video: info_dict = video.extract_info(url, download=False) global video_title video_title = info_dict['title'] video.download(url) progress(0.4) audio_file = f'{audio_file_number}.mp3' audio_file_number += 1 result = transcribe_audio(audio_file, return_timestamps=True) progress(0.7) result = format_whisper_jax_output(result, max_duration=max_duration) progress(0.9) with open('audio.json', 'w') as f: json.dump(result, f) def metadata_func(record: dict, metadata: dict) -> dict: metadata['start'] = record.get('start') metadata['end'] = record.get('end') metadata['source'] = metadata['start'] + '->' + metadata['end'] return metadata def load_data(): loader = JSONLoader( file_path='audio.json', jq_schema='.[]', content_key='text', metadata_func=metadata_func ) data = loader.load() return data embedding_model_name = 'sentence-transformers/all-mpnet-base-v2' device = 'cuda' if torch.cuda.is_available() else 'cpu' embedding_model_kwargs = {'device': device} embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name, model_kwargs=embedding_model_kwargs) def create_vectordb(data, k: int): ''' `k` is the number of retrieved documents ''' vectordb = Chroma.from_documents(documents=data, embedding=embeddings) retriever = vectordb.as_retriever(search_type='similarity', search_kwargs={'k': k}) return vectordb, retriever repo_id = 'mistralai/Mistral-7B-Instruct-v0.1' llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_length': 1024}) # Map map_template = """Summarise the following text: {docs} Answer:""" map_prompt = PromptTemplate.from_template(map_template) map_chain = LLMChain(llm=llm, prompt=map_prompt) # Reduce reduce_template = """The following is a set of summaries: {docs} Take these and distill it into a final, consolidated summary of the main themes. Answer:""" reduce_prompt = PromptTemplate.from_template(reduce_template) reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) # Takes a list of documents, combines them into a single string, and passes this to llm combine_documents_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name="docs" ) # Combines and iteravely reduces the mapped documents reduce_documents_chain = ReduceDocumentsChain( # This is final chain that is called. combine_documents_chain=combine_documents_chain, # If documents exceed context for `StuffDocumentsChain` collapse_documents_chain=combine_documents_chain, # The maximum number of tokens to group documents into. token_max=4000, ) # Combining documents by mapping a chain over them, then combining results map_reduce_chain = MapReduceDocumentsChain( # Map chain llm_chain=map_chain, # Reduce chain reduce_documents_chain=reduce_documents_chain, # The variable name in the llm_chain to put the documents in document_variable_name="docs", # Return the results of the map steps in the output return_intermediate_steps=False, ) def get_summary(): summary = map_reduce_chain.run(data) return summary contextualise_q_prompt = PromptTemplate.from_template( '''Given a chat history and the latest user question \ which might reference the chat history, formulate a standalone question \ which can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is. Chat history: {chat_history} Question: {question} Answer: ''' ) contextualise_q_chain = contextualise_q_prompt | llm standalone_prompt = PromptTemplate.from_template( '''Given a chat history and the latest user question, \ identify whether the question is a standalone question or the question \ references the chat history. Answer 'yes' if the question is a standalone \ question, and 'no' if the question references the chat history. Do not \ answer anything other than 'yes' or 'no'. Chat history: {chat_history} Question: {question} Answer: ''' ) def format_output(answer: str) -> str: # All lower case and remove all whitespace return ''.join(answer.lower().split()) standalone_chain = standalone_prompt | llm | format_output qa_prompt = PromptTemplate.from_template( '''You are an assistant for question-answering tasks. \ ONLY use the following context to answer the question. \ Do NOT answer with information that is not contained in \ the context. If you don't know the answer, just say:\ "Sorry, I cannot find the answer to that question in the video." Context: {context} Question: {question} Answer: ''' ) def format_docs(docs: list) -> str: ''' Combine documents ''' global sources sources = [doc.metadata['start'] for doc in docs] return '\n\n'.join(doc.page_content for doc in docs) def standalone_question(input_: dict) -> str: ''' If the question is a not a standalone question, run contextualise_q_chain ''' if input_['standalone']=='yes': return contextualise_q_chain else: return input_['question'] def format_answer(answer: str, n_sources: int=1, timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> str: if 'cannot find the answer' in answer: return answer.strip() else: timestamps = filter_timestamps(n_sources, timestamp_interval) answer_with_sources = (answer.strip() + ' You can find more information at these timestamps: {}.'.format(', '.join(timestamps)) ) return answer_with_sources def filter_timestamps(n_sources: int, timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> list: '''Returns a list of timestamps with length `n_sources`. The timestamps are at least an `timestamp_interval` apart. This prevents returning a list of timestamps that are too close together. ''' sorted_timestamps = sorted(sources) output = [sorted_timestamps[0]] i=1 while len(output)timestamp_interval: output.append(str(timestamp2.time())) i += 1 return output def setup_rag(url): '''Given a YouTube url, set up the vector database and the RAG chain. ''' yt_audio_to_text(url) global data data = load_data() global retriever _, retriever = create_vectordb(data, k) global rag_chain rag_chain = ( RunnablePassthrough.assign(standalone=standalone_chain) | {'question':standalone_question, 'context':standalone_question|retriever|format_docs } | qa_prompt | llm ) return url def get_answer(question: str) -> str: global chat_history ai_msg = rag_chain.invoke({'question': question, 'chat_history': chat_history }) answer = format_answer(ai_msg, n_sources, timestamp_interval) chat_history.extend([HumanMessage(content=question), AIMessage(content=answer)]) return answer # Chatbot settings n_sources = 3 # Number of sources provided in the answer k = 5 # Number of documents returned by the retriever timestamp_interval = datetime.timedelta(minutes=2) default_youtube_url = 'https://www.youtube.com/watch?v=4Bdc55j80l8' def greet(): summary = get_summary() global gradio_chat_history summary_message = f'Here is a summary of the video "{video_title}":' gradio_chat_history.append((None, summary_message)) gradio_chat_history.append((None, summary)) greeting_message = f'You can ask me anything about the video. I will do my best to answer!' gradio_chat_history.append((None, greeting_message)) return gradio_chat_history def question(user_message): global gradio_chat_history gradio_chat_history.append((user_message, None)) return gradio_chat_history def respond(): global gradio_chat_history ai_message = get_answer(gradio_chat_history[-1][0]) gradio_chat_history.append((None, ai_message)) return '', gradio_chat_history def clear_chat_history(): global chat_history global gradio_chat_history chat_history = [] gradio_chat_history = [] chat_history = [] gradio_chat_history = [] with gr.Blocks() as demo: # Structure with gr.Row(): url_input = gr.Textbox(value=default_youtube_url, label='YouTube URL', scale=5) button = gr.Button(value='Go', scale=1) chatbot = gr.Chatbot() user_message = gr.Textbox(label='Ask a question:') clear = gr.ClearButton([user_message, chatbot]) # Actions button.click(setup_rag, inputs=[url_input], outputs=[url_input], trigger_mode='once').then(greet, inputs=[], outputs=[chatbot]) user_message.submit(question, inputs=[user_message], outputs=[chatbot]).then(respond, inputs=[], outputs=[user_message, chatbot]) clear.click(clear_chat_history) demo.launch()