import torch from langchain.prompts import PromptTemplate from langchain_community.document_loaders import JSONLoader from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import Chroma from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnablePassthrough, RunnableLambda from langchain_core.messages import AIMessage, HumanMessage from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain from langchain.memory.buffer_window import ConversationBufferWindowMemory from langchain_community.llms import HuggingFaceHub import yt_dlp import json import gc import gradio as gr from gradio_client import Client import datetime import os whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/' whisper_jax = Client(whisper_jax_api) def transcribe_audio(audio_path, task='transcribe', return_timestamps=True) -> str: text, runtime = whisper_jax.predict( audio_path, task, return_timestamps, api_name='/predict_1', ) return text def format_whisper_jax_output(whisper_jax_output: str, max_duration: int = 60) -> list[dict]: """Returns a list of dict with keys 'start', 'end', 'text' The segments from whisper jax output are merged to form paragraphs. `max_duration` controls how many seconds of the audio's transcripts are merged For example, if `max_duration`=60, in the final output, each segment is roughly 60 seconds. """ final_output = [] max_duration = datetime.timedelta(seconds=max_duration) segments = whisper_jax_output.split('\n') current_start = datetime.datetime.strptime('00:00', '%M:%S') current_text = '' for i, seg in enumerate(segments): text = seg.split(']')[-1].strip() # Sometimes whisper jax returns None for timestamp try: end = datetime.datetime.strptime(seg[14:19], '%M:%S') except ValueError: end = current_start + max_duration if (end-current_start >= max_duration) or (i == len(segments)-1): # If we have exceeded max duration or at the last segment, # stop merging and append to final_output. current_text += text final_output.append({ 'start': current_start.strftime('%H:%M:%S'), 'end': end.strftime('%H:%M:%S'), 'text': current_text }) # Update current start and text current_start = end current_text = '' else: # If we have not exceeded max duration, keep merging. current_text += text return final_output audio_file_number = 1 def yt_audio_to_text(url: str, max_duration: int = 60 ): global audio_file_number progress = gr.Progress() progress(0.1) with yt_dlp.YoutubeDL({'extract_audio': True, 'format': 'bestaudio', 'outtmpl': f'{audio_file_number}.mp3' }) as video: info_dict = video.extract_info(url, download=False) global video_title video_title = info_dict['title'] video.download(url) progress(0.4) audio_file = f'{audio_file_number}.mp3' audio_file_number += 1 result = transcribe_audio(audio_file, return_timestamps=True) progress(0.7) result = format_whisper_jax_output(result, max_duration=max_duration) progress(0.9) with open('audio.json', 'w') as f: json.dump(result, f) def metadata_func(record: dict, metadata: dict) -> dict: metadata['start'] = record.get('start') metadata['end'] = record.get('end') metadata['source'] = metadata['start'] + ' -> ' + metadata['end'] return metadata def load_data(): loader = JSONLoader( file_path='audio.json', jq_schema='.[]', content_key='text', metadata_func=metadata_func ) data = loader.load() return data embedding_model_name = 'sentence-transformers/all-mpnet-base-v2' device = 'cuda' if torch.cuda.is_available() else 'cpu' embedding_model_kwargs = {'device': device} embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name, model_kwargs=embedding_model_kwargs) def create_vectordb(data, k: int): """Returns a vector database, and its retriever `k` is the number of retrieved documents """ vectordb = Chroma.from_documents(documents=data, embedding=embeddings) retriever = vectordb.as_retriever(search_type='similarity', search_kwargs={'k': k}) return vectordb, retriever repo_id = 'mistralai/Mistral-7B-Instruct-v0.1' llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_new_tokens': 1000}) # Map map_template = """Summarise the following text: {docs} Answer:""" map_prompt = PromptTemplate.from_template(map_template) map_chain = LLMChain(llm=llm, prompt=map_prompt) # Reduce reduce_template = """The following is a set of summaries: {docs} Take these and distill it into a final, consolidated summary of the main themes \ in 150 words or less. Answer:""" reduce_prompt = PromptTemplate.from_template(reduce_template) reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) # Takes a list of documents, combines them into a single string, and passes this to llm combine_documents_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name="docs" ) # Combines and iteravely reduces the mapped documents reduce_documents_chain = ReduceDocumentsChain( # This is final chain that is called. combine_documents_chain=combine_documents_chain, # If documents exceed context for `StuffDocumentsChain` collapse_documents_chain=combine_documents_chain, # The maximum number of tokens to group documents into. token_max=4000 ) # Combining documents by mapping a chain over them, then combining results map_reduce_chain = MapReduceDocumentsChain( # Map chain llm_chain=map_chain, # Reduce chain reduce_documents_chain=reduce_documents_chain, # The variable name in the llm_chain to put the documents in document_variable_name="docs", # Return the results of the map steps in the output return_intermediate_steps=False ) def get_summary(documents) -> str: summary = map_reduce_chain.invoke(documents, return_only_outputs=True) return summary['output_text'].strip() contextualise_q_prompt = PromptTemplate.from_template( """Given a chat history and the latest user question \ which might reference the chat history, formulate a standalone question \ that can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is. Chat history: {chat_history} Question: {question} Answer: """ ) contextualise_q_chain = contextualise_q_prompt | llm standalone_prompt = PromptTemplate.from_template( """Given a chat history and the latest user question, \ identify whether the question is a standalone question or the question \ references the chat history. Answer 'yes' if the question is a standalone \ question, and 'no' if the question references the chat history. Do not \ answer anything other than 'yes' or 'no'. Chat history: {chat_history} Question: {question} Answer: """ ) def format_output(answer: str) -> str: # All lower case and remove all whitespace return ''.join(answer.lower().split()) standalone_chain = standalone_prompt | llm | format_output qa_prompt = PromptTemplate.from_template( """You are an assistant for question-answering tasks. \ ONLY use the following context to answer the question. \ Do NOT answer with information that is not contained in \ the context. If you don't know the answer, just say:\ "Sorry, I cannot find the answer to that question in the video." Context: {context} Question: {question} Answer: """ ) class YouTubeChatbot: def __init__(self, n_sources: int, k: int, timestamp_interval: datetime.timedelta, memory: int, ): self.n_sources = n_sources self.k = k self.timestamp_interval = timestamp_interval self.chat_history = ConversationBufferWindowMemory(k=memory) def format_docs(self, docs: list) -> str: """Combine documents """ self.sources = [doc.metadata['start'] for doc in docs] return '\n\n'.join(doc.page_content for doc in docs) def standalone_question(self, input_: dict) -> str: """If the question is a not a standalone question, run contextualise_q_chain. """ if input_['standalone']=='yes': return contextualise_q_chain else: return input_['question'] def format_answer(self, answer: str) -> str: if 'cannot find the answer' in answer: return answer.strip() else: timestamps = self.filter_timestamps() answer_with_sources = ( answer.strip() + ' You can find more information '\ 'at these timestamps: {}.'.format(', '.join(timestamps)) ) return answer_with_sources def filter_timestamps(self) -> list[str]: """Returns a list of timestamps with length `n_sources`. The timestamps are at least an `timestamp_interval` apart. This prevents returning a list of timestamps that are too close together. """ sorted_timestamps = sorted(self.sources) filtered_timestamps = [sorted_timestamps[0]] i=1 while len(filtered_timestamps) < self.n_sources: timestamp1 = datetime.datetime.strptime(filtered_timestamps[-1], '%H:%M:%S') try: timestamp2 = datetime.datetime.strptime(sorted_timestamps[i], '%H:%M:%S') except IndexError: break time_diff = timestamp2 - timestamp1 if time_diff>=self.timestamp_interval: filtered_timestamps.append(str(timestamp2.time())) i += 1 return filtered_timestamps def setup_chatbot(self, url: str) -> str: """Given a YouTube url, set up the chatbot. """ yt_audio_to_text(url) self.data = load_data() _, self.retriever = create_vectordb(self.data, self.k) self.qa_chain = ( RunnablePassthrough.assign(standalone=standalone_chain) | {'question':self.standalone_question, 'context':self.standalone_question|self.retriever|self.format_docs} | qa_prompt | llm) return url def get_answer(self, question: str) -> str: try: ai_msg = self.qa_chain.invoke({'question': question, 'chat_history': self.chat_history}) except AttributeError: raise AttributeError("You haven't setup the chatbot yet. " "Setup the chatbot by calling the " "instance method `setup_chatbot`.") answer = self.format_answer(ai_msg) self.chat_history.save_context({'question':question}, {'answer':answer}) return answer class YouTubeChatbotApp(YouTubeChatbot): def __init__(self, n_sources: int, k: int, timestamp_interval: datetime.timedelta, memory: int, default_youtube_url: str ): super().__init__(n_sources, k, timestamp_interval, memory) self.default_youtube_url = default_youtube_url self.gradio_chat_history = [] def greet(self) -> list[tuple[str|None, str|None]]: summary = get_summary(self.data) summary_message = f'Here is a summary of the video "{video_title}":' self.gradio_chat_history.append((None, summary_message)) self.gradio_chat_history.append((None, summary)) greeting_message = ('You can ask me anything about the video. ' 'I will do my best to answer!') self.gradio_chat_history.append((None, greeting_message)) return self.gradio_chat_history def question(self, user_message: str) -> list[tuple[str|None, str|None]]: self.gradio_chat_history.append((user_message, None)) return '', self.gradio_chat_history def respond(self) -> tuple[str, list[tuple[str|None, str|None]]]: try: ai_message = self.get_answer(self.gradio_chat_history[-1][0]) except AttributeError: raise gr.Error('You need to process the video ' 'first by pressing the `Go` button.') self.gradio_chat_history.append((None, ai_message)) return self.gradio_chat_history def clear_chat_history(self) -> list: self.chat_history.clear() self.gradio_chat_history = [] return self.gradio_chat_history def launch(self, **kwargs): with gr.Blocks() as demo: # Structure with gr.Row(): url_input = gr.Textbox(value=self.default_youtube_url, label='YouTube URL', scale=5) button = gr.Button(value='Go', scale=1) chatbot = gr.Chatbot() user_message = gr.Textbox(label='Ask a question:') clear = gr.ClearButton([user_message, chatbot]) # Actions button.click(self.clear_chat_history, inputs=[], outputs=[chatbot], trigger_mode='once' ).then(self.setup_chatbot, inputs=[url_input], outputs=[url_input] ).then(self.greet, inputs=[], outputs=[chatbot]) user_message.submit(self.question, inputs=[user_message], outputs=[user_message, chatbot] ).then(self.respond, inputs=[], outputs=[chatbot]) clear.click(self.clear_chat_history, inputs=[], outputs=[chatbot]) demo.launch(**kwargs) if __name__ == "__main__": app = YouTubeChatbotApp(n_sources=3, k=5, timestamp_interval=datetime.timedelta(minutes=2), memory=5, default_youtube_url='https://www.youtube.com/watch?v=4Bdc55j80l8' ) app.launch()