Spaces:
Runtime error
Runtime error
import os | |
import shutil | |
import openai | |
import streamlit as st | |
from llama_index import SimpleDirectoryReader, VectorStoreIndex, StorageContext, load_index_from_storage | |
from llama_index.indices.query.base import BaseQueryEngine | |
from llama_index.node_parser import SimpleNodeParser | |
from llama_index.text_splitter import TokenTextSplitter | |
def create_vector_index(documents_path: str, persist_dir: str = "./vector_index/") -> None: | |
""" | |
Create a VectorStoreIndex and store it in a file. | |
Default storage context directory: ./vector_index/ | |
:param documents_path: The path to the documents to index. | |
:param persist_dir: The directory to store the index in. | |
:return: None | |
""" | |
# Load documents | |
documents = SimpleDirectoryReader(documents_path).load_data() | |
# Configure text splitter | |
text_splitter = TokenTextSplitter( | |
separator=" ", | |
chunk_size=1028, | |
chunk_overlap=256, | |
backup_separators=["\n"], | |
) | |
# Configure node parser | |
node_parser = SimpleNodeParser.from_defaults(text_splitter=text_splitter) | |
# Extract nodes from documents | |
nodes = node_parser.get_nodes_from_documents(documents) | |
# Create a vector store index | |
index = VectorStoreIndex(nodes) | |
# Persist the index, so it can be loaded later | |
index.storage_context.persist(persist_dir=persist_dir) | |
def load_query_engine_from_memory(persist_dir: str = "./vector_index/") -> BaseQueryEngine: | |
""" | |
Get a query engine for a given directory of documents. The query engine is loaded from memory. | |
Default storage context directory: ./vector_index/ | |
:param persist_dir: The directory to load the index from. | |
:return: BaseQueryEngine | |
""" | |
# rebuild storage context | |
storage_context = StorageContext.from_defaults(persist_dir=persist_dir) | |
# load index | |
index = load_index_from_storage(storage_context) | |
# Create a query engine from the index | |
query_engine = index.as_query_engine(top_k=5) | |
return query_engine | |
st.set_page_config(page_title="Pilot Chat", page_icon="๐", layout="wide") | |
st.header("Pilot Chat") | |
with st.sidebar: | |
openai_api_key = st.text_input( | |
"OpenAI API Key", key="file_qa_api_key", type="password" | |
) | |
os.environ["OPENAI_API_KEY"] = openai_api_key | |
openai.api_key = os.environ["OPENAI_API_KEY"] | |
uploaded_file = st.file_uploader( | |
"Upload file", | |
type=["pdf"], | |
help="Only PDF files are supported", | |
# on_change=clear_submit, | |
) | |
if uploaded_file is not None: | |
# save the file to 'documents/' | |
os.makedirs("documents/", exist_ok=True) | |
destination_path = f"documents/{uploaded_file.name}" | |
with open(destination_path, "wb") as buffer: | |
shutil.copyfileobj(uploaded_file, buffer) | |
st.info("File uploaded successfully.") | |
if not os.path.exists("vector_index"): | |
with st.spinner("Creating index..."): | |
create_vector_index("documents/") | |
st.info("Index created successfully.") | |
if openai_api_key == "": | |
st.warning("Please enter an OpenAI API key.") | |
query_engine = None | |
# Create a query engine from the index | |
if os.path.exists("vector_index") and openai_api_key != "": | |
with st.spinner("Loading index..."): | |
query_engine = load_query_engine_from_memory(persist_dir="./vector_index/") | |
st.info("Index loaded successfully. You can now ask questions about the document.") | |
# Create a text input box for the user | |
user_input = st.text_input("Enter a question about the document", key="file_qa_input") | |
if user_input and query_engine is not None: | |
# Query the index | |
with st.spinner("Querying index..."): | |
results = query_engine.query(user_input) | |
response = results.response | |
sources = results.get_formatted_sources(length=1500) | |
# Display the results | |
st.subheader("Answer") | |
st.write(response) | |
# Display the sources | |
st.subheader("Sources") | |
st.warning("The sources are not guaranteed to be relevant.") | |
st.info(sources) | |