File size: 4,368 Bytes
9c93e79
27a75c7
dbbbe03
27a75c7
 
 
 
9c93e79
27a75c7
 
 
 
 
 
9c93e79
 
 
2e71340
ea190bc
2e71340
 
0d4e49f
 
607217c
 
0d4e49f
2e71340
27a75c7
 
ea190bc
7ecf307
27a75c7
 
 
 
 
0d4e49f
 
 
 
9c93e79
 
 
 
 
e39545b
9c93e79
 
 
 
 
ef26dc3
0d4e49f
9c93e79
27a75c7
 
 
 
 
 
 
 
 
9c93e79
27a75c7
 
 
 
4905a09
 
d5cbb04
70fa901
 
4905a09
 
70fa901
4905a09
27a75c7
 
 
b25c8b8
9c93e79
ea190bc
 
b5e82ff
 
ea190bc
 
0d47333
0d4e49f
b25c8b8
9c93e79
0d4e49f
b25c8b8
9c93e79
 
 
 
 
 
 
 
 
dbbbe03
9c93e79
 
 
 
2e71340
716125b
70fa901
d5cbb04
70fa901
18a3ff8
 
 
b5e82ff
 
4905a09
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# other imports
import streamlit as st, os
import spaces
import tempfile
from pytube import YouTube
import whisper
from dotenv import load_dotenv

# langchain imports
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_openai import OpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough


# initialize session state variables if not already initialized
if 'chain' not in st.session_state:
    st.session_state.chain = None
if 'transcription' not in st.session_state:
    st.session_state.transcription = None
if 'thumbnail_url' not in st.session_state:
    st.session_state.thumbnail_url = None



# function for getting transcribing YouTube video
def get_yt_trans(yt_obj):
    video_audio = yt_obj.streams.filter(only_audio=True).first()
    # define model for audio to text conversion
    audio_to_text_model = whisper.load_model('base')
    # store transcription in local file
    with tempfile.TemporaryDirectory() as tmpdir:
        audio_file = video_audio.download(output_path=tmpdir)
        st.session_state.transcription = audio_to_text_model.transcribe(audio_file, fp16=False)["text"].strip()
        # with open("transcription.txt", "w") as file:
        #     file.write(transcription)
    


# function for chunking transcription text
def get_vid_text_chunks(transcription_text):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20)
    chunked_text = text_splitter.split_text(transcription_text)
    return chunked_text


# function for creating vector db
def create_vecdb(chunks):
    vecdb = FAISS.from_texts(chunks, OpenAIEmbeddings())
    # vecdb.save_local('vec_store') ## if needed to store and retrieve locally
    return vecdb


# define api keys
load_dotenv()
os.environ['OPEN_API_KEY'] = os.getenv('OPENAI_API_KEY')

# define LLM
llm = OpenAI(temperature=0)

# Set page configs
st.set_page_config(
    page_title="YouTube Talks",
    page_icon='📽️'
)
st.header('Query YouTube Videos!')
thumbnail_placeholder = st.empty()
st.sidebar.header("URL details:")
video_url = st.sidebar.text_input("Enter video URL")
submit_button = st.sidebar.button("Submit")
user_query = st.text_input("Query the Video!")
video_query_button = st.button("Ask video!")
progress_updates = st.sidebar.empty()
main_ph = st.empty()

# on button press
if submit_button:
    progress_updates.text("Transcribing Video...")
    # get video transcription
    yt_obj = YouTube(video_url)
    # display video thumbnail
    st.session_state.thumbnail_url = yt_obj.thumbnail_url
    thumbnail_placeholder.image(st.session_state.thumbnail_url)
    # get transcription
    get_yt_trans(yt_obj)
    st.sidebar.subheader("Transcription:")
    st.sidebar.write(st.session_state.transcription)
    progress_updates.text("Making Text Chunks...")
    # get text chunks
    chunks = get_vid_text_chunks(st.session_state.transcription)  
    progress_updates.text("Creating Vector DB...")
    # create vector db
    vector_db = create_vecdb(chunks)          
    # define main prompt
    p_template = '''Answer the question based on the context below. If you can't 
answer the question, reply "I don't know".

Context: {context}

Question: {question}'''
    
    prompt = ChatPromptTemplate.from_template(p_template)
    # define output parser
    parser = StrOutputParser()
    # define main chain
    st.session_state.chain = {"context": vector_db.as_retriever(), "question": RunnablePassthrough()} | prompt | llm | parser
    progress_updates.text("Video Transcribed Successfully!!!")


if video_query_button:
    # keep displaying Transcription and Thumbnail in window
    st.sidebar.subheader("Transcription:")
    st.sidebar.write(st.session_state.transcription)
    thumbnail_placeholder.image(st.session_state.thumbnail_url)
    # if video not transcribed display error
    if st.session_state.chain is None:
        st.error("Please transcribe a video first by submitting a URL.")
    else:
        main_ph.text("Fetching Results...")
        # print results
        st.subheader("Result:")
        main_ph.text("Displaying Results...")
        st.write(st.session_state.chain.invoke(user_query))