File size: 4,846 Bytes
2ed8e0d
018ec39
2ed8e0d
 
 
 
 
 
 
 
 
 
 
018ec39
2ed8e0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
# from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import HuggingFaceHub
from langchain.document_loaders import AssemblyAIAudioTranscriptLoader
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from tempfile import NamedTemporaryFile

# Load environment variables
# load_dotenv()

# Function to create a prompt for retrieval QA chain
def create_qa_prompt() -> PromptTemplate:
    template = """\n\nHuman: Use the following pieces of context to answer the question at the end. If the answer is not clear, say I DON'T KNOW
{context}
Question: {question}
\n\nAssistant:
Answer:"""

    return PromptTemplate(template=template, input_variables=["context", "question"])

# Function to create documents from a list of URLs
def create_docs(urls_list):
    documents = []
    for url in urls_list:
        st.write(f'Transcribing {url}')
        documents.append(AssemblyAIAudioTranscriptLoader(file_path=url).load()[0])
    return documents

# Function to create a Hugging Face embeddings model
def make_embedder():
    model_name = "sentence-transformers/all-mpnet-base-v2"
    model_kwargs = {'device': 'cpu'}
    encode_kwargs = {'normalize_embeddings': False}
    return HuggingFaceHubEmbeddings(
        repo_id=model_name,
        task="feature-extraction"
    )

# Function to create a retrieval QA chain
def make_qa_chain():
    llm = HuggingFaceHub(
        repo_id="HuggingFaceH4/zephyr-7b-beta",
        model_kwargs={
            "max_new_tokens": 512,
            "top_k": 30,
            "temperature": 0.01,
            "repetition_penalty": 1.5,
        },
    )
    return llm
    # return RetrievalQA.from_chain_type(
    #     llm,
    #     retriever=db.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 3}),
    #     return_source_documents=True,
    #     chain_type_kwargs={
    #         "prompt": create_qa_prompt(),
    #     }
    # )

# Streamlit UI
def main():
    st.set_page_config(page_title="Audio Query Chatbot", page_icon=":microphone:", layout="wide")

    # Left pane - Audio file upload
    col1, col2 = st.columns([1, 2])

    with col1:
        st.header("Upload Audio File")
        uploaded_file = st.file_uploader("Choose a WAV or MP3 file", type=["wav", "mp3"], key="audio_uploader")

        if uploaded_file is not None:
            with NamedTemporaryFile(suffix='.mp3') as temp:
                temp.write(uploaded_file.getvalue())
                temp.seek(0)
                docs = create_docs([temp.name])

                # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
                # texts = text_splitter.split_documents(docs)

                # for text in texts:
                    # text.metadata = {"audio_url": text.metadata["audio_url"]}

                st.success('Audio file transcribed successfully!')

                # hf = make_embedder()
                # db = FAISS.from_documents(texts, hf)

                # qa_chain = make_qa_chain(db)

    # Right pane - Chatbot Interface
    with col2:
        st.header("Chatbot Interface")

        if uploaded_file is not None:
            with st.form(key="form"):
                user_input = st.text_input("Ask your question", key="user_input")

                # Automatically submit the form on Enter key press
                st.markdown("<div><br></div>", unsafe_allow_html=True)  # Adds some space
                st.markdown(
                    """<style>
                    #form input {margin-bottom: 15px;}
                    </style>""", unsafe_allow_html=True
                )

                submit = st.form_submit_button("Submit Question")

            # Display the result once the form is submitted
            if submit:
                llm = make_qa_chain()
                chain = load_qa_chain(llm, chain_type="stuff")
                # docs = db.similarity_search(user_input)
                result = chain.run(question=user_input,input_documents = docs)
                # result = qa_chain.invoke(user_input)
                # result = qa_chain({"query": user_input})
                st.success("Query Result:")
                st.write(f"User: {user_input}")
                st.write(f"Assistant: {result}")

                # st.subheader("Source Documents:")
                # for idx, elt in enumerate(result['source_documents']):
                #     st.write(f"Source {idx + 1}:")
                #     st.write(f"Filepath: {elt.metadata['audio_url']}")
                #     st.write(f"Contents: {elt.page_content}")

if __name__ == "__main__":
    main()