File size: 4,161 Bytes
2392ba8
 
1e53020
6a57640
66b40bc
1e53020
 
 
 
6a57640
1e53020
6a57640
 
 
 
1e53020
6a57640
 
 
 
 
 
 
 
1e53020
 
2392ba8
 
 
 
1e53020
2392ba8
 
1e53020
d9be466
 
2392ba8
 
1e53020
2392ba8
1e53020
2392ba8
 
 
 
1e53020
 
2392ba8
1e53020
2392ba8
 
 
1e53020
2392ba8
1e53020
 
2392ba8
 
 
 
1e53020
2392ba8
1e53020
2392ba8
 
 
 
 
 
 
 
bb22cc4
 
 
 
66b40bc
6a57640
2392ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66b40bc
 
bb22cc4
 
 
 
66b40bc
 
 
6a57640
66b40bc
 
 
 
 
 
 
2392ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os

import streamlit as st
from langchain.chat_models import AzureChatOpenAI

from knowledge_gpt.components.sidebar import sidebar
from knowledge_gpt.core.caching import bootstrap_caching
from knowledge_gpt.core.chunking import chunk_file
from knowledge_gpt.core.embedding import embed_files
from knowledge_gpt.core.parsing import read_file
from knowledge_gpt.core.qa import query_folder
from knowledge_gpt.ui import display_file_read_error
from knowledge_gpt.ui import is_file_valid
from knowledge_gpt.ui import is_query_valid
from knowledge_gpt.ui import wrap_doc_in_html

st.set_page_config(page_title="ReferenceBot", page_icon="📖", layout="wide")

# add all secrets into environmental variables
if os.path.exists(
    os.path.dirname(os.path.abspath(__file__)) + "/../.streamlit/secrets.toml"
):  # to avoid redundant print by calling st.secrets
    for key, value in st.secrets.items():
        os.environ[key] = value


def main():
    EMBEDDING = "openai"
    VECTOR_STORE = "faiss"
    MODEL_LIST = ["gpt-3.5-turbo", "gpt-4"]

    # Uncomment to enable debug mode
    # MODEL_LIST.insert(0, "debug")

    st.header("📖ReferenceBot")

    # Enable caching for expensive functions
    bootstrap_caching()

    sidebar()

    uploaded_file = st.file_uploader(
        "Upload a pdf, docx, or txt file",
        type=["pdf", "docx", "txt"],
        help="Scanned documents are not supported yet!",
    )

    model: str = st.selectbox("Model", options=MODEL_LIST)  # type: ignore

    with st.expander("Advanced Options"):
        return_all_chunks = st.checkbox("Show all chunks retrieved from vector search")
        show_full_doc = st.checkbox("Show parsed contents of the document")

    if not uploaded_file:
        st.stop()

    try:
        file = read_file(uploaded_file)
    except Exception as e:
        display_file_read_error(e, file_name=uploaded_file.name)

    chunked_file = chunk_file(file, chunk_size=300, chunk_overlap=0)

    if not is_file_valid(file):
        st.stop()

    with st.spinner("Indexing document... This may take a while⏳"):
        folder_index = embed_files(
            files=[chunked_file],
            embedding=EMBEDDING if model != "debug" else "debug",
            vector_store=VECTOR_STORE if model != "debug" else "debug",
            deployment=os.environ["ENGINE_EMBEDDING"],
            model=os.environ["ENGINE"],
            openai_api_key=os.environ["OPENAI_API_KEY"],
            openai_api_base=os.environ["OPENAI_API_BASE"],
            openai_api_type="azure",
            chunk_size=1,
        )

    with st.form(key="qa_form"):
        query = st.text_area("Ask a question about the document")
        submit = st.form_submit_button("Submit")

    if show_full_doc:
        with st.expander("Document"):
            # Hack to get around st.markdown rendering LaTeX
            st.markdown(f"<p>{wrap_doc_in_html(file.docs)}</p>", unsafe_allow_html=True)

    if submit:
        if not is_query_valid(query):
            st.stop()

        # Output Columns
        answer_col, sources_col = st.columns(2)

        with st.spinner("Setting up AzureChatOpenAI bot..."):
            llm = AzureChatOpenAI(
                openai_api_base=os.environ["OPENAI_API_BASE"],
                openai_api_version=os.environ["OPENAI_API_VERSION"],
                deployment_name=os.environ["ENGINE"],
                openai_api_key=os.environ["OPENAI_API_KEY"],
                openai_api_type="azure",
                temperature=0,
            )

        with st.spinner("Querying folder to get result..."):
            result = query_folder(
                folder_index=folder_index,
                query=query,
                return_all=return_all_chunks,
                llm=llm,
            )

        with answer_col:
            st.markdown("#### Answer")
            st.markdown(result.answer)

        with sources_col:
            st.markdown("#### Sources")
            for source in result.sources:
                st.markdown(source.page_content)
                st.markdown(source.metadata["source"])
                st.markdown("---")


if __name__ == "__main__":
    main()