# other imports import streamlit as st, os import spaces import tempfile from pytube import YouTube import whisper from dotenv import load_dotenv # langchain imports from langchain_community.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough # initialize session state variables if not already initialized if 'chain' not in st.session_state: st.session_state.chain = None if 'transcription' not in st.session_state: st.session_state.transcription = None if 'thumbnail_url' not in st.session_state: st.session_state.thumbnail_url = None # function for getting transcribing YouTube video def get_yt_trans(yt_obj): video_audio = yt_obj.streams.filter(only_audio=True).first() # define model for audio to text conversion audio_to_text_model = whisper.load_model('base') # store transcription in local file with tempfile.TemporaryDirectory() as tmpdir: audio_file = video_audio.download(output_path=tmpdir) st.session_state.transcription = audio_to_text_model.transcribe(audio_file, fp16=False)["text"].strip() # with open("transcription.txt", "w") as file: # file.write(transcription) # function for chunking transcription text def get_vid_text_chunks(transcription_text): text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20) chunked_text = text_splitter.split_text(transcription_text) return chunked_text # function for creating vector db def create_vecdb(chunks): vecdb = FAISS.from_texts(chunks, OpenAIEmbeddings()) # vecdb.save_local('vec_store') ## if needed to store and retrieve locally return vecdb # define api keys load_dotenv() os.environ['OPEN_API_KEY'] = os.getenv('OPENAI_API_KEY') # define LLM llm = OpenAI(temperature=0) # Set page configs st.set_page_config( page_title="YouTube Talks", page_icon='📽️' ) st.header('Query YouTube Videos!') thumbnail_placeholder = st.empty() st.sidebar.header("URL details:") video_url = st.sidebar.text_input("Enter video URL") submit_button = st.sidebar.button("Submit") user_query = st.text_input("Query the Video!") video_query_button = st.button("Ask video!") progress_updates = st.sidebar.empty() main_ph = st.empty() # on button press if submit_button: main_ph.text("Transcribing Video, please wait...") progress_updates.text("Transcribing Video...") # get video transcription yt_obj = YouTube(video_url) # display video thumbnail st.session_state.thumbnail_url = yt_obj.thumbnail_url thumbnail_placeholder.image(st.session_state.thumbnail_url) # get transcription get_yt_trans(yt_obj) st.sidebar.subheader("Transcription:") st.sidebar.write(st.session_state.transcription) progress_updates.text("Making Text Chunks...") # get text chunks chunks = get_vid_text_chunks(st.session_state.transcription) progress_updates.text("Creating Vector DB...") # create vector db vector_db = create_vecdb(chunks) # define main prompt p_template = '''Answer the question based on the context below. If you can't answer the question, reply "I don't know". Context: {context} Question: {question}''' prompt = ChatPromptTemplate.from_template(p_template) # define output parser parser = StrOutputParser() # define main chain st.session_state.chain = {"context": vector_db.as_retriever(), "question": RunnablePassthrough()} | prompt | llm | parser progress_updates.text("Video Transcribed Successfully!!!") main_ph.text("Video Transcribed Successfully!") if video_query_button: # keep displaying Transcription and Thumbnail in window st.sidebar.subheader("Transcription:") st.sidebar.write(st.session_state.transcription) thumbnail_placeholder.image(st.session_state.thumbnail_url) # if video not transcribed display error if st.session_state.chain is None: st.error("Please transcribe a video first by submitting a URL.") else: main_ph.text("Fetching Results...") # print results st.subheader("Result:") main_ph.text("Displaying Results...") st.write(st.session_state.chain.invoke(user_query))