|
from io import StringIO |
|
|
|
import streamlit as st |
|
from langchain.docstore.document import Document |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language |
|
import time |
|
|
|
import vector_db as vdb |
|
from llm_model import LLMModel |
|
|
|
|
|
def default_state(): |
|
if "startup" not in st.session_state: |
|
st.session_state.startup = True |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if "uploaded_docs" not in st.session_state: |
|
st.session_state.uploaded_docs = [] |
|
|
|
if "llm_option" not in st.session_state: |
|
st.session_state.llm_option = "Local" |
|
|
|
if "answer_loading" not in st.session_state: |
|
st.session_state.answer_loading = False |
|
|
|
|
|
def load_doc(file_name: str, file_content: str): |
|
if file_name is not None: |
|
|
|
doc = Document(page_content=file_content, metadata={"source": file_name}) |
|
|
|
|
|
language = get_language(file_name) |
|
text_splitter = RecursiveCharacterTextSplitter.from_language(chunk_size=1000, chunk_overlap=150, |
|
language=language) |
|
|
|
docs = text_splitter.split_documents([doc]) |
|
return docs |
|
else: |
|
return None |
|
|
|
|
|
def get_language(file_name: str): |
|
if file_name.endswith(".md") or file_name.endswith(".mdx"): |
|
return Language.MARKDOWN |
|
elif file_name.endswith(".rst"): |
|
return Language.RST |
|
else: |
|
return Language.MARKDOWN |
|
|
|
|
|
@st.cache_resource() |
|
def get_vector_db(): |
|
return vdb.VectorDB() |
|
|
|
|
|
@st.cache_resource() |
|
def get_llm_model(_db: vdb.VectorDB): |
|
retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2}) |
|
return LLMModel(retriever=retriever).create_qa_chain() |
|
|
|
|
|
|
|
def init_sidebar(): |
|
with st.sidebar: |
|
st.toggle( |
|
"Loading from LLM", |
|
on_change=enable_sidebar(), |
|
disabled=not st.session_state.answer_loading |
|
) |
|
llm_option = st.selectbox( |
|
'Select to use local model or inference API', |
|
options=['Local', 'Inference API'] |
|
) |
|
st.session_state.llm_option = llm_option |
|
uploaded_files = st.file_uploader( |
|
'Upload file(s)', |
|
type=['md', 'mdx', 'rst', 'txt'], |
|
accept_multiple_files=True |
|
) |
|
for uploaded_file in uploaded_files: |
|
if uploaded_file.name not in st.session_state.uploaded_docs: |
|
|
|
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
string_data = stringio.read() |
|
|
|
doc_chunks = load_doc(uploaded_file.name, string_data) |
|
st.write(f"Number of chunks={len(doc_chunks)}") |
|
vector_db.load_docs_into_vector_db(doc_chunks) |
|
st.session_state.uploaded_docs.append(uploaded_file.name) |
|
|
|
|
|
def init_chat(): |
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
def disable_sidebar(): |
|
st.session_state.answer_loading = True |
|
st.rerun() |
|
|
|
|
|
def enable_sidebar(): |
|
st.session_state.answer_loading = False |
|
|
|
|
|
st.set_page_config(page_title="Document Answering Tool", page_icon=":book:") |
|
vector_db = get_vector_db() |
|
default_state() |
|
init_sidebar() |
|
st.header("Document answering tool") |
|
st.subheader("Upload your documents on the side and ask questions") |
|
init_chat() |
|
llm_model = get_llm_model(vector_db) |
|
st.session_state.startup = False |
|
|
|
|
|
|
|
if user_prompt := st.chat_input("What's up?", on_submit=disable_sidebar()): |
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
with st.chat_message("user"): |
|
st.markdown(user_prompt) |
|
|
|
st.session_state.messages.append({"role": "user", "content": user_prompt}) |
|
|
|
if llm_model is not None: |
|
assistant_chat = st.chat_message("assistant") |
|
if not st.session_state.uploaded_docs: |
|
assistant_chat.warning("WARN: Will try answer question without documents") |
|
with st.spinner('Resolving question...'): |
|
res = llm_model({"query": user_prompt}) |
|
sources = [] |
|
for source_docs in res['source_documents']: |
|
if 'source' in source_docs.metadata: |
|
sources.append(source_docs.metadata['source']) |
|
|
|
end_time = time.time() |
|
time_taken = "{:.2f}".format(end_time - start_time) |
|
format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s" |
|
assistant_chat.markdown(format_answer) |
|
source_expander = assistant_chat.expander("See full sources") |
|
for source_docs in res['source_documents']: |
|
if 'source' in source_docs.metadata: |
|
format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}" |
|
source_expander.markdown(format_source) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": format_answer}) |
|
enable_sidebar() |
|
st.rerun() |
|
|