Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- app.py +47 -0
- populate_database.py +110 -0
- requirements.txt +10 -0
app.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
import argparse
|
3 |
+
from langchain.vectorstores.chroma import Chroma
|
4 |
+
from langchain.prompts import ChatPromptTemplate
|
5 |
+
from langchain_community.llms import LlamaCpp
|
6 |
+
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
|
7 |
+
from get_embedding_function import get_embedding_function
|
8 |
+
|
9 |
+
CHROMA_PATH = "chroma"
|
10 |
+
|
11 |
+
PROMPT_TEMPLATE = """
|
12 |
+
Answer the question based only on the following context:
|
13 |
+
|
14 |
+
{context}
|
15 |
+
|
16 |
+
---
|
17 |
+
|
18 |
+
Answer the question based on the above context: {question}
|
19 |
+
"""
|
20 |
+
embedding_function = get_embedding_function()
|
21 |
+
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
|
22 |
+
|
23 |
+
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
24 |
+
model = LlamaCpp(
|
25 |
+
model_path="/content/drive/MyDrive/mistral-7b-instruct-v0.2.Q4_K_M.gguf",
|
26 |
+
temperature=0.75,
|
27 |
+
max_tokens=2000,
|
28 |
+
top_p=1,
|
29 |
+
callback_manager=callback_manager,
|
30 |
+
verbose=True, # Verbose is required to pass to the callback manager
|
31 |
+
)
|
32 |
+
|
33 |
+
app = FastAPI()
|
34 |
+
|
35 |
+
@app.get("/query")
|
36 |
+
async def getAnswer():
|
37 |
+
query_text = "What's up?"
|
38 |
+
results = db.similarity_search_with_score(query_text, k=5)
|
39 |
+
|
40 |
+
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
|
41 |
+
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
|
42 |
+
prompt = prompt_template.format(context=context_text, question=query_text)
|
43 |
+
|
44 |
+
response_text = model.invoke(prompt)
|
45 |
+
sources = [doc.metadata.get("id", None) for doc, _score in results]
|
46 |
+
formatted_response = f"Response: {response_text}\nSources: {sources}"
|
47 |
+
return response_text
|
populate_database.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
|
5 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
6 |
+
from langchain.schema.document import Document
|
7 |
+
from get_embedding_function import get_embedding_function
|
8 |
+
from langchain.vectorstores.chroma import Chroma
|
9 |
+
|
10 |
+
|
11 |
+
CHROMA_PATH = "chroma"
|
12 |
+
DATA_PATH = "data"
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
|
17 |
+
# Check if the database should be cleared (using the --clear flag).
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("--reset", action="store_true", help="Reset the database.")
|
20 |
+
args = parser.parse_args()
|
21 |
+
if args.reset:
|
22 |
+
print("✨ Clearing Database")
|
23 |
+
clear_database()
|
24 |
+
|
25 |
+
# Create (or update) the data store.
|
26 |
+
documents = load_documents()
|
27 |
+
chunks = split_documents(documents)
|
28 |
+
add_to_chroma(chunks)
|
29 |
+
|
30 |
+
|
31 |
+
def load_documents():
|
32 |
+
document_loader = PyPDFDirectoryLoader(DATA_PATH)
|
33 |
+
return document_loader.load()
|
34 |
+
|
35 |
+
|
36 |
+
def split_documents(documents: list[Document]):
|
37 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
38 |
+
chunk_size=800,
|
39 |
+
chunk_overlap=80,
|
40 |
+
length_function=len,
|
41 |
+
is_separator_regex=False,
|
42 |
+
)
|
43 |
+
return text_splitter.split_documents(documents)
|
44 |
+
|
45 |
+
|
46 |
+
def add_to_chroma(chunks: list[Document]):
|
47 |
+
# Load the existing database.
|
48 |
+
db = Chroma(
|
49 |
+
persist_directory=CHROMA_PATH, embedding_function=get_embedding_function()
|
50 |
+
)
|
51 |
+
|
52 |
+
# Calculate Page IDs.
|
53 |
+
chunks_with_ids = calculate_chunk_ids(chunks)
|
54 |
+
|
55 |
+
# Add or Update the documents.
|
56 |
+
existing_items = db.get(include=[]) # IDs are always included by default
|
57 |
+
existing_ids = set(existing_items["ids"])
|
58 |
+
print(f"Number of existing documents in DB: {len(existing_ids)}")
|
59 |
+
|
60 |
+
# Only add documents that don't exist in the DB.
|
61 |
+
new_chunks = []
|
62 |
+
for chunk in chunks_with_ids:
|
63 |
+
if chunk.metadata["id"] not in existing_ids:
|
64 |
+
new_chunks.append(chunk)
|
65 |
+
|
66 |
+
if len(new_chunks):
|
67 |
+
print(f"👉 Adding new documents: {len(new_chunks)}")
|
68 |
+
new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
|
69 |
+
db.add_documents(new_chunks, ids=new_chunk_ids)
|
70 |
+
db.persist()
|
71 |
+
else:
|
72 |
+
print("✅ No new documents to add")
|
73 |
+
|
74 |
+
|
75 |
+
def calculate_chunk_ids(chunks):
|
76 |
+
|
77 |
+
# This will create IDs like "data/monopoly.pdf:6:2"
|
78 |
+
# Page Source : Page Number : Chunk Index
|
79 |
+
|
80 |
+
last_page_id = None
|
81 |
+
current_chunk_index = 0
|
82 |
+
|
83 |
+
for chunk in chunks:
|
84 |
+
source = chunk.metadata.get("source")
|
85 |
+
page = chunk.metadata.get("page")
|
86 |
+
current_page_id = f"{source}:{page}"
|
87 |
+
|
88 |
+
# If the page ID is the same as the last one, increment the index.
|
89 |
+
if current_page_id == last_page_id:
|
90 |
+
current_chunk_index += 1
|
91 |
+
else:
|
92 |
+
current_chunk_index = 0
|
93 |
+
|
94 |
+
# Calculate the chunk ID.
|
95 |
+
chunk_id = f"{current_page_id}:{current_chunk_index}"
|
96 |
+
last_page_id = current_page_id
|
97 |
+
|
98 |
+
# Add it to the page meta-data.
|
99 |
+
chunk.metadata["id"] = chunk_id
|
100 |
+
|
101 |
+
return chunks
|
102 |
+
|
103 |
+
|
104 |
+
def clear_database():
|
105 |
+
if os.path.exists(CHROMA_PATH):
|
106 |
+
shutil.rmtree(CHROMA_PATH)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pypdf
|
2 |
+
langchain
|
3 |
+
chromadb
|
4 |
+
pytest
|
5 |
+
uvicorn
|
6 |
+
python-multipart
|
7 |
+
fastapi
|
8 |
+
requests
|
9 |
+
python-dotenv
|
10 |
+
llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
|